Replies: 1 comment
-
Update to the discussionWe have changed the design and usage API concerning derived types. Now Clad generates a specialisation of Users can tell Clad to generate a derived type for any pair of types // This declaration triggers Clad to generate `clad::DerivativeOf<double, ComplexNumber` specialization.
clad::BuildTangentType<double, ComplexNumber> buildDoubleWrtComplexNumber; Users can also get type information for a derived type using // d_c is a variable of type clad::DerivativeOf<double, ComplexNumber>
typename clad::TangentTypeInfo<double, ComplexNumber>::type d_c; This new design syntax provides a more intuitive, cleaner way to use derived types. It also does not suffer from the limitations of the previous design such as not being able to describe nested types. Few important things to note:
|
Beta Was this translation helpful? Give feedback.
-
This write-up aims at discussing the library design and decisions for the feature differentiating with respect to aggregate types in Clad. We will start with the current progress and the core ideas used, this will be followed up by some of the new ideas and designs, inspired from several other AD libraries notably from Swift and Julia, that we can use to improve the current design.
Current progress
Derived Types
Derivative of a function with respect to a real variable ‘a’ is the rate of change in the output of the function with respect to the variable ‘a’. We can extend this concept to aggregate types. Derivative of a function with respect to an aggregate type should contain information about derivatives with respect to all the data members of the aggregate types. The type of variables that can store derivatives will be referred to as derived types throughout this document.
For example, if we are differentiating a variable
a
ofdouble
type with respect to a variablec
ofComplexNumber
type. Then the derivative should contain information about both derivativeof
a
with respect toc.real
and derivative ofa
with respect toc.im
.The current design focuses on clad core ideology, that is differentiating existing algorithms without the need of modifying them.
Currently, users need to provide a forward declaration for each type that clad requires to perform differentiation.
This forward declaration will instruct clad to provide a body for this class declaration such that this class can store the derivative of a
double
variable with respect to aComplexNumber
variable.Clad will synthesize the following definition for this class:
Now, variables of this type (class) will represent derivatives of a
double
variable with respect toComplexNumber
variable. Thus,__clad_double_wrt_ComplexNumber
is an example of a derived type.These types are designed to be directly used by users. Users will need to use them to store derivatives obtained from clad differentiation functions and access the derivatives. Since these are ordinary variables, users will also be able to pass the derivatives to different routines after computing them.
Example Usage:
The forward declaration is required because these types need to be visible to the users so that they can easily use them.
Some more examples of derived types:
Here, notice that we are reusing derived types as data members in other more complex derived types. This reuse derived type architecture makes derived types easier to use and comprehend .
The derived types for any pair of types are both easy to create algorithmically and simpler to use.
The core idea that will be used during the differentiation of statements containing user-defined types is, no matter what data structures an expression uses, on breaking it down to simpler building blocks, it is always doing simple arithmetic on directly differentiable types (real numbers). This is also the core idea behind automatic differentiation in general.
Breaking design changes
Change in the signature of forward mode derived functions
The signature of forward mode derived function has been changed from,
ReturnType fn_darg0(ArgType1 arg1, ArgType2 arg2, …);
To
This also has an undesired consequence that users will directly or indirectly need to
delete
the memory allocated for the pointer that is returned by the derived function.Example usage:
This design change is required because we cannot compute the return type of the derived function at compile-time. This is because we cannot determine at compile-type, what is the type of the argument with respect to which function needs to be differentiated.
To put this into more concrete terms, consider this example:
Here, we cannot determine the type of argument with respect to which we are differentiating
fn
. It is because we cannot work with strings at compile-time and we cannot obtain names of the function parameters.Change in the signature of reverse-mode derived functions
The signature of reverse-mode derived functions have been changed from,
To
Here,
clad::array_ref<OriginalFnReturnType>
have been replaced byclad::array_ref<void>
. This is again because we cannot compute the derived types at compile time.Shortcomings of the forward declaration approach
The forward declaration approach works well for simpler cases. But this approach doesn’t extend well to supporting classes inside namespaces, nested classes, and template classes.
Therefore we need a better design that does not have these limitations.
Improvement Ideas
Now I will discuss some of the improvements that can be made but require further discussion to become more refined and get accepted.
An attribute for non-differentiable fields
All fields of a class/struct do not need to be differentiated. Class/struct can have non-differentiable fields as well. If we try to compute derivatives of those fields then that will take unnecessary computation resources. We can define some attributes that users can use to denote a class field as non-differentiable. For example:
Maintain signature of the forward mode derived functions
We can compute the return type of the forward mode derived function using advanced metaprogramming techniques under the condition that the independent argument is specified
using index only.
Using this approach, the forward declaration of the derived type will generate an additional structure that will contain information of the derived type required for storing derivative of type A with respect to type B.
To put it into more concrete terms, consider this example:
This approach is possible because we can compute the type of the ith parameter of a function using metaprogramming.
Advantages of using this approach:
void*
clad::differentiate
API.delete
the allocated memory for the derived objects.Disadvantages of using this approach:
Restrict to differentiating scalar types with respect to aggregate types.
The current design is made while keeping in mind that we need to extend the implementation to support differentiating aggregate types with respect to aggregate types. If we can restrict to just differentiating scalar types with respect to aggregate types then we can simplify a lot of things and hopefully make the design more intuitive and implementation much easier.
Other libraries in the field, notably Python Jax, Swift inbuilt automatic differentiation support, Zygote.jl also only support differentiating scalar types with respect to aggregate types.
Any suggestions or comments regarding this discussion are welcome. Please feel free to ask any questions.
Beta Was this translation helpful? Give feedback.
All reactions