Skip to content

Commit

Permalink
Add bitmasked-option for tbr analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Mar 8, 2024
1 parent d7e5434 commit da1bac3
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 83 deletions.
13 changes: 9 additions & 4 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,25 @@ enum order {
enum opts {
use_enzyme = 1 << ORDER_BITS,
vector_mode = 1 << (ORDER_BITS + 1),

// Storing two bits for tbr analysis.
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),
}; // enum opts

constexpr unsigned GetDerivativeOrder(unsigned const bitmasked_opts) {
constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) {
return bitmasked_opts & ORDER_MASK;
}

constexpr bool HasOption(unsigned const bitmasked_opts, unsigned const option) {
constexpr bool HasOption(const unsigned bitmasked_opts, const unsigned option) {
return (bitmasked_opts & option) == option;
}

constexpr unsigned GetBitmaskedOpts() { return 0; }
constexpr unsigned GetBitmaskedOpts(unsigned const first) { return first; }
constexpr unsigned GetBitmaskedOpts(const unsigned first) { return first; }
template <typename... Opts>
constexpr unsigned GetBitmaskedOpts(unsigned const first, Opts... opts) {
constexpr unsigned GetBitmaskedOpts(const unsigned first, Opts... opts) {
return first | GetBitmaskedOpts(opts...);
}

Expand Down
10 changes: 9 additions & 1 deletion include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ namespace clad {
using DiffSchedule = llvm::SmallVector<DiffRequest, 16>;
using DiffInterval = std::vector<clang::SourceRange>;

struct RequestOptions {
/// This is a flag to indicate the default behaviour to enable/disable
/// TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
};

class DiffCollector: public clang::RecursiveASTVisitor<DiffCollector> {
/// The source interval where clad was activated.
///
Expand All @@ -101,9 +107,11 @@ namespace clad {
const clang::FunctionDecl* m_TopMostFD = nullptr;
clang::Sema& m_Sema;

RequestOptions& m_Options;

public:
DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval,
DiffSchedule& plans, clang::Sema& S);
DiffSchedule& plans, clang::Sema& S, RequestOptions& opts);
bool VisitCallExpr(clang::CallExpr* E);

private:
Expand Down
54 changes: 22 additions & 32 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <unsigned... BitMaskedOpts /*To check for enzyme*/,
typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = GradientDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
Expand All @@ -376,9 +375,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <unsigned... BitMaskedOpts /*To check for enzyme*/,
typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = GradientDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
Expand All @@ -397,38 +395,34 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = HessianDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = HessianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("H")))
hessian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by hessian*/,
code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code);
}

/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = HessianDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = HessianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("H")))
hessian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by hessian*/,
code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code, f);
}

/// Generates function which computes jacobian matrix of the given function
Expand All @@ -438,38 +432,34 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("J")))
jacobian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by Jacobian*/,
code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
}

/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("J")))
jacobian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by Jacobian*/,
code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code, f);
}

template <typename ArgSpec = const char*, typename F,
Expand Down
55 changes: 35 additions & 20 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ namespace clad {
}

