forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConvShared.h
89 lines (76 loc) · 2.83 KB
/
ConvShared.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <ATen/ATen.h>
#include <ATen/cudnn/cudnn-wrapper.h>
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h>
#include <ATen/native/ConvUtils.h>
namespace at { namespace native {
// ---------------------------------------------------------------------
//
// Helper classes
//
// ---------------------------------------------------------------------
// This POD struct is used to let us easily compute hashes of the
// parameters
struct ConvolutionParams
{
c10::DeviceIndex device_id;
cudnnDataType_t dataType;
int input_size[2 + max_dim];
uint8_t input_dim;
at::MemoryFormat memory_format;
int weight_size[2 + max_dim];
int padding[max_dim];
int stride[max_dim];
int dilation[max_dim];
int64_t groups;
bool deterministic;
bool allow_tf32;
// NB: transposed purposely omitted: transposed just swaps
// forward and backward, so you can reuse the benchmark entry,
};
std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params);
// NB: This can't be a constructor, because then ConvolutionParams
// would not be a POD anymore.
// TODO: Use TensorGeometry here instead of the entire Tensor, which we
// don't actually need. (OTOH: We can always pass in
// grad_input/grad_output, so this is not very pressing)
void setConvolutionParams(
ConvolutionParams* params,
const at::Tensor& input, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool deterministic, bool allow_tf32);
std::string repro_from_args(const ConvolutionParams& args);
// ---------------------------------------------------------------------
//
// Raw functions
//
// ---------------------------------------------------------------------
void raw_cudnn_convolution_forward_out(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32);
void raw_cudnn_convolution_backward_input_out(
const at::Tensor& grad_input,
const at::Tensor& grad_output,
const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32);
void raw_cudnn_convolution_backward_weight_out(
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32);
void raw_cudnn_convolution_add_relu_out(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const Tensor& z,
float alpha,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups,
bool benchmark,
bool deterministic,
bool allow_tf32);
}}