diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index e6916e08fb8a..12bc3ea1023f 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -73,10 +73,14 @@ def evaluate(self, method_name, args): res = pytree.tree_unflatten(res, out_spec) return res - def get_stablehlo_bytecode(self, method_name): + def get_stablehlo_bytecode(self, method_name=None): + if method_name is None: + method_name = self._default_method return self._name_to_stablehlo[method_name].bytecode - def get_stablehlo_text(self, method_name): + def get_stablehlo_text(self, method_name=None): + if method_name is None: + method_name = self._default_method return self._name_to_stablehlo[method_name].text def save(self, directory_path):