diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index fa14e5629..f1835844e 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -148,6 +148,8 @@ namespace clad { clang::StringLiteral* CreateStringLiteral(clang::ASTContext& C, llvm::StringRef str); + bool isLambdaQType(clang::QualType QT); + /// Returns true if `QT` is Array or Pointer Type, otherwise returns false. bool isArrayOrPointerType(clang::QualType QT); diff --git a/include/clad/Differentiator/Compatibility.h b/include/clad/Differentiator/Compatibility.h index efd3d629c..1fb98efba 100644 --- a/include/clad/Differentiator/Compatibility.h +++ b/include/clad/Differentiator/Compatibility.h @@ -403,6 +403,16 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy, #define CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam clang::ParsedAttributesView::none(), #endif +#if CLANG_VERSION_MAJOR > 12 +#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind( \ + LAMBDACXXRECORDDECL) \ + LAMBDACXXRECORDDECL->getLambdaDependencyKind() +#else +#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind( \ + LAMBDACXXRECORDDECL) \ + LAMBDACXXRECORDDECL->isDependentLambda() +#endif + // Clang 12 add one extra param (FPO) that we get from Node in Create method of: // ImplicitCastExpr, CStyleCastExpr, CXXStaticCastExpr and CXXFunctionalCastExpr diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index a73513f94..385d60ba3 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -357,6 +357,11 @@ namespace clad { llvm::SmallVectorImpl& outputArgs, clang::Expr* CUDAExecConfig = nullptr); + clang::CXXRecordDecl* + diffLambdaCXXRecordDecl(const clang::CXXRecordDecl* Original); + clang::CXXMethodDecl* + DifferentiateCallOperatorIfLambda(const clang::CXXRecordDecl* RD); + public: ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request); virtual ~ReverseModeVisitor(); @@ -383,6 +388,7 @@ namespace clad { DerivativeAndOverload DerivePullback(); StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); + StmtDiff VisitLambdaExpr(const clang::LambdaExpr* LE); StmtDiff VisitCallExpr(const clang::CallExpr* CE); virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 210f82112..7454a5e06 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -293,7 +293,7 @@ namespace clad { clang::Scope* scope, clang::Expr* Init = nullptr, bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit); + clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false); /// Builds variable declaration to be used inside the derivative /// body. /// \param[in] Type The type of variable declaration to build. @@ -310,7 +310,7 @@ namespace clad { clang::Expr* Init = nullptr, bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit); + clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false); /// Builds variable declaration to be used inside the derivative /// body. /// \param[in] Type The type of variable declaration to build. @@ -326,7 +326,7 @@ namespace clad { clang::Expr* Init = nullptr, bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit); + clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false); /// Builds variable declaration to be used inside the derivative /// body in the derivative function global scope. clang::VarDecl* diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 350eeea07..cd40405d4 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -304,6 +304,13 @@ namespace clad { return false; } + bool isLambdaQType(QualType QT) { + if (const RecordType* RT = QT->getAs()) + if (const CXXRecordDecl* RD = dyn_cast(RT->getDecl())) + return RD->isLambda(); + return false; + } + bool IsReferenceOrPointerArg(const Expr* arg) { // The argument is passed by reference if it's passed as an L-value. // However, if arg is a MaterializeTemporaryExpr, then arg is a diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 026164498..c7739e1fb 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -26,6 +26,7 @@ #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" #include "clang/Sema/Scope.h" +#include "clang/Sema/ScopeInfo.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaInternal.h" #include "clang/Sema/Template.h" @@ -1667,6 +1668,288 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(FL)); } + CXXMethodDecl* ReverseModeVisitor::DifferentiateCallOperatorIfLambda( + const clang::CXXRecordDecl* RD) { + if (RD) { + CXXRecordDecl* constructedType = RD->getDefinition(); + bool isLambda = constructedType->isLambda(); + if (isLambda) { + for (const auto* method : constructedType->methods()) { + if (const auto* cxxMethod = dyn_cast(method)) { + if (cxxMethod->isOverloadedOperator() && + cxxMethod->getOverloadedOperator() == OO_Call) { + + DiffRequest req; + req.Function = cxxMethod; + req.Mode = DiffMode::experimental_pullback; + req.BaseFunctionName = utils::ComputeEffectiveFnName(cxxMethod); + // Silence diag outputs in nested derivation process. + req.VerboseDiags = false; + + return dyn_cast(m_Builder.Derive(req).derivative); + } + } + } + } + } + return nullptr; + } + + CXXRecordDecl* + ReverseModeVisitor::diffLambdaCXXRecordDecl(const CXXRecordDecl* Original) { + // Create a new Lambda CXXRecordDecl that is going to represent a pullback + CXXRecordDecl* Cloned = CXXRecordDecl::CreateLambda( + m_Context, const_cast(Original->getDeclContext()), + Original->getLambdaTypeInfo(), Original->getBeginLoc(), + CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind(Original), + Original->isGenericLambda(), LCD_ByRef); + + // Copy the fields if any (FieldDecl) + for (auto* Field : Original->fields()) { + FieldDecl* NewField = FieldDecl::Create( + m_Context, // AST context + Cloned, // Owning class (Cloned CXXRecordDecl) + Field->getBeginLoc(), // Start location of field + Field->getLocation(), // End location of field + Field->getIdentifier(), // Field's name + Field->getType(), // Field's type + Field->getTypeSourceInfo(), // Type source info + Field->getBitWidth(), // Bit width (Expr*), nullptr if not a bitfield + Field->isMutable(), // Is the field mutable? + Field->getInClassInitStyle() // In-class initialization style + ); + + NewField->setAccess( + Field + ->getAccess()); // Set access specifier (public/private/protected) + Cloned->addDecl( + NewField); // Add the new field to the cloned CXXRecordDecl + } + + // Create operator() as a pullback + for (auto* Method : Original->methods()) { + if (CXXMethodDecl* OriginalOpCall = dyn_cast(Method)) { + if (OriginalOpCall->getOverloadedOperator() == OO_Call) { + auto* diffedOpCall = DifferentiateCallOperatorIfLambda(Original); + if (diffedOpCall) { + diffedOpCall->setAccess(OriginalOpCall->getAccess()); + // Cloned->addDecl(diffedOpCall); + + CXXMethodDecl* ClonedOpCall = CXXMethodDecl::Create( + m_Context, Cloned, diffedOpCall->getBeginLoc(), + OriginalOpCall->getNameInfo(), + diffedOpCall + ->getType(), // Function type (return type + parameters) + diffedOpCall->getTypeSourceInfo(), + diffedOpCall->getStorageClass() + CLAD_COMPAT_FunctionDecl_UsesFPIntrin_Param(diffedOpCall), + diffedOpCall->isInlineSpecified(), // Inline specifier + clad_compat::Function_GetConstexprKind( + diffedOpCall), // Constexpr specifier + diffedOpCall->getEndLoc() //, + // diffedOpCall->getTrailingRequiresClause() + ); + + llvm::SmallVector params; + for (unsigned i = 0; i < diffedOpCall->param_size(); ++i) { + ParmVarDecl* p = diffedOpCall->getParamDecl(i); + ParmVarDecl* NewParam = ParmVarDecl::Create( + m_Context, ClonedOpCall, p->getBeginLoc(), p->getLocation(), + p->getIdentifier(), p->getType(), p->getTypeSourceInfo(), + p->getStorageClass(), p->getDefaultArg()); + params.push_back(NewParam); + } + ClonedOpCall->setParams(params); + + // Copy the method body if it exists + if (diffedOpCall->hasBody()) { + Stmt* body = diffedOpCall->getBody(); + Stmt* ClonedBody = Clone(body); + ClonedOpCall->setBody(ClonedBody); + } + + ClonedOpCall->setAccess(OriginalOpCall->getAccess()); + Cloned->addDecl(ClonedOpCall); + + break; // we get into an infinite loop otherwise + } + } + } + } + + // Step 4: Finish defining the class + Cloned->completeDefinition(); + + return Cloned; + } + + StmtDiff ReverseModeVisitor::VisitLambdaExpr(const clang::LambdaExpr* LE) { + // ============== CAP + + auto children_iterator_range = LE->children(); + std::vector children_Exp; + std::vector children_Exp_dx; + + for (auto children : children_iterator_range) { + // auto children_expr = const_cast(dyn_cast(children)); + auto children_expr = dyn_cast(children); + if (children_expr) { + children_Exp.push_back(dyn_cast(Clone(children_expr))); + + // children_Exp_dx.push_back(children_expr); + + // if(isa(children_expr)) { + // std::string constructedTypeName = QualType::getAsString(dyn_cast(children_expr)->getType().split(), PrintingPolicy{ {} }); + // // if (!utils::IsKokkosTeamPolicy(constructedTypeName) && !utils::IsKokkosRange(constructedTypeName) && !utils::IsKokkosMember(constructedTypeName)) { + // auto children_exprV = Visit(children_expr); + // auto children_expr_copy = dyn_cast(Clone(children_expr)); + // children_expr_copy->setArg(0, children_exprV.getExpr_dx()); + // children_Exp_dx.push_back(children_expr_copy); + // // } + // } + // else if(isa(children_expr)) { + + // } + // else { + // auto children_exprV = Visit(children_expr); + // if (children_exprV.getExpr_dx()) { + // children_Exp_dx.push_back(children_exprV.getExpr_dx()); + // } + // } + } + } + + llvm::ArrayRef childrenRef_Exp = + clad_compat::makeArrayRef(children_Exp.data(), children_Exp.size()); + + llvm::ArrayRef childrenRef_Exp_dx; // = + // clad_compat::makeArrayRef(children_Exp_dx.data(), children_Exp_dx.size()); + + // ============== CAP + + // FIXME: ideally, we need to create a reverse_forw lambda and not copy the original one for the forward pass. + auto forwardLambdaClass = LE->getLambdaClass(); + + clang::LambdaIntroducer cloneIntro; + cloneIntro.Default = forwardLambdaClass->getLambdaCaptureDefault(); + cloneIntro.Range.setBegin(LE->getBeginLoc()); + cloneIntro.Range.setEnd(LE->getEndLoc()); + + clang::AttributeFactory cloneAttrFactory; + const clang::DeclSpec cloneDS(cloneAttrFactory); + clang::Declarator cloneD( + cloneDS, CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam + CLAD_COMPAT_CLANG12_Declarator_LambdaExpr); + clang::sema::LambdaScopeInfo* cloneLSI = m_Sema.PushLambdaScope(); + beginScope(clang::Scope::BlockScope | clang::Scope::FnScope | + clang::Scope::DeclScope); + m_Sema.ActOnStartOfLambdaDefinition( + cloneIntro, cloneD, + clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec( + getCurrentScope(), cloneDS)); + + cloneLSI->CallOperator = forwardLambdaClass->getLambdaCallOperator(); + + m_Sema.buildLambdaScope(cloneLSI, + cloneLSI->CallOperator, + LE->getIntroducerRange(), + LE->getCaptureDefault(), + LE->getCaptureDefaultLoc(), + LE->hasExplicitParameters(), + LE->hasExplicitResultType(), + true); + + auto forwardLE = LambdaExpr::Create(m_Context, + forwardLambdaClass, + LE->getIntroducerRange(), + LE->getCaptureDefault(), + LE->getCaptureDefaultLoc(), + LE->hasExplicitParameters(), + LE->hasExplicitResultType(), + childrenRef_Exp, + LE->getEndLoc(), + false); + + clang::LambdaExpr* reverseLE = nullptr; + CXXRecordDecl* diffedCXXRec = diffLambdaCXXRecordDecl(forwardLambdaClass); + + endScope(); + + clang::LambdaIntroducer Intro; + Intro.Default = forwardLambdaClass->getLambdaCaptureDefault(); + Intro.Range.setBegin(LE->getBeginLoc()); + Intro.Range.setEnd(LE->getEndLoc()); + + clang::AttributeFactory AttrFactory; + const clang::DeclSpec DS(AttrFactory); + clang::Declarator D( + DS, CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam + CLAD_COMPAT_CLANG12_Declarator_LambdaExpr); + clang::sema::LambdaScopeInfo* LSI = m_Sema.PushLambdaScope(); + beginScope(clang::Scope::BlockScope | clang::Scope::FnScope | + clang::Scope::DeclScope); + m_Sema.ActOnStartOfLambdaDefinition( + Intro, D, + clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec( + getCurrentScope(), DS)); + + LSI->CallOperator = diffedCXXRec->getLambdaCallOperator(); + + // ============== CAP + + std::vector children_LC_Exp_dx; + + for (auto children_expr : children_Exp_dx) { + if(isa(children_expr)) { + + auto tmp = dyn_cast(children_expr)->getArg(0)->IgnoreImpCasts(); + + if (isa(tmp)) { + auto VD = dyn_cast(dyn_cast(tmp)->getDecl()); + children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD)); + } + if(isa(tmp)) { + auto PE = dyn_cast(tmp); + auto OCE = dyn_cast(PE->getSubExpr()); + + auto VD = dyn_cast(dyn_cast(OCE->getArg(0))->getDecl()); + children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD)); + } + } + if (isa(children_expr)) { + auto VD = dyn_cast(dyn_cast(children_expr)->getDecl()); + children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD)); + } + } + // assert(children_Exp_dx.size() == children_LC_Exp_dx.size() && "Wrong number of captures"); + + llvm::ArrayRef childrenRef_LC_Exp_dx;// = + // clad_compat::makeArrayRef(children_LC_Exp_dx.data(), children_LC_Exp_dx.size()); + + // diffedCXXRec->setCaptures(m_Context, childrenRef_LC_Exp_dx); + + // ============== CAP + + m_Sema.buildLambdaScope(LSI, + LSI->CallOperator, + LE->getIntroducerRange(), + LCD_ByRef, + LE->getCaptureDefaultLoc(), + LE->hasExplicitParameters(), + LE->hasExplicitResultType(), + true); + + reverseLE = LambdaExpr::Create( + m_Context, diffedCXXRec, LE->getIntroducerRange(), + LCD_ByRef, LE->getCaptureDefaultLoc(), + LE->hasExplicitParameters(), LE->hasExplicitResultType(), + childrenRef_Exp_dx, LE->getEndLoc(), false); + + endScope(); + + return {forwardLE, reverseLE}; + } + StmtDiff ReverseModeVisitor::VisitCallExpr(const CallExpr* CE) { const FunctionDecl* FD = CE->getDirectCallee(); if (!FD) { @@ -1883,6 +2166,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); const auto* MD = dyn_cast(FD); + bool isLambda = (MD ? isLambdaCallOperator(MD) : false); // Method operators have a base like methods do but it's included in the // call arguments so we have to shift the indexing of call arguments. bool isMethodOperatorCall = MD && isa(CE); @@ -1890,17 +2174,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (std::size_t i = static_cast(isMethodOperatorCall), e = CE->getNumArgs(); i != e; ++i) { + llvm::errs() << "i: " << i << '\n'; const Expr* arg = CE->getArg(i); const auto* PVD = FD->getParamDecl( i - static_cast(isMethodOperatorCall)); StmtDiff argDiff{}; + + bool isArgLambda = clad::utils::isLambdaQType(arg->getType()); // is this argument a lambda? + // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (utils::IsReferenceOrPointerArg(arg) || + if (utils::IsReferenceOrPointerArg(arg)|| !m_DiffReq.shouldHaveAdjoint(PVD)) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); + } else if (isArgLambda) { + // TODO: this block is now the same as the one above, but we might want to actually save the differentiated lambda into a declaration first here. This way we wouldn't create new lambdas for the derivative every time the user passes the same lambda as an argument. + argDiff = Visit(arg); + CallArgDx.push_back(argDiff.getExpr_dx()); } else { // Create temporary variables corresponding to derivative of each // argument, so that they can be referred to when arguments is visited. @@ -1909,7 +2201,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // same as the call expression as it is the type used to declare the // _gradX array QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema); - VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy)); + VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy), false, nullptr, clang::VarDecl::InitializationStyle::CInit, isLambda); PreCallStmts.push_back(BuildDeclStmt(dArgDecl)); CallArgDx.push_back(BuildDeclRef(dArgDecl)); // Visit using uninitialized reference. @@ -2011,13 +2303,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /// Add base derivative expression in the derived call output args list if /// `CE` is a call to an instance member function. if (MD) { - if (isLambdaCallOperator(MD)) { - QualType ptrType = m_Context.getPointerType(m_Context.getRecordType( - FD->getDeclContext()->getOuterLexicalRecordContext())); - baseDiff = - StmtDiff(Clone(dyn_cast(CE)->getArg(0)), - new (m_Context) CXXNullPtrLiteralExpr(ptrType, Loc)); - } else if (MD->isInstance()) { + if (MD->isInstance()) { const Expr* baseOriginalE = nullptr; if (const auto* MCE = dyn_cast(CE)) baseOriginalE = MCE->getImplicitObjectArgument(); @@ -2026,21 +2312,33 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, baseDiff = Visit(baseOriginalE); baseExpr = baseDiff.getExpr(); - Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); - baseDiff.updateStmt(baseDiffStore); + if (!isLambda) { + Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); + baseDiff.updateStmt(baseDiffStore); + } + + llvm::errs() << "diff base: "; + baseExpr->dumpPretty(m_Context); + llvm::errs() << " "; + baseDiff.getExpr_dx()->dumpPretty(m_Context); + llvm::errs() << " "; + llvm::errs() << "\n"; + Expr* baseDerivative = baseDiff.getExpr_dx(); if (!baseDerivative->getType()->isPointerType()) baseDerivative = BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative); - DerivedCallOutputArgs.push_back(baseDerivative); + if (!isLambda) + DerivedCallOutputArgs.push_back(baseDerivative); } } for (auto* argDerivative : CallArgDx) { Expr* gradArgExpr = nullptr; + llvm::errs() << "i: " << idx << '\n'; QualType paramTy = FD->getParamDecl(idx)->getType(); if (!argDerivative || utils::isArrayOrPointerType(paramTy) || - isCladArrayType(argDerivative->getType())) + isCladArrayType(argDerivative->getType()) || clad::utils::isLambdaQType(paramTy)) gradArgExpr = argDerivative; else gradArgExpr = @@ -2108,12 +2406,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) .get(); - OverloadedDerivedFn = - m_Sema - .ActOnCallExpr(getCurrentScope(), selfRef, Loc, - pullbackCallArgs, Loc, CUDAExecConfig) - .get(); - } else { + OverloadedDerivedFn = m_Sema + .ActOnCallExpr(getCurrentScope(), selfRef, + Loc, pullbackCallArgs, Loc) + .get(); + } else if (!isLambda) { if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( pullbackCallArgs, PreCallStmts, dfdx()); @@ -2136,10 +2433,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (MD && isLambdaCallOperator(MD)) { - if (const auto* paramDecl = FD->getParamDecl(i)) - pullbackRequest.DVI.push_back(paramDecl); - } else if (DerivedCallOutputArgs[i + isaMethod]) + if (DerivedCallOutputArgs[i + isaMethod]) pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = nullptr; @@ -2196,7 +2490,36 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - if (OverloadedDerivedFn) { + if (isLambda) { + Stmts& block = getCurrentBlock(direction::reverse); + Stmts::iterator it = std::begin(block) + insertionPoint; + // Insert PreCallStmts + it = block.insert(it, PreCallStmts.begin(), PreCallStmts.end()); + it += PreCallStmts.size(); + // Insert the call + Expr* baseEdx = baseDiff.getExpr_dx(); // The pullback lambda + const CXXRecordDecl* EdxRD = baseEdx->getType()->getAsCXXRecordDecl(); + auto* CMD = const_cast(EdxRD->getLambdaCallOperator()); + NestedNameSpecifierLoc NNS(CMD->getQualifier(), + /*Data=*/nullptr); + auto DAP = DeclAccessPair::make(CMD, CMD->getAccess()); + auto* memberExpr = MemberExpr::Create( + m_Context, Clone(baseEdx), /*isArrow=*/false, Loc, NNS, noLoc, + CMD, DAP, CMD->getNameInfo(), + /*TemplateArgs=*/nullptr, m_Context.BoundMemberTy, + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( + NOUR_None)); + OverloadedDerivedFn = m_Sema + .BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc, + pullbackCallArgs, Loc) + .get(); + + // OverloadedDerivedFn = BuildCallExprToMemFn( + // baseEdx, FD->getName(), pullbackCallArgs, Loc); + it = block.insert(it, OverloadedDerivedFn); + it++; + } else if (OverloadedDerivedFn) { // Derivative was found. FunctionDecl* fnDecl = dyn_cast(OverloadedDerivedFn) ->getDirectCallee(); @@ -2991,6 +3314,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } + // Lambda function declaractions should be of auto type; + bool isLambda = false; + if (const RecordType* RT = VDType->getAs()) { + if (const CXXRecordDecl* RD = dyn_cast(RT->getDecl())) + isLambda = RD->isLambda(); + } + if (isLambda) + VDCloneType = VDDerivedType = m_Context.getAutoDeductType(); + // Check if the variable is pointer type and initialized by new expression if (isPointerType && VD->getInit() && isa(VD->getInit())) isInitializedByNewExpr = true; @@ -3013,7 +3345,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } VDDerived = BuildGlobalVarDecl( VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, - nullptr, VarDecl::InitializationStyle::CInit); + m_Context.getTrivialTypeSourceInfo(VDDerivedType), + VarDecl::InitializationStyle::CInit); } else { // If VD is a reference to a local variable, then the initial value is set // to the derived variable of the corresponding local variable. @@ -3023,8 +3356,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // `VDDerivedType` is the corresponding non-reference type and the initial // value is set to 0. // Otherwise, for non-reference types, the initial value is set to 0. - if (!VDDerivedInit) + if (!(VDDerivedInit || isLambda)) { VDDerivedInit = getZeroInit(VDType); + } else if (isLambda) { + if (const Expr* init = VD->getInit()) { + initDiff = Visit(init); + VDDerivedInit = initDiff.getExpr_dx(); + } /* else ==> invalid lambda */ + } // `specialThisDiffCase` is only required for correctly differentiating // the following code: @@ -3104,10 +3443,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VDDerivedInit = getZeroInit(VDDerivedType); } } - if (initializeDerivedVar) + if (initializeDerivedVar) { VDDerived = BuildGlobalVarDecl( VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, - nullptr, VD->getInitStyle()); + m_Context.getTrivialTypeSourceInfo(VDDerivedType), + VD->getInitStyle()); + } } if (!m_DiffReq.shouldHaveAdjoint((VD))) @@ -3194,9 +3535,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()), VD->isDirectInit()); else - VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(), - initDiff.getExpr(), VD->isDirectInit(), - nullptr, VD->getInitStyle()); + VDClone = BuildGlobalVarDecl( + VDCloneType, VD->getNameAsString(), initDiff.getExpr(), + VD->isDirectInit(), m_Context.getTrivialTypeSourceInfo(VDCloneType), + VD->getInitStyle()); if (isPointerType && derivedVDE) { if (promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, @@ -3316,8 +3658,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass; // If the DeclStmt is not empty, check the first declaration in case it is a - // lambda function. This case it is treated separately for now and we don't - // create a variable for its derivative. + // lambda function. This case it is treated differently. bool isLambda = false; const auto* declsBegin = DS->decls().begin(); if (declsBegin != DS->decls().end() && isa(*declsBegin)) { @@ -3327,12 +3668,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, QT = QT->getPointeeType(); auto* typeDecl = QT->getAsCXXRecordDecl(); - // We should also simply copy the original lambda. The differentiation - // of lambdas is happening in the `VisitCallExpr`. For now, only the - // declarations with lambda expressions without captures are supported. isLambda = typeDecl && typeDecl->isLambda(); - if (isLambda || - (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) { + if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) { for (auto* D : DS->decls()) if (auto* VD = dyn_cast(D)) decls.push_back(VD); @@ -3350,7 +3687,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (auto* VD = dyn_cast(D)) { DeclDiff VDDiff; - if (!isLambda) VDDiff = DifferentiateVarDecl(VD); // Here, we move the declaration to the function global scope. diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index e8fce3628..5b2fe0577 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -106,21 +106,30 @@ namespace clad { VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier, Expr* Init, bool DirectInit, TypeSourceInfo* TSI, - VarDecl::InitializationStyle IS) { + VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) { return BuildVarDecl(Type, Identifier, getCurrentScope(), Init, DirectInit, - TSI, IS); + TSI, IS, pushCodeSynthCtxt); } VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier, Scope* Scope, Expr* Init, bool DirectInit, TypeSourceInfo* TSI, - VarDecl::InitializationStyle IS) { + VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) { // add namespace specifier in variable declaration if needed. Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type); auto* VD = VarDecl::Create( m_Context, m_Sema.CurContext, m_DiffReq->getLocation(), m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None); + bool isLambda = false; if (Init) { + if (const RecordType* RT = Init->getType()->getAs()) { + if (const CXXRecordDecl* RD = dyn_cast(RT->getDecl())) + isLambda = RD->isLambda(); + } + if (isLambda || pushCodeSynthCtxt) { + clang::Sema::CodeSynthesisContext csc; + m_Sema.pushCodeSynthesisContext(csc); + } m_Sema.AddInitializerToDecl(VD, Init, DirectInit); VD->setInitStyle(IS); } else { @@ -129,6 +138,8 @@ namespace clad { m_Sema.FinalizeDeclaration(VD); // Add the identifier to the scope and IdResolver m_Sema.PushOnScopeChains(VD, Scope, /*AddToContext*/ false); + if (Init && (isLambda || pushCodeSynthCtxt)) + m_Sema.popCodeSynthesisContext(); return VD; } @@ -141,9 +152,9 @@ namespace clad { VarDecl* VisitorBase::BuildVarDecl(QualType Type, llvm::StringRef prefix, Expr* Init, bool DirectInit, TypeSourceInfo* TSI, - VarDecl::InitializationStyle IS) { + VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) { return BuildVarDecl(Type, CreateUniqueIdentifier(prefix), Init, DirectInit, - TSI, IS); + TSI, IS, pushCodeSynthCtxt); } VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,