forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgen_autograd.py
142 lines (119 loc) · 4.36 KB
/
gen_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
To run this file by hand from the root of the PyTorch
repository, run:
python -m tools.autograd.gen_autograd \
aten/src/ATen/native/native_functions.yaml \
aten/src/ATen/native/tags.yaml \
$OUTPUT_DIR \
tools/autograd
Where $OUTPUT_DIR is where you would like the files to be
generated. In the full build system, OUTPUT_DIR is
torch/csrc/autograd/generated/
"""
# gen_autograd.py generates C++ autograd functions and Python bindings.
#
# It delegates to the following scripts:
#
# gen_autograd_functions.py: generates subclasses of torch::autograd::Node
# gen_variable_type.py: generates VariableType.h which contains all tensor methods
# gen_python_functions.py: generates Python bindings to THPVariable
#
import argparse
import os
from typing import List
from torchgen.api import cpp
from torchgen.api.autograd import (
match_differentiability_info,
NativeFunctionWithDifferentiabilityInfo,
)
from torchgen.gen import parse_native_yaml
from torchgen.selective_build.selector import SelectiveBuilder
from . import gen_python_functions
from .gen_autograd_functions import (
gen_autograd_functions_lib,
gen_autograd_functions_python,
)
from .gen_inplace_or_view_type import gen_inplace_or_view_type
from .gen_trace_type import gen_trace_type
from .gen_variable_factories import gen_variable_factories
from .gen_variable_type import gen_variable_type
from .load_derivatives import load_derivatives
def gen_autograd(
native_functions_path: str,
tags_path: str,
out: str,
autograd_dir: str,
operator_selector: SelectiveBuilder,
disable_autograd: bool = False,
) -> None:
# Parse and load derivatives.yaml
differentiability_infos, used_dispatch_keys = load_derivatives(
os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
)
template_path = os.path.join(autograd_dir, "templates")
native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
fns = sorted(
filter(
operator_selector.is_native_function_selected_for_training, native_funcs
),
key=lambda f: cpp.name(f.func),
)
fns_with_diff_infos: List[
NativeFunctionWithDifferentiabilityInfo
] = match_differentiability_info(fns, differentiability_infos)
# Generate VariableType.h/cpp
if not disable_autograd:
gen_variable_type(
out,
native_functions_path,
tags_path,
fns_with_diff_infos,
template_path,
used_dispatch_keys,
)
gen_inplace_or_view_type(
out, native_functions_path, tags_path, fns_with_diff_infos, template_path
)
# operator filter not applied as tracing sources are excluded in selective build
gen_trace_type(out, native_funcs, template_path)
# Generate Functions.h/cpp
gen_autograd_functions_lib(out, differentiability_infos, template_path)
# Generate variable_factories.h
gen_variable_factories(out, native_functions_path, tags_path, template_path)
def gen_autograd_python(
native_functions_path: str,
tags_path: str,
out: str,
autograd_dir: str,
) -> None:
differentiability_infos, _ = load_derivatives(
os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
)
template_path = os.path.join(autograd_dir, "templates")
# Generate Functions.h/cpp
gen_autograd_functions_python(out, differentiability_infos, template_path)
# Generate Python bindings
deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
gen_python_functions.gen(
out, native_functions_path, tags_path, deprecated_path, template_path
)
def main() -> None:
parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
parser.add_argument(
"native_functions", metavar="NATIVE", help="path to native_functions.yaml"
)
parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml")
parser.add_argument("out", metavar="OUT", help="path to output directory")
parser.add_argument(
"autograd", metavar="AUTOGRAD", help="path to autograd directory"
)
args = parser.parse_args()
gen_autograd(
args.native_functions,
args.tags,
args.out,
args.autograd,
SelectiveBuilder.get_nop_selector(),
)
if __name__ == "__main__":
main()