Skip to content

Commit

Permalink
Add a DynamoSyncInputExecuteTime counter (pytorch#6813)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and yitongh committed Dec 11, 2024
1 parent d45968c commit 8cb3371
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 1 deletion.
2 changes: 2 additions & 0 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def test_simple_model(self):
torch_xla._XLAC._get_xla_tensor_debug_info(xla_xy))
self.assertNotIn('XLAData: None',
torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3))
# Dynamo has to sync the input since they are intermedate IR(xla_xy and xla_y3)
self.assertEqual(met.counter_value('DynamoSyncInputExecuteTime'), 1)

# Tests that the dynamo bridge automatically moves tensors to XLA device,
# then back to the original device.
Expand Down
12 changes: 12 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ def test_execute_time_metric(self):
# of `ExecuteComputation`, but the actual async time.
self.assertGreater(execute_time_ns, .5 * wall_time_ns)

def test_pybind_increment_counter(self):
met.clear_all()
xla_device = xm.xla_device()
t1 = torch.tensor(2077, device=xla_device)
self.assertEqual(met.counter_value('CreateXlaTensor'), 1)
torch_xla._XLAC._xla_increment_counter('CreateXlaTensor', 3)
self.assertEqual(met.counter_value('CreateXlaTensor'), 4)

# try increment a counter that does not exist
torch_xla._XLAC._xla_increment_counter('FakeCounter', 2)
self.assertEqual(met.counter_value('FakeCounter'), 2)


if __name__ == '__main__':
test = unittest.main()
Expand Down
1 change: 1 addition & 0 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def optimized_mod(*args: tuple):
[a for a in args if isinstance(a, torch.Tensor)])) if x
]
if len(input_tensors_to_sync) > 0:
torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1)
torch_xla._XLAC._xla_sync_multi(
input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True)

Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,12 @@ void InitXlaModuleBindings(py::module m) {
return xla_data != nullptr ? py::cast<int64_t>(xla_data->Value())
: py::none();
});
// TORCH_LAZY_COUNTER
m.def("_xla_increment_counter",
[](const std::string& name, uint64_t inc_val) {
torch::lazy::Counter* counter = new ::torch::lazy::Counter(name);
counter->AddValue(inc_val);
});
m.def("_xla_metric_names", []() {
auto metric_names = torch::lazy::GetMetricNames();
auto xla_metric_names = runtime::metrics::GetMetricNames();
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/debug/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def short_metrics_report(counter_names: list = None, metric_names: list = None):
metric_names (list): The list of metric names whose data needs to be printed.
"""
if not counter_names:
counter_names = ['CachedCompile', 'MarkStep']
counter_names = ['CachedCompile', 'MarkStep', 'DynamoSyncInputExecuteTime']
if not metric_names:
metric_names = [
'CompileTime', 'ExecuteTime', 'ExecuteReplicatedTime',
Expand Down

0 comments on commit 8cb3371

Please sign in to comment.