diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index d9c13c6ec69e..ec303ffccfe1 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -484,7 +484,7 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): collector.run(*xla_args) fallback_ops = collector.get_fallback_ops() if (ptxla_debug or dynamo_debug) and len(fallback_ops) > 0: - print('Dynamo fallback ops are' + str(fallback_ops) + + print('pt-xla-profiler: Dynamo fallback ops are' + str(fallback_ops) + '. Please open a GitHub issue with the above op lowering requests.') # This logic, needed for supporting in-place operations, is a duplicate of @@ -502,38 +502,44 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) - class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): - - def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return node.op in [ - "call_function", "call_module", "call_method" - ] and (node not in fallback_ops or node.target == operator.getitem) - - # partition the model - supported_ops = XlaOperatorSupport() - partitioner = CapabilityBasedPartitioner( - xla_model, supported_ops, allows_single_node_partition=True) - partitions = partitioner.propose_partitions() - - # propose_partitions() does not guarantee topolgical order, so sort it manually - for partition in partitions: - partition.nodes = topo_sort(partition.nodes) - - # fuse partitions and exectue to collect inputs - partitioned_graph = partitioner.fuse_partitions(partitions) - InputCollector(partitioned_graph).run(*xla_args) - - # compile each submodule and replace it with a call - for node in partitioned_graph.graph.nodes: - if node.op == "call_module" and "fused_" in node.name: - fused_module = getattr(partitioned_graph, node.name) - partitioned_graph.delete_submodule(node.target) - with partitioned_graph.graph.inserting_after(node): - new_node = partitioned_graph.graph.call_function( - extract_internal(fused_module), node.args, None) - node.replace_all_uses_with(new_node) - partitioned_graph.graph.erase_node(node) - - partitioned_graph.recompile() - - return partitioned_graph + if len(fallback_ops) == 0: + # skip the partitioner if there is no fallback + xla_model.xla_args = xla_args + return extract_internal(xla_model) + else: + + class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): + + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op in [ + "call_function", "call_module", "call_method" + ] and (node not in fallback_ops or node.target == operator.getitem) + + # partition the model + supported_ops = XlaOperatorSupport() + partitioner = CapabilityBasedPartitioner( + xla_model, supported_ops, allows_single_node_partition=True) + partitions = partitioner.propose_partitions() + + # propose_partitions() does not guarantee topolgical order, so sort it manually + for partition in partitions: + partition.nodes = topo_sort(partition.nodes) + + # fuse partitions and exectue to collect inputs + partitioned_graph = partitioner.fuse_partitions(partitions) + InputCollector(partitioned_graph).run(*xla_args) + + # compile each submodule and replace it with a call + for node in partitioned_graph.graph.nodes: + if node.op == "call_module" and "fused_" in node.name: + fused_module = getattr(partitioned_graph, node.name) + partitioned_graph.delete_submodule(node.target) + with partitioned_graph.graph.inserting_after(node): + new_node = partitioned_graph.graph.call_function( + extract_internal(fused_module), node.args, None) + node.replace_all_uses_with(new_node) + partitioned_graph.graph.erase_node(node) + + partitioned_graph.recompile() + + return partitioned_graph