forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgen_autograd.py
318 lines (265 loc) · 12.1 KB
/
gen_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
To run this file by hand from the root of the PyTorch
repository, run:
python -m tools.autograd.gen_autograd \
build/aten/src/ATen/Declarations.yaml \
$OUTPUT_DIR \
tools/autograd
Where $OUTPUT_DIR is where you would like the files to be
generated. In the full build system, OUTPUT_DIR is
torch/csrc/autograd/generated/
"""
# gen_autograd.py generates C++ autograd functions and Python bindings.
#
# It delegates to the following scripts:
#
# gen_autograd_functions.py: generates subclasses of torch::autograd::Node
# gen_variable_type.py: generates VariableType.h which contains all tensor methods
# gen_python_functions.py: generates Python bindings to THPVariable
#
import argparse
import copy
import os
import yaml
import re
from collections import defaultdict
from .utils import YamlLoader, split_name_params, op_name_without_overload
# See NOTE [ Autograd View Variables ] in variable.h for details.
# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
# you **MUST** also update the public list of view ops accordingly in
# docs/source/tensor_view.rst. Note not all ATen functions are exposed to public,
# e.g alias & sparse_coo_tensor_with_dims_and_tensors.
#
# A map: function name => name of the argument that all outputs are view of
VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_real', 'view_as_complex']
VIEW_FUNCTIONS = {
'numpy_T': 'self',
'alias': 'self',
'as_strided': 'self',
'diagonal': 'self',
'expand': 'self',
'permute': 'self',
'select': 'self',
'slice': 'self',
'split': 'self',
'split_with_sizes': 'self',
'squeeze': 'self',
't': 'self',
'transpose': 'self',
'unfold': 'self',
'unsqueeze': 'self',
'flatten': 'self',
'view': 'self',
'unbind': 'self',
'_indices': 'self',
'_values': 'self',
'indices': 'self',
'values': 'self',
# sparse_coo ctor output should really be views of both indices and values,
# but we only supports making as view of a single variable, and indices is
# discrete anyways.
# FIXME: clone indices on construction.
'sparse_coo_tensor_with_dims_and_tensors': 'values',
}
for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE:
VIEW_FUNCTIONS[key] = 'self'
# Functions for which we use CreationMeta::MULTI_OUTPUT_SAFE. I.e., the ones for
# which inplace modification of outputs is being gradually deprecated.
MULTI_OUTPUT_SAFE_FUNCTIONS = {
'split',
'split_with_sizes',
}
# note: some VIEW_FUNCTIONS are just compositions of the view functions above
# this list contains both the root view functions and any that are purely composed
# of viewing functions, and is used by the JIT to determine when an operator
# may return a view of its inputs; however they may sometimes return a copy.
# (e.g. `contiguous`)
RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({
'chunk', 'detach', 'contiguous', 'reshape', 'reshape_as',
'expand_as', 'view_as', 'real', 'imag', 'narrow', 'movedim',
})
def format_return_type(returns):
if len(returns) == 0:
return 'void'
elif len(returns) == 1:
return returns[0]['type']
else:
return_types = [r['type'] for r in returns]
return 'std::tuple<{}>'.format(','.join(return_types))
def get_simple_type(arg):
simple_type = arg['type']
simple_type = simple_type.replace(' &', '').replace('const ', '')
simple_type = simple_type.replace('Generator *', 'Generator')
opt_match = re.match(r'c10::optional<(.+)>', simple_type)
if opt_match:
simple_type = '{}?'.format(opt_match.group(1))
return simple_type
def has_tensoroptions_argument(declaration):
for argument in declaration['arguments']:
if 'TensorOptions' == argument['dynamic_type']:
return True
return False
def process_schema_order_arg(schema_order_arg):
if schema_order_arg == 'dtype':
return 'optTypeMetaToScalarType(options.dtype_opt())'
elif schema_order_arg == 'layout':
return 'options.layout_opt()'
elif schema_order_arg == 'device':
return 'options.device_opt()'
elif schema_order_arg == 'pin_memory':
return 'options.pinned_memory_opt()'
elif schema_order_arg == 'memory_format':
return 'c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)'
else:
return schema_order_arg
def load_aten_declarations(path):
with open(path, 'r') as f:
declarations = yaml.load(f, Loader=YamlLoader)
# enrich declarations with additional information
selected_declarations = []
for declaration in declarations:
if declaration.get('deprecated'):
continue
for arg in declaration['arguments']:
arg['simple_type'] = get_simple_type(arg)
for ret in declaration['returns']:
ret['simple_type'] = get_simple_type(ret)
declaration['formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['arguments']]
declaration['schema_order_formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['schema_order_arguments']]
declaration['args'] = [arg['name'] for arg in declaration['arguments']]
declaration['schema_order_args'] = [arg['name'] for arg in declaration['schema_order_arguments']]
if has_tensoroptions_argument(declaration):
declaration['schema_order_args'] = [process_schema_order_arg(arg) for arg in declaration['schema_order_args']]
declaration['api_name'] = declaration['name']
# NB: keep this in sync with common_with_cwrap.py
if declaration.get('overload_name'):
declaration['type_wrapper_name'] = "{}_{}".format(
declaration['name'], declaration['overload_name'])
else:
declaration['type_wrapper_name'] = declaration['name']
declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0]
declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1]
declaration['return_type'] = format_return_type(declaration['returns'])
declaration['base_name'] = declaration['name']
selected_declarations.append(declaration)
return selected_declarations
def load_deprecated_signatures(aten_decls, deprecated_path):
def group_declarations_by_signature():
d = defaultdict(list)
for declaration in aten_decls:
name = declaration['name']
base_name = name[:-1] if declaration['inplace'] else name
simple_types = [arg['simple_type'] for arg in declaration['arguments']]
signature = '{}({})'.format(base_name, ', '.join(simple_types))
d[signature].append(declaration)
return d
with open(deprecated_path, 'r') as f:
deprecated_defs = yaml.load(f, Loader=YamlLoader)
declarations = []
declarations_by_signature = group_declarations_by_signature()
def get_signature(name, params, call_args):
# create a mapping of parameter name to parameter type
types = dict([param.split(' ')[::-1] for param in params if param != '*'])
# if the name in the call is not in the parameter list, assume it's
# a literal Scalar
rearranged_types = [types.get(arg, 'Scalar') for arg in call_args]
return '{}({})'.format(name, ', '.join(rearranged_types))
for deprecated in deprecated_defs:
aten_name, call_args = split_name_params(deprecated['aten'])
name, params = split_name_params(deprecated['name'])
signature = get_signature(aten_name, params, call_args)
for declaration in declarations_by_signature[signature]:
declaration = copy.deepcopy(declaration)
declaration['deprecated'] = True
declaration['call_args'] = call_args
call_arg_to_idx = {arg: i for i, arg in enumerate(call_args)}
original_args = declaration['arguments']
# Create an arguments list that uses the types from the original
# ATen declaration, but the ordering and parameter names from
# the deprecated overload. Any default parameter values from the
# original ATen declaration are ignored.
arguments = []
kwarg_only = False
for param in params:
if param == '*':
kwarg_only = True
continue
_, param_name = param.split(' ')
original = original_args[call_arg_to_idx[param_name]]
arguments.append({
'name': param_name,
'kwarg_only': kwarg_only,
'type': original['type'],
'simple_type': original['simple_type'],
'dynamic_type': original['dynamic_type'],
'output': original.get('output', False),
})
declaration['arguments'] = arguments
declarations.append(declaration)
return declarations
def gen_autograd(aten_path, out, autograd_dir, disable_autograd=False, selected_op_list=None):
full_aten_decls = load_aten_declarations(aten_path)
def filter_decls(aten_decls, selected_op_list):
if selected_op_list is None:
return aten_decls
return [decl for decl in aten_decls if op_name_without_overload(decl) in selected_op_list]
aten_decls = filter_decls(full_aten_decls, selected_op_list)
# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
autograd_functions = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), full_aten_decls)
template_path = os.path.join(autograd_dir, 'templates')
# Generate VariableType.h/cpp
if not disable_autograd:
from .gen_variable_type import gen_variable_type
gen_variable_type(out, aten_decls, template_path)
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_lib
gen_autograd_functions_lib(
out, autograd_functions, template_path)
# Generate variable_factories.h
from .gen_variable_factories import gen_variable_factories
# Some non-selectable ops (e.g. prim ops) need factory methods so we pass in `full_aten_decls` here.
gen_variable_factories(out, full_aten_decls, template_path)
def gen_autograd_python(aten_path, out, autograd_dir):
# TODO Deduplicate these four variable assignments
aten_decls = load_aten_declarations(aten_path)
# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
autograd_functions = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
template_path = os.path.join(autograd_dir, 'templates')
# Load deprecated signatures
deprecated = load_deprecated_signatures(
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_python
gen_autograd_functions_python(
out, autograd_functions, template_path)
# Generate Python bindings
from . import gen_python_functions
gen_python_functions.gen_py_variable_methods(
out, aten_decls + deprecated, template_path)
gen_python_functions.gen_py_torch_functions(
out, aten_decls + deprecated, template_path)
gen_python_functions.gen_py_nn_functions(
out, aten_decls, template_path)
gen_python_functions.gen_py_fft_functions(
out, aten_decls, template_path)
gen_python_functions.gen_py_linalg_functions(
out, aten_decls, template_path)
def main():
parser = argparse.ArgumentParser(
description='Generate autograd C++ files script')
parser.add_argument('declarations', metavar='DECL',
help='path to Declarations.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
parser.add_argument('autograd', metavar='AUTOGRAD',
help='path to autograd directory')
args = parser.parse_args()
gen_autograd(args.declarations, args.out, args.autograd)
if __name__ == '__main__':
main()