Skip to content

Commit

Permalink
Improve CallExpr analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk committed Nov 25, 2024
1 parent ef4ff13 commit 2f11458
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 188 deletions.
19 changes: 12 additions & 7 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ struct DiffRequest {
} m_TbrRunInfo;

mutable struct ActivityRunInfo {
std::set<const clang::VarDecl*> ToBeRecorded;
// std::set<const clang::VarDecl*> ToBeRecorded;
bool HasAnalysisRun = false;
} m_ActivityRunInfo;

public:
static std::set<const clang::VarDecl*> AllVariedDecls;

/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
/// Name of the base function to be differentiated. Can be different from
Expand Down Expand Up @@ -145,12 +147,15 @@ struct DiffRequest {
bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;

void setToBeRecorded(std::set<const clang::VarDecl*> init) {
this->m_ActivityRunInfo.ToBeRecorded = init;
}
std::set<const clang::VarDecl*> getToBeRecorded() const {
return this->m_ActivityRunInfo.ToBeRecorded;
}
// void setToBeRecorded(std::set<const clang::VarDecl*> init) {
// this->m_ActivityRunInfo.ToBeRecorded = init;
// }
// std::set<const clang::VarDecl*> getToBeRecorded() const {
// for(auto i: m_ActivityRunInfo.ToBeRecorded){
// AllVariedDecls.insert(i);
// }
// //return this->m_ActivityRunInfo.ToBeRecorded;
// }
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand Down
35 changes: 18 additions & 17 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
FunctionDecl* FD = CE->getDirectCallee();
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
bool restoreMarking = m_Marking;
bool restoreVaried = m_Varied;
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
Expand All @@ -130,25 +132,24 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
while (innermostType->isPointerType())
innermostType = innermostType->getPointeeType();

if ((parType->isReferenceType() ||
utils::isArrayOrPointerType(parType)) &&
!innermostType.isConstQualified()) {
m_Marking = true;
m_Varied = true;
}

m_Varied = false;
m_Marking = false;
TraverseStmt(par);
if ((parType->isReferenceType() ||
utils::isArrayOrPointerType(parType)) &&
!innermostType.isConstQualified()) {
m_Marking = false; //?
m_Varied = false;
}

if ((m_Varied || !innermostType.isConstQualified()))
if (m_Varied)
m_VariedDecls.insert(FDparam[i]);
else if ((parType->isReferenceType() ||
(utils::isArrayOrPointerType(parType) &&
!innermostType.isConstQualified()))) {
m_Varied = true;
m_Marking = true;
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
}
m_Varied = restoreVaried;
m_Marking = restoreMarking;
}

return true;
}

Expand All @@ -161,10 +162,10 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
innermost = innermost->getPointeeType();
if (VDTy->isPointerType() && !innermost.isConstQualified()) {
copyVarToCurBlock(cast<VarDecl>(D));
continue;
m_Varied = true;
} else if (VDTy->isArrayType()) {
copyVarToCurBlock(cast<VarDecl>(D));
continue;
m_Varied = true;
}

if (Expr* init = cast<VarDecl>(D)->getInit()) {
Expand Down
Loading

0 comments on commit 2f11458

Please sign in to comment.