This repository has been archived by the owner on Nov 15, 2024. It is now read-only.
forked from violethaze74/pumpkin-py
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrecord_function_ops.cpp
169 lines (151 loc) · 6.23 KB
/
record_function_ops.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#include <ATen/ThreadLocalState.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/record_function.h>
#include <torch/csrc/autograd/record_function_ops.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/library.h>
namespace caffe2 {
// Required for cpp_custom_type_hack to work
// NOLINTNEXTLINE(bugprone-exception-escape)
CAFFE_KNOWN_TYPE(at::RecordFunction);
} // namespace caffe2
namespace torch {
namespace autograd {
namespace profiler {
// Creates a new profiling scope using RecordFunction and invokes its starting
// callbacks.
void record_function_enter(
const std::string& name,
const c10::optional<std::string>& args,
at::RecordFunction& rec) {
if (rec.isActive()) {
if (rec.needsInputs() && args.has_value()) {
rec.before(
name, c10::ArrayRef<const c10::IValue>{c10::IValue{args.value()}});
} else {
rec.before(name);
}
}
}
// Legacy signature using cpp_custom_type_hack
at::Tensor record_function_enter_legacy(
const std::string& name,
const c10::optional<std::string>& args) {
auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
record_function_enter(name, args, *rec);
return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
}
// New signature using custom_class
c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
const std::string& name,
const c10::optional<std::string>& args) {
auto rec =
c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
record_function_enter(name, args, rec->record);
return rec;
}
at::RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) {
auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
return rec;
}
// Ends the profiling scope created with record_function_enter.
void record_function_exit(at::RecordFunction& rec) {
rec.end();
}
// Legacy signature using cpp_custom_type_hack
void record_function_exit_legacy(const at::Tensor& handle) {
// We don't actually need to do anything with handle just need to persist the
// lifetime until now.
auto& rec = getRecordFunctionFromTensor(handle);
record_function_exit(rec);
}
// New signature using custom_class
void record_function_exit_new(
const c10::intrusive_ptr<PythonRecordFunction>& record) {
record_function_exit(record->record);
}
template <typename Func>
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
Func get_record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
// Profiling callback that ends the associated record_function
// and returns the value of the passed in future.
std::function<c10::IValue(c10::ivalue::Future&)> futureProfilingFunc =
[get_record = std::move(get_record)](c10::ivalue::Future& fut) {
auto& rec = get_record();
rec.end();
// Note: this future is returned to the user to ensure that a call to
// wait() ensures that profiling callbacks have ran. To ensure that this
// is transparent, we must make this future propagate the value of the
// RPC future. Use value() here instead of constValue() to ensure we
// propagate errors.
return fut.value();
};
// Define a future that completes after the profiling callbacks are run.
auto profiledFut = fut->then(
at::wrapPropagateTLSState(std::move(futureProfilingFunc)),
fut->elementType());
return profiledFut;
}
// Legacy signature using cpp_custom_type_hack
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
const at::Tensor& handle,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
return _call_end_callbacks_on_fut(
[handle]() -> at::RecordFunction& {
TORCH_INTERNAL_ASSERT(
handle.defined(),
"Undefined RecordFunction handle. This can happen if the handle is "
"not correctly persisted and is destroyed before the future is "
"realized.");
return getRecordFunctionFromTensor(handle);
},
fut);
}
// New signature using custom_class
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
const c10::intrusive_ptr<PythonRecordFunction>& record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
return _call_end_callbacks_on_fut(
[record]() -> at::RecordFunction& { return record->record; }, fut);
}
// Internal only, do not use directly, use Python's record_function()
TORCH_LIBRARY_FRAGMENT(profiler, m) {
m.class_<PythonRecordFunction>("_RecordFunction");
m.def(
"_record_function_enter(str name, str? args=None) -> Tensor",
&record_function_enter_legacy);
m.def(
"_record_function_enter_new(str name, str? args=None) -> "
"__torch__.torch.classes.profiler._RecordFunction",
&record_function_enter_new);
m.def("_record_function_exit", &record_function_exit_legacy);
m.def("_record_function_exit._RecordFunction", &record_function_exit_new);
torch::jit::registerOperator(torch::jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
[](jit::Stack& stack) {
// Pop inputs, which should be a future and a tensor
auto fut = jit::pop(stack).toFuture();
auto tensor = jit::pop(stack).toTensor();
auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut);
// return future that completes when profiling callbacks have run.
jit::push(stack, std::move(profiledFut));
},
c10::AliasAnalysisKind::FROM_SCHEMA));
torch::jit::registerOperator(torch::jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut._RecordFunction("
"__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)",
[](c10::Stack& stack) {
// Pop inputs, which should be a future and a PythonRecordFunction
auto fut = torch::jit::pop(stack).toFuture();
auto tensor =
torch::jit::pop(stack).toCustomClass<PythonRecordFunction>();
auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut);
// return future that completes when profiling callbacks have run.
torch::jit::push(stack, std::move(profiledFut));
},
c10::AliasAnalysisKind::FROM_SCHEMA));
}
} // namespace profiler
} // namespace autograd
} // namespace torch