Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a DynamoSyncInputExecuteTime counter #6813

Merged
merged 4 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def test_simple_model(self):
torch_xla._XLAC._get_xla_tensor_debug_info(xla_xy))
self.assertNotIn('XLAData: None',
torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3))
# Dynamo has to sync the input since they are intermedate IR(xla_xy and xla_y3)
self.assertEqual(met.counter_value('DynamoSyncInputExecuteTime'), 1)

# Tests that the dynamo bridge automatically moves tensors to XLA device,
# then back to the original device.
Expand Down
12 changes: 12 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,18 @@ def test_execute_time_metric(self):
# of `ExecuteComputation`, but the actual async time.
self.assertGreater(execute_time_ns, .5 * wall_time_ns)

def test_pybind_increment_counter(self):
met.clear_all()
xla_device = xm.xla_device()
t1 = torch.tensor(2077, device=xla_device)
self.assertEqual(met.counter_value('CreateXlaTensor'), 1)
torch_xla._XLAC._xla_increment_counter('CreateXlaTensor', 3)
self.assertEqual(met.counter_value('CreateXlaTensor'), 4)

# try increment a counter that does not exist
torch_xla._XLAC._xla_increment_counter('FakeCounter', 2)
self.assertEqual(met.counter_value('FakeCounter'), 2)


if __name__ == '__main__':
test = unittest.main()
Expand Down
1 change: 1 addition & 0 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def optimized_mod(*args: tuple):
[a for a in args if isinstance(a, torch.Tensor)])) if x
]
if len(input_tensors_to_sync) > 0:
torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1)
torch_xla._XLAC._xla_sync_multi(
input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True)

Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,12 @@ void InitXlaModuleBindings(py::module m) {
return xla_data != nullptr ? py::cast<int64_t>(xla_data->Value())
: py::none();
});
// TORCH_LAZY_COUNTER
m.def("_xla_increment_counter",
[](const std::string& name, uint64_t inc_val) {
torch::lazy::Counter* counter = new ::torch::lazy::Counter(name);
counter->AddValue(inc_val);
});
m.def("_xla_metric_names", []() {
auto metric_names = torch::lazy::GetMetricNames();
auto xla_metric_names = runtime::metrics::GetMetricNames();
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/debug/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def short_metrics_report(counter_names: list = None, metric_names: list = None):
metric_names (list): The list of metric names whose data needs to be printed.
"""
if not counter_names:
counter_names = ['CachedCompile', 'MarkStep']
counter_names = ['CachedCompile', 'MarkStep', 'DynamoSyncInputExecuteTime']
if not metric_names:
metric_names = [
'CompileTime', 'ExecuteTime', 'ExecuteReplicatedTime',
Expand Down
Loading