forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCTensorMasked.cuh
58 lines (49 loc) · 1.41 KB
/
THCTensorMasked.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#ifndef THC_TENSOR_MASKED_CUH
#define THC_TENSOR_MASKED_CUH
#include <THC/THCTensorMath.h>
#include <THC/THCGeneral.h>
#include <THC/THCTensorCopy.h>
#include <THC/THCApply.cuh>
#include <THC/THCReduce.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/device_ptr.h>
#include <thrust/scan.h>
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
#include <thrust/system/cuda/execution_policy.h>
#endif
template <typename T, typename MaskT>
struct TensorMaskedFillOp {
TensorMaskedFillOp(T v) : value(v) {}
__device__ inline void operator()(T* t, MaskT* mask) {
if (*mask) {
*t = value;
}
}
T value;
};
template <typename T, typename MaskT, typename MaskPrefixSumT>
struct TensorMaskedCopyOp {
TensorMaskedCopyOp(T* s) : in(s) {}
__device__ inline void operator()(T* out,
MaskT* mask,
MaskPrefixSumT* maskPrefixSum) {
if (*mask) {
*out = in[*maskPrefixSum];
}
}
// Where we are copying from
T* in;
};
template <typename T, typename MaskT, typename MaskPrefixSumT>
struct TensorMaskedSelectOp {
TensorMaskedSelectOp(T* t) : out(t) {}
__device__ inline void operator()(MaskT* mask,
MaskPrefixSumT* maskPrefixSum,
T* in) {
if (*mask) {
out[*maskPrefixSum] = *in;
}
}
T* out;
};
#endif // THC_TENSOR_MASKED_CUH