forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCDeviceTensorUtils-inl.cuh
118 lines (105 loc) · 4.39 KB
/
THCDeviceTensorUtils-inl.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
namespace detail {
// Add a layer of SFINAE to support static_assert
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim, bool B>
struct UpcastTHCRoot {
static THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
make(THCState* state, THCudaTensor* t);
};
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim, bool B>
struct UpcastTHC :
UpcastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, B> {
};
// Never instantiated SFINAE purposes only
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim>
struct UpcastTHC<T, Dim, IndexT, PtrTraits, NewDim, false> :
UpcastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, false> {
};
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim>
struct UpcastTHC<T, Dim, IndexT, PtrTraits, NewDim, true> :
UpcastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, true> {
static THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
make(THCState* state, THCudaTensor* t) {
thc_static_assert(NewDim > Dim);
return toDeviceTensor<T, Dim, IndexT, PtrTraits>(state, t).
template upcastOuter<NewDim>();
}
};
// Add a layer of SFINAE to support static_assert
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim, bool B>
struct DowncastTHCRoot {
static THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
make(THCState* state, THCudaTensor* t);
};
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim, bool B>
struct DowncastTHC :
DowncastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, B> {
};
// Never instantiated SFINAE purposes only
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim>
struct DowncastTHC<T, Dim, IndexT, PtrTraits, NewDim, false> :
DowncastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, false> {
};
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim>
struct DowncastTHC<T, Dim, IndexT, PtrTraits, NewDim, true> :
DowncastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, true> {
static THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
make(THCState* state, THCudaTensor* t) {
thc_static_assert(NewDim < Dim);
return toDeviceTensor<T, Dim, IndexT, PtrTraits>(state, t).
template downcastOuter<NewDim>();
}
};
} // namespace detail
#define SWITCH_UNROLL_CUDA_CAST_FACTORY(i) \
case i: \
if (NewDim > i) { \
return detail::UpcastTHC<T, i, IndexT, \
PtrTraits, NewDim, (NewDim > i)>:: \
make(state, t); \
} else if (NewDim == i) { \
return toDeviceTensor<T, NewDim, IndexT, PtrTraits>(state, t); \
} else { \
return detail::DowncastTHC<T, i, IndexT, \
PtrTraits, NewDim, (NewDim < i)>:: \
make(state, t); \
} \
/* break; */
template <typename T, int NewDim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
toDeviceTensorCast(THCState* state, THCudaTensor* t) {
switch (THCudaTensor_nDimensionLegacyAll(state, t)) {
SWITCH_UNROLL_CUDA_CAST_FACTORY(1);
SWITCH_UNROLL_CUDA_CAST_FACTORY(2);
SWITCH_UNROLL_CUDA_CAST_FACTORY(3);
SWITCH_UNROLL_CUDA_CAST_FACTORY(4);
SWITCH_UNROLL_CUDA_CAST_FACTORY(5);
SWITCH_UNROLL_CUDA_CAST_FACTORY(6);
SWITCH_UNROLL_CUDA_CAST_FACTORY(7);
SWITCH_UNROLL_CUDA_CAST_FACTORY(8);
SWITCH_UNROLL_CUDA_CAST_FACTORY(9);
SWITCH_UNROLL_CUDA_CAST_FACTORY(10);
default:
;
}
// Not implemented
THError("THCDeviceTensor dimension size not supported");
return NULL; /* never enters this piece, appeasing compiler warnings */
}
#undef SWITCH_UNROLL_CUDA_CAST_FACTORY