forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FuncTorchTLS.h
46 lines (38 loc) · 1.85 KB
/
FuncTorchTLS.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#pragma once
#include <memory>
#include <c10/macros/Macros.h>
namespace at { namespace functorch {
// NOTE [functorch TLS in pytorch/pytorch]
//
// functorch lives out-of-tree. However, it has some TLS that needs to be
// propagated. The solution for that is we store a pointer to the TLS
// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
// include whatever functorch needs.
//
// We need to store a pointer due to the indirection:
// inside functorch, we will create a subclass of FunctorchTLSBase called
// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined yet.
//
// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
// We can't directly pass around FunctorchTLSBase (without a pointer) because
// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
// more elements.
struct TORCH_API FuncTorchTLSBase {
virtual ~FuncTorchTLSBase() = default;
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
// functorch doesn't always work with autograd.Function.
// This is a hook to get into functorch -- functorch will determine
// if it should raise an error message
virtual int64_t checkSupportsAutogradFunction() const = 0;
virtual void checkSupportsInplaceRequiresGrad() const = 0;
virtual void checkSupportsRetainGrad() const = 0;
};
// returns deepcopy of the functorch tls
TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
// sets the functorch tls. always does a deep copy.
TORCH_API void setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase>& state);
// get a mutable reference to the functorch tls
TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
}}