forked from Stonepia/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathForeachBinaryOpScalarTensor.cu
149 lines (137 loc) · 5.23 KB
/
ForeachBinaryOpScalarTensor.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/ForeachMinMaxFunctors.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_foreach_mul_native.h>
#include <ATen/ops/empty_like_native.h>
#endif
namespace at::native {
template <typename T, template <class> class Op>
std::vector<Tensor> foreach_binary_op(
TensorList tensors,
const Tensor& scalar) {
TORCH_CHECK(
scalar.dim() == 0 && scalar.numel() == 1,
"scalar tensor expected to be 0 dim but it has ",
scalar.dim(),
" dimensions and ",
scalar.numel(),
" elements.");
TORCH_CHECK(
tensors[0].device() == scalar.device(),
"scalar tensor expected to be on ",
tensors[0].device(),
" but is on ",
scalar.device());
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors.size());
for (const auto& t : tensors) {
vec_res.emplace_back(at::native::empty_like(t));
}
tensor_lists.emplace_back(tensors.vec());
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = at::opmath_type<T>;
multi_tensor_apply<2>(
tensor_lists,
BinaryOpScalarTensorFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Op<opmath_t>(),
scalar.data_ptr<T>());
return tensor_lists[1];
}
template <typename T, template <class> class Op>
void foreach_binary_op_(TensorList tensors, const Tensor& scalar) {
TORCH_CHECK(
scalar.dim() == 0 && scalar.numel() == 1,
"scalar tensor expected to be 0 dim but has ",
scalar.dim(),
" dimensions and ",
scalar.numel(),
" elements.");
TORCH_CHECK(
tensors[0].device() == scalar.device(),
"scalar tensor is expected to be on ",
tensors[0].device(),
" but is on ",
scalar.device());
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.emplace_back(tensors.vec());
using opmath_t = at::opmath_type<T>;
multi_tensor_apply<1>(
tensor_lists,
BinaryOpScalarTensorFunctor<
T,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>(),
scalar.data_ptr<T>());
increment_version(tensors);
}
// TODO(crcrpar): Nest dispatch by looking up `scalar.scalar_type` for better
// coverage?
#define FOREACH_BINARY_OP_SCALAR_TENSOR(FUNCTION, NAME, OP, DIVISION_OP) \
void foreach_tensor_##NAME##_tensor_kernel_cuda_( \
TensorList tensors, const Tensor& scalar) { \
check_foreach_api_restrictions(tensors); \
if (!(can_use_fast_route( \
ArrayRef<TensorList>{tensors}, {}, DIVISION_OP) && \
tensors[0].scalar_type() == scalar.scalar_type())) { \
return at::native::foreach_tensor_##NAME##_tensor_kernel_slow_( \
tensors, scalar); \
} \
\
FUNCTION##_<OP>(tensors, scalar); \
} \
\
std::vector<Tensor> foreach_tensor_##NAME##_tensor_kernel_cuda( \
TensorList tensors, const Tensor& scalar) { \
check_foreach_api_restrictions(tensors); \
if (!(can_use_fast_route( \
ArrayRef<TensorList>{tensors}, {}, DIVISION_OP) && \
tensors[0].scalar_type() == scalar.scalar_type())) { \
return at::native::foreach_tensor_##NAME##_tensor_kernel_slow( \
tensors, scalar); \
} \
\
return FUNCTION<OP>(tensors, scalar); \
}
template <template <class> class Op>
std::vector<Tensor> all_types_complex_bool_half_bfloat16(
TensorList tensors,
const Tensor& scalar) {
return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool,
kHalf,
kBFloat16,
tensors[0].scalar_type(),
"foreach_binary_op_scalar_cuda",
[&]() { return foreach_binary_op<scalar_t, Op>(tensors, scalar); });
}
template <template <class> class Op>
void all_types_complex_bool_half_bfloat16_(
TensorList tensors,
const Tensor& scalar) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool,
kHalf,
kBFloat16,
tensors[0].scalar_type(),
"foreach_binary_op_scalar_cuda_",
[&]() { foreach_binary_op_<scalar_t, Op>(tensors, scalar); });
}
FOREACH_BINARY_OP_SCALAR_TENSOR(
all_types_complex_bool_half_bfloat16,
mul,
std::multiplies,
/* div_op */ false);
} // namespace at::native