Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lambda support in the reverse mode #1126

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

gojakuch
Copy link
Collaborator

@gojakuch gojakuch commented Oct 29, 2024

Potentially fixes: #1054

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 10 out of 21. Check the log or trigger a new build to see more.

@@ -403,6 +403,16 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy,
#define CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam clang::ParsedAttributesView::none(),
#endif

#if CLANG_VERSION_MAJOR > 12
#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind( \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: function-like macro 'CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind' used; consider a 'constexpr' template function [cppcoreguidelines-macro-usage]

#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind(                 \
        ^

ReverseModeVisitor::diffLambdaCXXRecordDecl(const CXXRecordDecl* Original) {
// Create a new Lambda CXXRecordDecl that is going to represent a pullback
CXXRecordDecl* Cloned = CXXRecordDecl::CreateLambda(
m_Context, const_cast<DeclContext*>(Original->getDeclContext()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]

        m_Context, const_cast<DeclContext*>(Original->getDeclContext()),
                   ^


// Create operator() as a pullback
for (auto* Method : Original->methods()) {
if (CXXMethodDecl* OriginalOpCall = dyn_cast<CXXMethodDecl>(Method)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use auto when initializing with a template cast to avoid duplicating the type name [modernize-use-auto]

Suggested change
if (CXXMethodDecl* OriginalOpCall = dyn_cast<CXXMethodDecl>(Method)) {
if (auto* OriginalOpCall = dyn_cast<CXXMethodDecl>(Method)) {

std::vector<Expr *> children_Exp;
std::vector<Expr *> children_Exp_dx;

for (auto children : children_iterator_range) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto children' can be declared as 'const auto *children' [llvm-qualified-auto]

Suggested change
for (auto children : children_iterator_range) {
for (const auto *children : children_iterator_range) {


for (auto children : children_iterator_range) {
// auto children_expr = const_cast<clang::Expr*>(dyn_cast<clang::Expr>(children));
auto children_expr = dyn_cast<clang::Expr>(children);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto children_expr' can be declared as 'const auto *children_expr' [llvm-qualified-auto]

Suggested change
auto children_expr = dyn_cast<clang::Expr>(children);
const auto *children_expr = dyn_cast<clang::Expr>(children);

// ============== CAP

// FIXME: ideally, we need to create a reverse_forw lambda and not copy the original one for the forward pass.
auto forwardLambdaClass = LE->getLambdaClass();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto forwardLambdaClass' can be declared as 'auto *forwardLambdaClass' [llvm-qualified-auto]

Suggested change
auto forwardLambdaClass = LE->getLambdaClass();
auto *forwardLambdaClass = LE->getLambdaClass();

LE->getCaptureDefaultLoc(),
LE->hasExplicitParameters(),
LE->hasExplicitResultType(),
true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: too many arguments to function call, expected 7, have 8 [clang-diagnostic-error]

                            true);
                            ^
Additional context

llvm/include/clang/Sema/Sema.h:7176: 'buildLambdaScope' declared here

  void buildLambdaScope(sema::LambdaScopeInfo *LSI, CXXMethodDecl *CallOperator,
       ^

LE->hasExplicitResultType(),
true);

auto forwardLE = LambdaExpr::Create(m_Context,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto forwardLE' can be declared as 'auto *forwardLE' [llvm-qualified-auto]

Suggested change
auto forwardLE = LambdaExpr::Create(m_Context,
auto *forwardLE = LambdaExpr::Create(m_Context,


std::vector<LambdaCapture> children_LC_Exp_dx;

for (auto children_expr : children_Exp_dx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto children_expr' can be declared as 'auto *children_expr' [llvm-qualified-auto]

Suggested change
for (auto children_expr : children_Exp_dx) {
for (auto *children_expr : children_Exp_dx) {

for (auto children_expr : children_Exp_dx) {
if(isa<CXXConstructExpr>(children_expr)) {

auto tmp = dyn_cast<CXXConstructExpr>(children_expr)->getArg(0)->IgnoreImpCasts();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto tmp' can be declared as 'auto *tmp' [llvm-qualified-auto]

Suggested change
auto tmp = dyn_cast<CXXConstructExpr>(children_expr)->getArg(0)->IgnoreImpCasts();
auto *tmp = dyn_cast<CXXConstructExpr>(children_expr)->getArg(0)->IgnoreImpCasts();

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 10 out of 11. Check the log or trigger a new build to see more.

auto tmp = dyn_cast<CXXConstructExpr>(children_expr)->getArg(0)->IgnoreImpCasts();

if (isa<DeclRefExpr>(tmp)) {
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(tmp)->getDecl());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto VD' can be declared as 'auto *VD' [llvm-qualified-auto]

Suggested change
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(tmp)->getDecl());
auto *VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(tmp)->getDecl());


if (isa<DeclRefExpr>(tmp)) {
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(tmp)->getDecl());
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use emplace_back instead of push_back [modernize-use-emplace]

Suggested change
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
children_LC_Exp_dx.emplace_back(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD);

children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
}
if(isa<ParenExpr>(tmp)) {
auto PE = dyn_cast<ParenExpr>(tmp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto PE' can be declared as 'auto *PE' [llvm-qualified-auto]

Suggested change
auto PE = dyn_cast<ParenExpr>(tmp);
auto *PE = dyn_cast<ParenExpr>(tmp);

}
if(isa<ParenExpr>(tmp)) {
auto PE = dyn_cast<ParenExpr>(tmp);
auto OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto OCE' can be declared as 'auto *OCE' [llvm-qualified-auto]

Suggested change
auto OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());
auto *OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());

auto PE = dyn_cast<ParenExpr>(tmp);
auto OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());

auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(OCE->getArg(0))->getDecl());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto VD' can be declared as 'auto *VD' [llvm-qualified-auto]

Suggested change
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(OCE->getArg(0))->getDecl());
auto *VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(OCE->getArg(0))->getDecl());

auto OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());

auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(OCE->getArg(0))->getDecl());
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use emplace_back instead of push_back [modernize-use-emplace]

Suggested change
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
children_LC_Exp_dx.emplace_back(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD);

}
}
if (isa<DeclRefExpr>(children_expr)) {
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(children_expr)->getDecl());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto VD' can be declared as 'auto *VD' [llvm-qualified-auto]

Suggested change
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(children_expr)->getDecl());
auto *VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(children_expr)->getDecl());

}
if (isa<DeclRefExpr>(children_expr)) {
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(children_expr)->getDecl());
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use emplace_back instead of push_back [modernize-use-emplace]

Suggested change
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
children_LC_Exp_dx.emplace_back(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD);

LE->getCaptureDefaultLoc(),
LE->hasExplicitParameters(),
LE->hasExplicitResultType(),
true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: too many arguments to function call, expected 7, have 8 [clang-diagnostic-error]

                              true);
                              ^
Additional context

llvm/include/clang/Sema/Sema.h:7176: 'buildLambdaScope' declared here

  void buildLambdaScope(sema::LambdaScopeInfo *LSI, CXXMethodDecl *CallOperator,
       ^

// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
// modified by the derived callee function.
if (utils::IsReferenceOrPointerArg(arg) ||
if (utils::IsReferenceOrPointerArg(arg)||
!m_DiffReq.shouldHaveAdjoint(PVD)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: repeated branch in conditional chain [bugprone-branch-clone]

          !m_DiffReq.shouldHaveAdjoint(PVD)) {
                                             ^
Additional context

lib/Differentiator/ReverseModeVisitor.cpp:2191: end of the original

      } else if (isArgLambda) {
       ^

lib/Differentiator/ReverseModeVisitor.cpp:2191: clone 1 starts here

      } else if (isArgLambda) {
                              ^

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant