forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ForeachUtils.h
159 lines (136 loc) · 6.84 KB
/
ForeachUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#pragma once
#include <ATen/ATen.h>
#include <c10/util/irange.h>
namespace at {
namespace native {
namespace {
// Check foreach API restrictions
// - Tensor lists must be non-empty.
// - All tensors in all lists must have the same dtype.
// - All TensorLists and ScalarLists must have the same number of elements.
// - Corresponding tensors must have the same size.
void check_foreach_api_restrictions(TensorList tensors) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_dtype = tensors[0].dtype();
for (const auto& t : tensors) {
TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
}
}
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors);
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
}
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) {
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
auto expected_dtype = tensors1[0].dtype();
for (const auto i : c10::irange(tensors1.size())) {
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
}
}
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors3.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors3.size());
auto expected_dtype = tensors1[0].dtype();
for (const auto i : c10::irange(tensors1.size())) {
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
TORCH_CHECK(tensors1[i].sizes() == tensors3[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors3[i].sizes());
}
}
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size());
}
// To go via 'fast' path, several conditions must be satisfied
// - All tensors must be on the same device
// - All tensors must have strided layout
// - All tensors must be non-overlapping and dense
// - Resulting tensor must have the same dtype as the input one
bool will_promote_tensor(const Tensor& tensor, const Scalar& scalar, bool does_op_promote_integer_inputs_to_float = false) {
// In case of division, integer inputs will result in float
if (does_op_promote_integer_inputs_to_float) {
if (at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
return true;
}
}
auto result_dtype = at::result_type(tensor, scalar);
return result_dtype != tensor.scalar_type();
}
// Please, make sure to call check_foreach_api_restrictions before calling this method.
// There is a set of preconditions that have to be satisfied.
bool check_fast_path_restrictions(
ArrayRef<TensorList> tensorLists,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
auto expected_device = tensorLists[0][0].device();
auto is_tensor_okay = [&](const Tensor& tensor) {
return tensor.device() == expected_device &&
tensor.layout() == at::kStrided &&
tensor.is_non_overlapping_and_dense();
};
for (const auto& tensorList : tensorLists) {
for (const auto& tensor : tensorList) {
if (!is_tensor_okay(tensor)) {
return false;
}
}
}
// Check if corresponding tensors in tensor lists have the same strides.
for (int i=0; i < tensorLists.size(); i++) {
for (int j=0; j < tensorLists[0].size(); j++) {
if (tensorLists[0][j].strides() != tensorLists[i][j].strides()) {
return false;
}
}
}
// For all j, tensorList[j][0] have the same shape and dtype. (this was a precondition
// checked by `check_foreach_api_restrictions`). This means we only need to check if
// {tensorList[0][0], tensorList[0][1], tensorList[0][2], ...} do type promotion with scalarLIst.
for (int i=0; i < tensorLists[0].size(); i++) {
if (does_op_promote_integer_inputs_to_float) {
if (at::isIntegralType(tensorLists[0][i].scalar_type(), /*includeBool*/ true)) {
return false;
}
}
if (scalarList.size() == 1) {
if (will_promote_tensor(tensorLists[0][i], scalarList[0])) {
return false;
}
} else if (scalarList.size() > 1) {
// Complex scalar list is not supported due to the limit for kernel launch argument (4KB)
if (scalarList[i].isComplex()) {
return false;
}
if (will_promote_tensor(tensorLists[0][i], scalarList[i])) {
return false;
}
}
}
return true;
}
bool can_use_fast_route(ArrayRef<TensorList> tensorLists,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
return check_fast_path_restrictions(tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
#endif
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, bool does_op_promote_integer_inputs_to_float = false) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
return can_use_fast_route({tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
#endif
}
}
}} // at::native