From 8cb3371091435d25a2db322a8517be3631cad5e0 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 1 Apr 2024 12:19:39 -0700 Subject: [PATCH] Add a `DynamoSyncInputExecuteTime` counter (#6813) --- test/dynamo/test_dynamo.py | 2 ++ test/test_metrics.py | 12 ++++++++++++ torch_xla/core/dynamo_bridge.py | 1 + torch_xla/csrc/init_python_bindings.cpp | 6 ++++++ torch_xla/debug/metrics.py | 2 +- 5 files changed, 22 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 9f93f9b803b..d8bb01f13fa 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -102,6 +102,8 @@ def test_simple_model(self): torch_xla._XLAC._get_xla_tensor_debug_info(xla_xy)) self.assertNotIn('XLAData: None', torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3)) + # Dynamo has to sync the input since they are intermedate IR(xla_xy and xla_y3) + self.assertEqual(met.counter_value('DynamoSyncInputExecuteTime'), 1) # Tests that the dynamo bridge automatically moves tensors to XLA device, # then back to the original device. diff --git a/test/test_metrics.py b/test/test_metrics.py index 734391f91bc..7e659980e8f 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -192,6 +192,18 @@ def test_execute_time_metric(self): # of `ExecuteComputation`, but the actual async time. self.assertGreater(execute_time_ns, .5 * wall_time_ns) + def test_pybind_increment_counter(self): + met.clear_all() + xla_device = xm.xla_device() + t1 = torch.tensor(2077, device=xla_device) + self.assertEqual(met.counter_value('CreateXlaTensor'), 1) + torch_xla._XLAC._xla_increment_counter('CreateXlaTensor', 3) + self.assertEqual(met.counter_value('CreateXlaTensor'), 4) + + # try increment a counter that does not exist + torch_xla._XLAC._xla_increment_counter('FakeCounter', 2) + self.assertEqual(met.counter_value('FakeCounter'), 2) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 5a20aa2389d..379416ec73f 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -464,6 +464,7 @@ def optimized_mod(*args: tuple): [a for a in args if isinstance(a, torch.Tensor)])) if x ] if len(input_tensors_to_sync) > 0: + torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1) torch_xla._XLAC._xla_sync_multi( input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d4c5909ee85..9870fdb72ac 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1769,6 +1769,12 @@ void InitXlaModuleBindings(py::module m) { return xla_data != nullptr ? py::cast(xla_data->Value()) : py::none(); }); + // TORCH_LAZY_COUNTER + m.def("_xla_increment_counter", + [](const std::string& name, uint64_t inc_val) { + torch::lazy::Counter* counter = new ::torch::lazy::Counter(name); + counter->AddValue(inc_val); + }); m.def("_xla_metric_names", []() { auto metric_names = torch::lazy::GetMetricNames(); auto xla_metric_names = runtime::metrics::GetMetricNames(); diff --git a/torch_xla/debug/metrics.py b/torch_xla/debug/metrics.py index 108b0d61bf5..363c52a80da 100644 --- a/torch_xla/debug/metrics.py +++ b/torch_xla/debug/metrics.py @@ -72,7 +72,7 @@ def short_metrics_report(counter_names: list = None, metric_names: list = None): metric_names (list): The list of metric names whose data needs to be printed. """ if not counter_names: - counter_names = ['CachedCompile', 'MarkStep'] + counter_names = ['CachedCompile', 'MarkStep', 'DynamoSyncInputExecuteTime'] if not metric_names: metric_names = [ 'CompileTime', 'ExecuteTime', 'ExecuteReplicatedTime',