Skip to content

Commit

Permalink
A few initial simplifications.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Sep 15, 2024
1 parent 937b884 commit 7675053
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions src/math/pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,14 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
// Fetch the pow eval algo.
const auto pea = get_pow_eval_algo(f);

// Codegen the exponent.
auto *expo = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size);

// Fetch the internal vector type.
auto *vec_t = make_vector_type(fp_t, batch_size);

if (order == 0u) {
return pea.eval_f(
s, {taylor_fetch_diff(arr, u_idx, 0, n_uvars), taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size)});
return pea.eval_f(s, {taylor_fetch_diff(arr, u_idx, 0, n_uvars), expo});
}

// Special case for sqrt().
Expand All @@ -514,7 +519,6 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
}

// The general case.
auto &builder = s.builder();

// NOTE: iteration in the [0, order) range
// (i.e., order *not* included).
Expand All @@ -526,21 +530,17 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
// Compute the scalar factor: order * num - j * (num + 1).
auto scal_f = [&]() -> llvm::Value * {
if constexpr (std::is_same_v<U, number>) {
return vector_splat(
builder,
llvm_codegen(s, fp_t,
number_like(s, fp_t, static_cast<double>(order)) * num
- number_like(s, fp_t, static_cast<double>(j)) * (num + number_like(s, fp_t, 1.))),
batch_size);
return llvm_codegen(s, vec_t,
number_like(s, fp_t, static_cast<double>(order)) * num
- number_like(s, fp_t, static_cast<double>(j))
* (num + number_like(s, fp_t, 1.)));
} else {
auto pc = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size);
auto *jvec = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(j))), batch_size);
auto *ordvec
= vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(order))), batch_size);
auto *onevec = vector_splat(builder, llvm_codegen(s, fp_t, number(1.)), batch_size);
auto *jvec = llvm_codegen(s, vec_t, number(static_cast<double>(j)));
auto *ordvec = llvm_codegen(s, vec_t, number(static_cast<double>(order)));
auto *onevec = llvm_codegen(s, vec_t, number(1.));

auto tmp1 = llvm_fmul(s, ordvec, pc);
auto tmp2 = llvm_fmul(s, jvec, llvm_fadd(s, pc, onevec));
auto tmp1 = llvm_fmul(s, ordvec, expo);
auto tmp2 = llvm_fmul(s, jvec, llvm_fadd(s, expo, onevec));

return llvm_fsub(s, tmp1, tmp2);
}
Expand All @@ -554,12 +554,14 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
auto *ret_acc = pairwise_sum(s, sum);

// Compute the final divisor: order * (zero-th derivative of u_idx).
auto *ord_f = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(order))), batch_size);
auto *ord_f = llvm_codegen(s, vec_t, number(static_cast<double>(order)));
auto *b0 = taylor_fetch_diff(arr, u_idx, 0, n_uvars);
auto *div = llvm_fmul(s, ord_f, b0);

// Compute and return the result: ret_acc / div.
return llvm_fdiv(s, ret_acc, div);
auto *ret = llvm_fdiv(s, ret_acc, div);

return ret;
}

// All the other cases.
Expand Down

0 comments on commit 7675053

Please sign in to comment.