forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathVariableFallbackKernel.cpp
111 lines (91 loc) · 3.57 KB
/
VariableFallbackKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/VariableHooksInterface.h>
#include <torch/library.h>
/*
* This file implements a variable fallback kernel for custom operators.
* Since tensors always have the Autograd set, but custom operators
* usually don't have a kernel registered for Autograd, the dispatcher
* will call into this fallback kernel instead.
* Note that this is not a correct autograd implementation. It will just
* fallthrough to the custom operator implementation.
* If you want a custom operator to work with autograd, you need to use
* autograd::Function so that the custom operator implementation knows how to
* do autograd.
* Note also that ops from native_functions.yaml register their own variable
* kernels, so this is never called for them.
*/
// TODO This whole file should be deleted and replaced with the mechanism
// described in https://github.com/pytorch/pytorch/issues/29548
using c10::Stack;
namespace {
// Register fallthrough for Autograd backends dispatch keys
// NB: But not the private use ones; maybe the extension wants
// to override it themselves!
void autograd_fallback(
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack);
#ifdef C10_MOBILE
// NOTE [mobile/edge builds and the autograd fallback]
// To save on binary size, some of the mobile configs don't include the
// autograd kernels for built-in operators (VariableTypeEverything.cpp).
// For the mobile build:
// - we don't care about having a nice autograd fallback that warns if
// an operator has incorrect autograd support. If you're running
// a custom operator on mobile then it's already too late for us to warn
// or error on it.
// - for perf reasons, we do not want mobile to go through autograd_fallback
// for all operators (the boxing/unboxing adds overhead).
// As a result, on mobile we set the fallback to the fallthrough.
#define AUTOGRAD_FALLBACK torch::CppFunction::makeFallthrough()
#else
#define AUTOGRAD_FALLBACK torch::CppFunction::makeFromBoxedFunction<&autograd_fallback>()
#endif
TORCH_LIBRARY_IMPL(_, AutogradOther, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradCPU, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradXPU, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradCUDA, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradLazy, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradMPS, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradMeta, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
// see Note [ADInplaceOrView key]
TORCH_LIBRARY_IMPL(_, ADInplaceOrView, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
#undef AUTOGRAD_FALLBACK
void autograd_fallback(
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) {
// PyTorch has separate builds, some of which don't include autograd.
// So we define some behavior for when autograd isn't included and
// go through a layer of indirection (VariableHooksInterface) when it is.
// See aten/src/ATen/core/VariableHooksInterface.h for more details.
if (!at::impl::HasVariableHooks()) {
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
return;
}
at::impl::GetVariableHooks()->basic_autograd_not_implemented_fallback(op, dispatch_keys, stack);
}
} // namespace