DiffCollector::DiffCollector(DeclGroupRef DGR, DiffInterval& Interval,
DiffSchedule& plans, clang::Sema& S)
DiffSchedule& plans, clang::Sema& S,
RequestOptions& opts)
: m_Interval(Interval), m_DiffPlans(plans), m_TopMostFD(nullptr),
m_Sema(S) {
m_Sema(S), m_Options(opts) {

if (Interval.empty())
return;
Expand Down Expand Up @@ -556,27 +557,52 @@ namespace clad {
return true;
DiffRequest request{};

if (A->getAnnotation().equals("D")) {
request.Mode = DiffMode::forward;

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
bool enable_tbr_in_req = false;
bool disable_tbr_in_req = false;
if (!A->getAnnotation().equals("E") && FD->getTemplateSpecializationArgs()) {
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();

// Set option for TBR analysis.
enable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr);
disable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr);
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
return true;

Check warning on line 582 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L580-L582

Added lines #L580 - L582 were not covered by tests
}
if (enable_tbr_in_req || disable_tbr_in_req) {
// override the default value of TBR analysis.
request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req;
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
}

if (A->getAnnotation().equals("D")) {
request.Mode = DiffMode::forward;
unsigned derivative_order =
clad::GetDerivativeOrder(bitmasked_opts_value);
if (derivative_order == 0) {
derivative_order = 1; // default to first order derivative.
}
request.RequestedDerivativeOrder = derivative_order;
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) {
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
if (enable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"TBR analysis is not yet supported in forward mode.");
return true;

Check warning on line 605 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L603-L605

Added lines #L603 - L605 were not covered by tests
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
request.Mode = DiffMode::vector_forward_mode;
Expand All @@ -601,17 +627,6 @@ namespace clad {
request.Mode = DiffMode::jacobian;
} else if (A->getAnnotation().equals("G")) {
request.Mode = DiffMode::reverse;

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
// reverse vector mode is not yet supported.
Expand Down
4 changes: 2 additions & 2 deletions test/Analyses/TBR.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang -mllvm -debug-only=clad-tbr -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s
// RUN: %cladclang -mllvm -debug-only=clad-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s
// REQUIRES: asserts
//CHECK-NOT: {{.*error|warning|note:.*}}

Expand All @@ -13,7 +13,7 @@ double f1(double x) {

#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient(F);\
auto F##grad = clad::gradient<clad::opts::enable_tbr>(F);\
F##grad.execute(x, result);\
printf("{%.2f}\n", result[0]); \
}
Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ int main() {
d_structPointer.execute(5, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 1.00

auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x");
auto d_cStyleMemoryAlloc = clad::gradient<clad::opts::disable_tbr>(cStyleMemoryAlloc, "x");
d_x = 0;
d_cStyleMemoryAlloc.execute(5, 7, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 4.00
Expand Down
7 changes: 7 additions & 0 deletions test/Misc/Args.C
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
// CHECK_HELP-NEXT: -fdump-derived-fn
// CHECK_HELP-NEXT: -fdump-derived-fn-ast
// CHECK_HELP-NEXT: -fgenerate-source-file
// CHECK_HELP-NEXT: -fno-validate-clang-version
// CHECK_HELP-NEXT: -enable-tbr
// CHECK_HELP-NEXT: -disable-tbr
// CHECK_HELP-NEXT: -fcustom-estimation-model
// CHECK_HELP-NEXT: -fprint-num-diff-errors
// CHECK_HELP-NEXT: -help
Expand All @@ -23,3 +26,7 @@
// RUN: -Xclang %t.so %S/../../demos/ErrorEstimation/CustomModel/test.cpp \
// RUN: -I%S/../../include 2>&1 | FileCheck --check-prefix=CHECK_SO_INVALID %s
// CHECK_SO_INVALID: Failed to load '{{.*.so}}', {{.*}}. Aborting.

// RUN: clang -fsyntax-only -fplugin=%cladlib -Xclang -plugin-arg-clad -Xclang -enable-tbr \
// RUN: -Xclang -plugin-arg-clad -Xclang -disable-tbr %s 2>&1 | FileCheck --check-prefix=CHECK_TBR %s
// CHECK_TBR: -enable-tbr and -disable-tbr cannot be used together
24 changes: 17 additions & 7 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,15 @@ namespace clad {
if (m_HandleTopLevelDeclInternal)
return true;

RequestOptions opts{};
SetRequestOptions(opts);
DiffSchedule requests{};
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema());
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema(),
opts);

if (requests.empty())
return true;

// FIXME: flags have to be set manually since DiffCollector's constructor
// does not have access to m_DO.
if (m_DO.EnableTBRAnalysis)
for (DiffRequest& request : requests)
request.EnableTBRAnalysis = true;

// FIXME: Remove the PerformPendingInstantiations altogether. We should
// somehow make the relevant functions referenced.
// Instantiate all pending for instantiations templates, because we will
Expand Down Expand Up @@ -318,6 +315,19 @@ namespace clad {
m_HasRuntime = !R.empty();
return m_HasRuntime;
}

void CladPlugin::SetRequestOptions(RequestOptions& opts) {
/// ---------- Set TBR analysis flag ----------
// If user has explicitly specified the mode for TBR analysis, use it.
if (m_DO.EnableTBRAnalysis || m_DO.DisableTbrAnalysis) {
opts.EnableTBRAnalysis =
m_DO.EnableTBRAnalysis && !m_DO.DisableTbrAnalysis;
} else {
// If user has not specified the mode for TBR analysis, use the default
// mode.
opts.EnableTBRAnalysis = false; // Default mode.
}
}
} // end namespace plugin

clad::CladTimerGroup::CladTimerGroup()
Expand Down
Loading

0 comments on commit da1bac3

Please sign in to comment.