diff --git a/src/math/pow.cpp b/src/math/pow.cpp index 399f62aef..4478e01d4 100644 --- a/src/math/pow.cpp +++ b/src/math/pow.cpp @@ -496,9 +496,14 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp // Fetch the pow eval algo. const auto pea = get_pow_eval_algo(f); + // Codegen the exponent. + auto *expo = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size); + + // Fetch the internal vector type. + auto *vec_t = make_vector_type(fp_t, batch_size); + if (order == 0u) { - return pea.eval_f( - s, {taylor_fetch_diff(arr, u_idx, 0, n_uvars), taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size)}); + return pea.eval_f(s, {taylor_fetch_diff(arr, u_idx, 0, n_uvars), expo}); } // Special case for sqrt(). @@ -514,7 +519,6 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp } // The general case. - auto &builder = s.builder(); // NOTE: iteration in the [0, order) range // (i.e., order *not* included). @@ -526,21 +530,17 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp // Compute the scalar factor: order * num - j * (num + 1). auto scal_f = [&]() -> llvm::Value * { if constexpr (std::is_same_v) { - return vector_splat( - builder, - llvm_codegen(s, fp_t, - number_like(s, fp_t, static_cast(order)) * num - - number_like(s, fp_t, static_cast(j)) * (num + number_like(s, fp_t, 1.))), - batch_size); + return llvm_codegen(s, vec_t, + number_like(s, fp_t, static_cast(order)) * num + - number_like(s, fp_t, static_cast(j)) + * (num + number_like(s, fp_t, 1.))); } else { - auto pc = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size); - auto *jvec = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast(j))), batch_size); - auto *ordvec - = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast(order))), batch_size); - auto *onevec = vector_splat(builder, llvm_codegen(s, fp_t, number(1.)), batch_size); + auto *jvec = llvm_codegen(s, vec_t, number(static_cast(j))); + auto *ordvec = llvm_codegen(s, vec_t, number(static_cast(order))); + auto *onevec = llvm_codegen(s, vec_t, number(1.)); - auto tmp1 = llvm_fmul(s, ordvec, pc); - auto tmp2 = llvm_fmul(s, jvec, llvm_fadd(s, pc, onevec)); + auto tmp1 = llvm_fmul(s, ordvec, expo); + auto tmp2 = llvm_fmul(s, jvec, llvm_fadd(s, expo, onevec)); return llvm_fsub(s, tmp1, tmp2); } @@ -554,12 +554,14 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp auto *ret_acc = pairwise_sum(s, sum); // Compute the final divisor: order * (zero-th derivative of u_idx). - auto *ord_f = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast(order))), batch_size); + auto *ord_f = llvm_codegen(s, vec_t, number(static_cast(order))); auto *b0 = taylor_fetch_diff(arr, u_idx, 0, n_uvars); auto *div = llvm_fmul(s, ord_f, b0); // Compute and return the result: ret_acc / div. - return llvm_fdiv(s, ret_acc, div); + auto *ret = llvm_fdiv(s, ret_acc, div); + + return ret; } // All the other cases.