Skip to content

Commit

Permalink
Remove excessive FD and request parameters from DeriveVectorMode (#1136)
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored Nov 13, 2024
1 parent d3292eb commit e494c10
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 13 deletions.
5 changes: 1 addition & 4 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,10 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
///\brief Produces the first derivative of a given function with
/// respect to multiple parameters.
///
///\param[in] FD - the function that will be differentiated.
///
///\returns The differentiated and potentially created enclosing
/// context.
///
DerivativeAndOverload DeriveVectorMode(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload DeriveVectorMode();

/// Builds an overload for the vector mode function that has derived params
/// for all the arguments of the requested function and it calls the original
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
result = V.DerivePushforward();
} else if (request.Mode == DiffMode::vector_forward_mode) {
VectorForwardModeVisitor V(*this, request);
result = V.DeriveVectorMode(FD, request);
result = V.DeriveVectorMode();
} else if (request.Mode == DiffMode::experimental_vector_pushforward) {
VectorPushForwardModeVisitor V(*this, request);
result = V.DerivePushforward();
Expand Down
12 changes: 4 additions & 8 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,16 @@ void VectorForwardModeVisitor::SetIndependentVarsExpr(Expr* IndVarCountExpr) {
m_IndVarCountExpr = IndVarCountExpr;
}

DerivativeAndOverload
VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD,
const DiffRequest& request) {
assert(m_DiffReq == request);
DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() {
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == DiffMode::vector_forward_mode);

DiffParams args{};
DiffInputVarsInfo DVI;
DVI = request.DVI;
for (auto dParam : DVI)
for (const auto& dParam : m_DiffReq.DVI)
args.push_back(dParam.param);

// Generate name for the derivative function.
std::string derivedFnName = request.BaseFunctionName + "_dvec";
std::string derivedFnName = m_DiffReq.BaseFunctionName + "_dvec";
if (args.size() != FD->getNumParams()) {
for (auto arg : args) {
auto it = std::find(FD->param_begin(), FD->param_end(), arg);
Expand Down

0 comments on commit e494c10

Please sign in to comment.