Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip partitioner when there is no fallback #5756

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 42 additions & 36 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
collector.run(*xla_args)
fallback_ops = collector.get_fallback_ops()
if (ptxla_debug or dynamo_debug) and len(fallback_ops) > 0:
print('Dynamo fallback ops are' + str(fallback_ops) +
print('pt-xla-profiler: Dynamo fallback ops are' + str(fallback_ops) +
'. Please open a GitHub issue with the above op lowering requests.')

# This logic, needed for supporting in-place operations, is a duplicate of
Expand All @@ -502,38 +502,44 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op in [
"call_function", "call_module", "call_method"
] and (node not in fallback_ops or node.target == operator.getitem)

# partition the model
supported_ops = XlaOperatorSupport()
partitioner = CapabilityBasedPartitioner(
xla_model, supported_ops, allows_single_node_partition=True)
partitions = partitioner.propose_partitions()

# propose_partitions() does not guarantee topolgical order, so sort it manually
for partition in partitions:
partition.nodes = topo_sort(partition.nodes)

# fuse partitions and exectue to collect inputs
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

# compile each submodule and replace it with a call
for node in partitioned_graph.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
fused_module = getattr(partitioned_graph, node.name)
partitioned_graph.delete_submodule(node.target)
with partitioned_graph.graph.inserting_after(node):
new_node = partitioned_graph.graph.call_function(
extract_internal(fused_module), node.args, None)
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

partitioned_graph.recompile()

return partitioned_graph
if len(fallback_ops) == 0:
# skip the partitioner if there is no fallback
xla_model.xla_args = xla_args
return extract_internal(xla_model)
else:

class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op in [
"call_function", "call_module", "call_method"
] and (node not in fallback_ops or node.target == operator.getitem)

# partition the model
supported_ops = XlaOperatorSupport()
partitioner = CapabilityBasedPartitioner(
xla_model, supported_ops, allows_single_node_partition=True)
partitions = partitioner.propose_partitions()

# propose_partitions() does not guarantee topolgical order, so sort it manually
for partition in partitions:
partition.nodes = topo_sort(partition.nodes)

# fuse partitions and exectue to collect inputs
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

# compile each submodule and replace it with a call
for node in partitioned_graph.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
fused_module = getattr(partitioned_graph, node.name)
partitioned_graph.delete_submodule(node.target)
with partitioned_graph.graph.inserting_after(node):
new_node = partitioned_graph.graph.call_function(
extract_internal(fused_module), node.args, None)
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

partitioned_graph.recompile()

return partitioned_graph
Loading