forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CPUFallback.cpp
191 lines (173 loc) · 8.99 KB
/
CPUFallback.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#include <ATen/native/CPUFallback.h>
#include <sstream>
#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include <ATen/Functions.h>
namespace at { namespace native {
// convenience helper for converting tensors to cpu
std::vector<at::Tensor> to_cpu(const at::TensorList& tensors) {
// We can't just call at::to_cpu() on the entire list of Tensors
// Because it will break on undefined tensors. Separate out undefined tensors first.
std::vector<at::Tensor> cpu_tensors(tensors.size());
std::vector<at::Tensor> valid_tensors;
std::vector<bool> to_translate(tensors.size());
for (const auto i : c10::irange(tensors.size())) {
const at::Tensor& tensor = tensors[i];
// Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it.
// Otherwise, we'd need to require all backends with their own implementation of _to_cpu
// to properly handle undefined tensors.
if (tensor.defined()) {
to_translate[i] = true;
valid_tensors.push_back(tensor);
} else {
cpu_tensors[i] = tensor;
}
}
auto cpu_valid_tensors = at::_to_cpu(valid_tensors);
for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) {
if (to_translate[i]) {
cpu_tensors[i] = std::move(cpu_valid_tensors[defined_pos++]);
}
}
return cpu_tensors;
}
c10::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args, std::vector<c10::List<at::Tensor>> tlist_args) {
// Decide what device to move the output tensor(s) to.
// The current convention is that we use the first tensor arg to pick the device
// Barring that, we take the first tensor from a TensorList arg.
if (t_args.size() > 0) {
return t_args[0].device();
} else {
// We need to loop through all of the (potentially multiple) TensorList arguments
// In case, e.g. the first one is empty but the second is not.
for (auto& tens_list : tlist_args) {
for (const auto i : c10::irange(tens_list.size())) {
return tens_list.get(i).device();
}
}
}
return c10::nullopt;
}
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& schema_args = op.schema().arguments();
const auto num_arguments = schema_args.size();
auto arguments = torch::jit::last(stack, num_arguments);
const auto arguments_begin = stack->size() - num_arguments;
std::vector<at::Tensor> tensor_args;
std::vector<int> tensor_args_indices;
std::vector<c10::List<at::Tensor>> tensorlist_args;
// Step 1: Convert all non-CPU tensor inputs into CPU tensors
// and put them on the stack at the correct indices.
for (const auto idx : c10::irange(arguments.size())) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
tensor_args.push_back(ivalue.toTensor());
tensor_args_indices.push_back(idx);
} else if (ivalue.isTensorList()) {
// Note: we copy each TensorList argument to CPU individually out of convenience,
// but XLA would benefit from materializing all tensor and TensorList args onto the CPU at the same time.
// We can improve this if we need better perf for XLA's CPU fallbacks.
auto cpu_ivalue = c10::IValue(c10::List<at::Tensor>(to_cpu(ivalue.toTensorList().vec())));
(*stack)[arguments_begin + idx] = std::move(cpu_ivalue);
tensorlist_args.push_back(ivalue.toTensorList());
}
}
// XLA requires all of the tensor arguments to be gathered up and converted to CPU together.
auto cpu_tensors = to_cpu(tensor_args);
for (const auto i : c10::irange(tensor_args_indices.size())) {
auto idx = tensor_args_indices[i];
(*stack)[arguments_begin + idx] = c10::IValue(cpu_tensors[i]);
}
// Step 2: Call the underlying CPU implementation of the operator
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack);
// Step 3: We need to take special care to handle mutable aliases properly:
// If any input tensors are mutable aliases, we need to
// directly copy the updated data on the CPU tensors back to the original inputs.
for (const auto i : c10::irange(tensor_args_indices.size())) {
auto tensor_idx = tensor_args_indices[i];
const AliasInfo* alias_info = schema_args[tensor_idx].alias_info();
if (alias_info != nullptr && alias_info->isWrite()) {
at::_copy_from_and_resize(cpu_tensors[i], tensor_args[i]);
}
}
// Step 4: Convert any CPU output tensors back to the original input device.
// For mutable alias'd outputs, we also need to take special care
// to move the ORIGINAL input tensor back onto the stack, in place of
// the temporary CPU output tensor that we created.
//
// Note [CPU Fallback Does Not Handle View Operators]
// Also note that we are incapable of handling immutable alises properly.
// Why?
// Schemas with an immutable alias'd tensor outputs correspond to view operators.
// For example, the `view_as` schema from native_functions.yaml:
// `view_as(Tensor(a) self, Tensor other) -> Tensor(a)`
// We can't handle these ops properly, because view ops are supposed to return
// a NEW tensor that shares the SAME storage as the original tensor.
// However, the new tensor that we created cannot share the same storage,
// since it lives on CPU and the original tensor lives on a different device.
// Because of that, we warn if someone attempts to call the
// CPU fallback on a view operator (this is to maintain BC for view ops for XLA
// that fall back to CPU).
const auto& schema_returns = op.schema().returns();
const auto& num_returns = schema_returns.size();
auto returns = torch::jit::last(stack, num_returns);
const auto returns_begin = stack->size() - num_returns;
for (const auto idx : c10::irange(returns.size())) {
if (returns[idx].isTensor()) {
const auto& return_tens = returns[idx].toTensor();
if (return_tens.defined()) {
const AliasInfo* alias_info = schema_returns[idx].alias_info();
if (alias_info != nullptr && alias_info->isWrite()) {
// Case (1): mutable alias case. Move the input ivalue directly onto the stack
// in place of the existing cpu output tensor.
bool found_alias = false;
// We could store some extra metadata on the function schema to avoid the loop here
// if we need to improve perf.
for (const auto i : c10::irange(tensor_args_indices.size())) {
auto input_tensor_idx = tensor_args_indices[i];
const auto& input_tensor = cpu_tensors[i];
const AliasInfo* input_alias_info = schema_args[input_tensor_idx].alias_info();
// Checked above; adding assert to guard against breakage of the below condition due to changing the above if test.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alias_info != nullptr);
if (input_tensor.defined() && (alias_info == input_alias_info || (input_alias_info != nullptr && *alias_info == *input_alias_info))) {
// We've found the original input tensor that aliases with the current output.
// Wrap it in an IValue and put it directly on the stack.
(*stack)[returns_begin + idx] = c10::IValue(tensor_args[i]);
found_alias = true;
break;
}
}
TORCH_CHECK(found_alias, "The operator ", op.schema().operator_name(), " appears to have invalid alias information. ",
"Found a return tensor argument with a mismatched mutable alias: ", schema_returns[idx]);
} else {
c10::optional<c10::Device> tgt_device = compute_target_device(tensor_args, tensorlist_args);
if (alias_info != nullptr && !alias_info->isWrite()) {
// immutable alias (view) case: Warn here, since we're copying and not creating a view.
//If this operator is needed, the backend should provide a kernel for it.
// See Note [CPU Fallback Does Not Handle View Operators]
std::stringstream dev_str;
if (tgt_device) {
dev_str << *tgt_device;
} else {
dev_str << "<none>";
}
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
}
// Case (2): copy case. Copy the cpu output tensor to the original device.
// We technically might not have a target device, e.g. if you call torch.cat() with an empty list
// In that case, we shouldn't have any tensors to schlep across devices anyway.
if (tgt_device) {
(*stack)[returns_begin + idx] = c10::IValue(returns[idx].toTensor().to(*tgt_device));
}
}
}
}
}
}
} // namespace native
} // namespace at