forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LerpKernel.cpp
171 lines (158 loc) · 6.88 KB
/
LerpKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/Lerp.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/util/irange.h>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
Vectorized<scalar_t> is_lerp_weight_small(Vectorized<scalar_t> weight) {
static_assert(!c10::is_complex<scalar_t>::value, "");
return weight.abs() < Vectorized<scalar_t>(0.5);
}
// is_lerp_weight_small doesn't work for complex because z.abs() returns a
// complex vector which can't be compared. Either implement it with z.abs_2_(),
// or fallback to the scalar function.
#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER))
template <typename value_t>
Vectorized<c10::complex<value_t>> is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight) {
using vec_reg_t = decltype(weight.abs_2_());
vec_reg_t mask = Vectorized<value_t>(weight.abs_2_()) < Vectorized<value_t>(0.25);
return Vectorized<c10::complex<value_t>>(mask);
}
#else
template <typename scalar_t>
Vectorized<scalar_t> lerp_vec_map(Vectorized<scalar_t> start, Vectorized<scalar_t> end, Vectorized<scalar_t> weight) {
using vec_t = Vectorized<scalar_t>;
__at_align__ scalar_t start_arr[vec_t::size()];
__at_align__ scalar_t end_arr[vec_t::size()];
__at_align__ scalar_t weight_arr[vec_t::size()];
__at_align__ scalar_t result_arr[vec_t::size()];
start.store(start_arr);
end.store(end_arr);
weight.store(weight_arr);
for (auto i : c10::irange(vec_t::size())) {
result_arr[i] = lerp(start_arr[i], end_arr[i], weight_arr[i]);
}
return vec_t::loadu(result_arr);
}
template <typename value_t>
Vectorized<c10::complex<value_t>> lerp_vec(Vectorized<c10::complex<value_t>> start, Vectorized<c10::complex<value_t>> end, Vectorized<c10::complex<value_t>> weight) {
return lerp_vec_map(start, end, weight);
}
#endif
template <typename scalar_t>
Vectorized<scalar_t> lerp_vec(Vectorized<scalar_t> start, Vectorized<scalar_t> end, Vectorized<scalar_t> weight) {
using vec_t = Vectorized<scalar_t>;
auto mask = is_lerp_weight_small(weight);
auto coeff = vec_t::blendv(weight - vec_t(1), weight, mask);
auto base = vec_t::blendv(end, start, mask);
return vec::fmadd(coeff, end - start, base);
}
void lerp_scalar_kernel(at::TensorIteratorBase& iter, const Scalar& weight) {
if (iter.common_dtype() == kBFloat16) {
using bVec = Vectorized<BFloat16>;
using fVec = Vectorized<float>;
float weight_val = weight.to<float>();
auto weight_vec = fVec(weight_val);
at::native::cpu_kernel_vec(
iter,
[weight_val](BFloat16 self_val, BFloat16 end_val) -> BFloat16 {
return lerp(self_val, end_val, weight_val);
},
[=](bVec self_vec, bVec end_vec) -> bVec {
fVec self_vec0, self_vec1, end_vec0, end_vec1;
std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec);
std::tie(end_vec0, end_vec1) = convert_bfloat16_float(end_vec);
auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
return convert_float_bfloat16(result0, result1);
});
} else if (iter.common_dtype() == kHalf) {
using hVec = Vectorized<Half>;
using fVec = Vectorized<float>;
float weight_val = weight.to<float>();
auto weight_vec = fVec(weight_val);
at::native::cpu_kernel_vec(
iter,
[weight_val](Half self_val, Half end_val) -> Half {
return lerp(self_val, end_val, weight_val);
},
[=](hVec self_vec, hVec end_vec) -> hVec {
fVec self_vec0, self_vec1, end_vec0, end_vec1;
std::tie(self_vec0, self_vec1) = convert_half_float(self_vec);
std::tie(end_vec0, end_vec1) = convert_half_float(end_vec);
auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
return convert_float_half(result0, result1);
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_scalar", [&] {
auto weight_val = weight.to<scalar_t>();
at::native::cpu_kernel_vec(
iter,
[weight_val](scalar_t self_val, scalar_t end_val) {
return lerp(self_val, end_val, weight_val);
},
[weight_val](Vectorized<scalar_t> self, Vectorized<scalar_t> end) {
const Vectorized<scalar_t> weight(weight_val);
return lerp_vec(self, end, weight);
});
});
}
}
void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
if (iter.common_dtype() == kBFloat16) {
using bVec = Vectorized<BFloat16>;
using fVec = Vectorized<float>;
at::native::cpu_kernel_vec(
iter,
[=](BFloat16 self_val, BFloat16 end_val, BFloat16 weight_val) -> BFloat16 {
return lerp(self_val, end_val, weight_val);
},
[=](bVec self_vec, bVec end_vec, bVec weight_vec) -> bVec {
fVec self_vec0, self_vec1, end_vec0, end_vec1, weight_vec0, weight_vec1;
std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec);
std::tie(end_vec0, end_vec1) = convert_bfloat16_float(end_vec);
std::tie(weight_vec0, weight_vec1) = convert_bfloat16_float(weight_vec);
auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
return convert_float_bfloat16(result0, result1);
});
} else if (iter.common_dtype() == kHalf) {
using hVec = Vectorized<Half>;
using fVec = Vectorized<float>;
at::native::cpu_kernel_vec(
iter,
[=](Half self_val, Half end_val, Half weight_val) -> Half {
return lerp(self_val, end_val, weight_val);
},
[=](hVec self_vec, hVec end_vec, hVec weight_vec) -> hVec {
fVec self_vec0, self_vec1, end_vec0, end_vec1, weight_vec0, weight_vec1;
std::tie(self_vec0, self_vec1) = convert_half_float(self_vec);
std::tie(end_vec0, end_vec1) = convert_half_float(end_vec);
std::tie(weight_vec0, weight_vec1) = convert_half_float(weight_vec);
auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
return convert_float_half(result0, result1);
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_tensor", [&] {
at::native::cpu_kernel_vec(
iter,
[](scalar_t self_val, scalar_t end_val, scalar_t weight_val) {
return lerp(self_val, end_val, weight_val);
},
[](Vectorized<scalar_t> self_val, Vectorized<scalar_t> end_val, Vectorized<scalar_t> weight_val) {
return lerp_vec(self_val, end_val, weight_val);
});
});
}
}
} // anonymous namespace
REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_scalar_kernel);
REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_tensor_kernel);
} // namespace native
} // namespace at