forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
common_gpu.h
525 lines (470 loc) · 22.4 KB
/
common_gpu.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
#ifndef CAFFE2_CORE_COMMON_GPU_H_
#define CAFFE2_CORE_COMMON_GPU_H_
#include <assert.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Disable strict aliasing errors for CUDA 9.
// The cuda_fp16.h header in CUDA 9 RC triggers this diagnostic.
// It is included by cusparse.h as well, so guarding the
// inclusion of that header here is not enough.
#if CUDA_VERSION >= 9000
#ifdef __GNUC__
#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
#pragma GCC diagnostic push
#endif
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
#endif // CUDA_VERSION >= 9000
#include <cublas_v2.h>
#include <curand.h>
#include <driver_types.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "c10/cuda/CUDAMacros.h"
#include "c10/cuda/CUDAMathCompat.h"
// Defines CAFFE2_CUDA_EXPORT and CAFFE2_CUDA_IMPORT. On Windows, this
// corresponds to different declarations (dllexport and dllimport). On
// Linux/Mac, it just resolves to the same "default visibility" setting.
#if defined(_MSC_VER)
#if defined(CAFFE2_BUILD_SHARED_LIBS)
#define CAFFE2_CUDA_EXPORT __declspec(dllexport)
#define CAFFE2_CUDA_IMPORT __declspec(dllimport)
#else
#define CAFFE2_CUDA_EXPORT
#define CAFFE2_CUDA_IMPORT
#endif
#else
#if defined(__GNUC__)
#define CAFFE2_CUDA_EXPORT __attribute__((__visibility__("default")))
#else
#define CAFFE2_CUDA_EXPORT
#endif
#define CAFFE2_CUDA_IMPORT CAFFE2_CUDA_EXPORT
#endif
// CAFFE2_CUDA_API is a macro that, depends on whether you are building the
// main caffe2 library or not, resolves to either CAFFE2_CUDA_EXPORT or
// CAFFE2_CUDA_IMPORT.
//
// This is used in e.g. Caffe2's protobuf files: when building the main library,
// it is defined as CAFFE2_CUDA_EXPORT to fix a Windows global-variable-in-dll
// issue, and for anyone dependent on Caffe2 it will be defined as
// CAFFE2_CUDA_IMPORT.
#ifdef CAFFE2_CUDA_BUILD_MAIN_LIB
#define CAFFE2_CUDA_API CAFFE2_CUDA_EXPORT
#else
#define CAFFE2_CUDA_API CAFFE2_CUDA_IMPORT
#endif
// This is a macro defined for cuda fp16 support. In default, cuda fp16 is
// supported by NVCC 7.5, but it is also included in the Tegra X1 platform with
// a (custom?) NVCC 7.0. As a result, we would normally just check the cuda
// version here, but would also allow a use to pass in the flag
// CAFFE_HAS_CUDA_FP16 manually.
#ifndef CAFFE_HAS_CUDA_FP16
#if CUDA_VERSION >= 7050 || defined(__HIP_PLATFORM_HCC__)
#define CAFFE_HAS_CUDA_FP16
#endif // CUDA_VERSION >= 7050
#endif // CAFFE_HAS_CUDA_FP16
#ifdef CAFFE_HAS_CUDA_FP16
#include <cuda_fp16.h>
#endif
// cuda major revision number below which fp16 compute is not supoorted
#ifndef __HIP_PLATFORM_HCC__
constexpr int kFp16CUDADevicePropMajor = 6;
#else
constexpr int kFp16CUDADevicePropMajor = 3;
#endif
// Re-enable strict aliasing diagnostic if it was disabled.
#if CUDA_VERSION >= 9000
#ifdef __GNUC__
#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
#pragma GCC diagnostic pop
#endif
#endif // __GNUC__
#endif // CUDA_VERSION >= 9000
/**
* The maximum number of peers that each gpu can have when doing p2p setup.
* Currently, according to NVidia documentation, each device can support a
* system-wide maximum of eight peer connections.
* When Caffe2 sets up peer access resources, if we have more than 8 gpus,
* we will enable peer access in groups of 8.
*/
#define CAFFE2_CUDA_MAX_PEER_SIZE 8
namespace caffe2 {
#if CUDA_VERSION >= 9000
/**
* Empty class to identify TensorCore-based math
*/
class TensorCoreEngine {};
#endif
#if CUDA_VERSION >= 10000
#define CAFFE2_CUDA_PTRATTR_MEMTYPE type
#else
#define CAFFE2_CUDA_PTRATTR_MEMTYPE memoryType
#endif
/**
* A runtime function to report the cuda version that Caffe2 is built with.
*/
inline int CudaVersion() {
return CUDA_VERSION;
}
/**
* Returns the number of devices.
*/
CAFFE2_CUDA_API int NumCudaDevices();
/**
* Check if the current running session has a cuda gpu present.
*
* Note that this is different from having caffe2 built with cuda. Building
* Caffe2 with cuda only guarantees that this function exists. If there are no
* cuda gpus present in the machine, or there are hardware configuration
* problems like an insufficient driver, this function will still return false,
* meaning that there is no usable GPU present.
*
* In the open source build, it is possible that Caffe2's GPU code is
* dynamically loaded, and as a result a library could be only linked to the
* CPU code, but want to test if cuda is later available or not. In this case,
* one should use HasCudaRuntime() from common.h.
*/
inline bool HasCudaGPU() {
return NumCudaDevices() > 0;
}
/**
* Gets the current GPU id. This is a simple wrapper around cudaGetDevice().
*/
CAFFE2_CUDA_API int CaffeCudaGetDevice();
/**
* Gets the current GPU id. This is a simple wrapper around cudaGetDevice().
*/
CAFFE2_CUDA_API void CaffeCudaSetDevice(const int id);
/**
* Gets the GPU id that the current pointer is located at.
*/
CAFFE2_CUDA_API int GetGPUIDForPointer(const void* ptr);
/**
* Gets the device property for the given device. This function is thread safe.
*/
CAFFE2_CUDA_API const cudaDeviceProp& GetDeviceProperty(const int device);
/**
* Runs a device query function and prints out the results to LOG(INFO).
*/
CAFFE2_CUDA_API void DeviceQuery(const int deviceid);
/**
* Return a peer access pattern by returning a matrix (in the format of a
* nested vector) of boolean values specifying whether peer access is possible.
*
* This function returns false if anything wrong happens during the query of
* the GPU access pattern.
*/
CAFFE2_CUDA_API bool GetCudaPeerAccessPattern(vector<vector<bool>>* pattern);
/**
* Return the availability of TensorCores for math
*/
CAFFE2_CUDA_API bool TensorCoreAvailable();
/**
* Return a human readable cublas error string.
*/
CAFFE2_CUDA_API const char* cublasGetErrorString(cublasStatus_t error);
/**
* Return a human readable curand error string.
*/
CAFFE2_CUDA_API const char* curandGetErrorString(curandStatus_t error);
// CUDA: various checks for different function calls.
#define CUDA_ENFORCE(condition, ...) \
do { \
cudaError_t error = condition; \
CAFFE_ENFORCE_EQ( \
error, \
cudaSuccess, \
"Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
cudaGetErrorString(error), \
##__VA_ARGS__); \
} while (0)
#define CUDA_CHECK(condition) \
do { \
cudaError_t error = condition; \
CHECK(error == cudaSuccess) << cudaGetErrorString(error); \
} while (0)
#define CUDA_DRIVERAPI_ENFORCE(condition) \
do { \
CUresult result = condition; \
if (result != CUDA_SUCCESS) { \
const char* msg; \
cuGetErrorName(result, &msg); \
CAFFE_THROW("Error at: ", __FILE__, ":", __LINE__, ": ", msg); \
} \
} while (0)
#define CUDA_DRIVERAPI_CHECK(condition) \
do { \
CUresult result = condition; \
if (result != CUDA_SUCCESS) { \
const char* msg; \
cuGetErrorName(result, &msg); \
LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \
<< msg; \
} \
} while (0)
#define CUBLAS_ENFORCE(condition) \
do { \
cublasStatus_t status = condition; \
CAFFE_ENFORCE_EQ( \
status, \
CUBLAS_STATUS_SUCCESS, \
"Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
::caffe2::cublasGetErrorString(status)); \
} while (0)
#define CUBLAS_CHECK(condition) \
do { \
cublasStatus_t status = condition; \
CHECK(status == CUBLAS_STATUS_SUCCESS) \
<< ::caffe2::cublasGetErrorString(status); \
} while (0)
#define CURAND_ENFORCE(condition) \
do { \
curandStatus_t status = condition; \
CAFFE_ENFORCE_EQ( \
status, \
CURAND_STATUS_SUCCESS, \
"Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
::caffe2::curandGetErrorString(status)); \
} while (0)
#define CURAND_CHECK(condition) \
do { \
curandStatus_t status = condition; \
CHECK(status == CURAND_STATUS_SUCCESS) \
<< ::caffe2::curandGetErrorString(status); \
} while (0)
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \
j += blockDim.y * gridDim.y)
// CUDA_KERNEL_ASSERT is a macro that wraps an assert() call inside cuda
// kernels. This is not supported by Apple platforms so we special case it.
// See http://docs.nvidia.com/cuda/cuda-c-programming-guide/#assertion
#if defined(__APPLE__) || defined(__HIP_PLATFORM_HCC__)
#define CUDA_KERNEL_ASSERT(...)
#else // __APPLE__
#define CUDA_KERNEL_ASSERT(...) assert(__VA_ARGS__)
#endif // __APPLE__
// The following helper functions are here so that you can write a kernel call
// when you are not particularly interested in maxing out the kernels'
// performance. Usually, this will give you a reasonable speed, but if you
// really want to find the best performance, it is advised that you tune the
// size of the blocks and grids more reasonably.
// A legacy note: this is derived from the old good Caffe days, when I simply
// hard-coded the number of threads and wanted to keep backward compatibility
// for different computation capabilities.
// For more info on CUDA compute capabilities, visit the NVidia website at:
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
// The number of cuda threads to use. Since work is assigned to SMs at the
// granularity of a block, 128 is chosen to allow utilizing more SMs for
// smaller input sizes.
// 1D grid
constexpr int CAFFE_CUDA_NUM_THREADS = 128;
// 2D grid
constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMX = 16;
constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMY = 16;
// The maximum number of blocks to use in the default kernel call. We set it to
// 4096 which would work for compute capability 2.x (where 65536 is the limit).
// This number is very carelessly chosen. Ideally, one would like to look at
// the hardware at runtime, and pick the number of blocks that makes most
// sense for the specific runtime environment. This is a todo item.
// 1D grid
constexpr int CAFFE_MAXIMUM_NUM_BLOCKS = 4096;
// 2D grid
constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX = 128;
constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY = 128;
constexpr int kCUDAGridDimMaxX = 2147483647;
constexpr int kCUDAGridDimMaxY = 65535;
constexpr int kCUDAGridDimMaxZ = 65535;
/**
* @brief Compute the number of blocks needed to run N threads.
*/
inline int CAFFE_GET_BLOCKS(const int N) {
return std::max(
std::min(
(N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS,
CAFFE_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since CUDA does not allow empty block
1);
}
/**
* @brief Compute the number of blocks needed to run N threads for a 2D grid
*/
inline dim3 CAFFE_GET_BLOCKS_2D(const int N, const int /* M */) {
dim3 grid;
// Not calling the 1D version for each dim to keep all constants as literals
grid.x = std::max(
std::min(
(N + CAFFE_CUDA_NUM_THREADS_2D_DIMX - 1) /
CAFFE_CUDA_NUM_THREADS_2D_DIMX,
CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX),
// Use at least 1 block, since CUDA does not allow empty block
1);
grid.y = std::max(
std::min(
(N + CAFFE_CUDA_NUM_THREADS_2D_DIMY - 1) /
CAFFE_CUDA_NUM_THREADS_2D_DIMY,
CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY),
// Use at least 1 block, since CUDA does not allow empty block
1);
return grid;
}
class DeviceGuard {
public:
explicit DeviceGuard(int newDevice) : previous_(CaffeCudaGetDevice()) {
if (previous_ != newDevice) {
CaffeCudaSetDevice(newDevice);
}
}
~DeviceGuard() noexcept {
CaffeCudaSetDevice(previous_);
}
private:
int previous_;
};
template <typename T, int N>
struct SimpleArray {
T data[N];
};
constexpr int kCUDATensorMaxDims = 8;
#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1(val, Func, T, ...) \
do { \
CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \
switch (val) { \
case 1: { \
Func<T, 1>(__VA_ARGS__); \
break; \
} \
case 2: { \
Func<T, 2>(__VA_ARGS__); \
break; \
} \
case 3: { \
Func<T, 3>(__VA_ARGS__); \
break; \
} \
case 4: { \
Func<T, 4>(__VA_ARGS__); \
break; \
} \
case 5: { \
Func<T, 5>(__VA_ARGS__); \
break; \
} \
case 6: { \
Func<T, 6>(__VA_ARGS__); \
break; \
} \
case 7: { \
Func<T, 7>(__VA_ARGS__); \
break; \
} \
case 8: { \
Func<T, 8>(__VA_ARGS__); \
break; \
} \
default: { \
break; \
} \
} \
} while (false)
#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2(val, Func, T1, T2, ...) \
do { \
CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \
switch (val) { \
case 1: { \
Func<T1, T2, 1>(__VA_ARGS__); \
break; \
} \
case 2: { \
Func<T1, T2, 2>(__VA_ARGS__); \
break; \
} \
case 3: { \
Func<T1, T2, 3>(__VA_ARGS__); \
break; \
} \
case 4: { \
Func<T1, T2, 4>(__VA_ARGS__); \
break; \
} \
case 5: { \
Func<T1, T2, 5>(__VA_ARGS__); \
break; \
} \
case 6: { \
Func<T1, T2, 6>(__VA_ARGS__); \
break; \
} \
case 7: { \
Func<T1, T2, 7>(__VA_ARGS__); \
break; \
} \
case 8: { \
Func<T1, T2, 8>(__VA_ARGS__); \
break; \
} \
default: { \
break; \
} \
} \
} while (false)
#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_3(val, Func, T1, T2, T3, ...) \
do { \
CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \
switch (val) { \
case 1: { \
Func<T1, T2, T3, 1>(__VA_ARGS__); \
break; \
} \
case 2: { \
Func<T1, T2, T3, 2>(__VA_ARGS__); \
break; \
} \
case 3: { \
Func<T1, T2, T3, 3>(__VA_ARGS__); \
break; \
} \
case 4: { \
Func<T1, T2, T3, 4>(__VA_ARGS__); \
break; \
} \
case 5: { \
Func<T1, T2, T3, 5>(__VA_ARGS__); \
break; \
} \
case 6: { \
Func<T1, T2, T3, 6>(__VA_ARGS__); \
break; \
} \
case 7: { \
Func<T1, T2, T3, 7>(__VA_ARGS__); \
break; \
} \
case 8: { \
Func<T1, T2, T3, 8>(__VA_ARGS__); \
break; \
} \
default: { \
break; \
} \
} \
} while (false)
} // namespace caffe2
#endif // CAFFE2_CORE_COMMON_GPU_H_