diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index bd50d0d0598..1e89f0e6edb 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -117,7 +117,7 @@ class ParallelExecutionCtx: - example_inputs (`List[Any]`): A list of tensors which are used as example inputs for graphs captured by dynamo. - - parallel_layer_cache (`Dict[int, nn.Module]`): + - parallel_layer_cache (`Dict[str, nn.Module]`): Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts. Note that we will build the cache in the first compilation process, and for recompilations later on, we will directly replace the modules with their parallel counterparts in the cache, @@ -135,7 +135,7 @@ class ParallelExecutionCtx: tp_group: dist.ProcessGroup current_device: torch.device example_inputs: List[Any] = field(default_factory=list) - parallel_layer_cache: Dict[int, nn.Module] = field(default_factory=dict) + parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) weight_map: Dict[str, str] = field(default_factory=dict) compile_times: int = 0 diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 6574f5e883d..d14abc6b6ad 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -388,7 +388,7 @@ def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None: field = node.target mod: nn.Linear = graph_module.get_submodule(node.target) - key, layer_cache = id(mod), ctx.parallel_layer_cache + key, layer_cache = node.target, ctx.parallel_layer_cache if key in layer_cache: new_mod = layer_cache[key] else: @@ -422,7 +422,7 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: field = node.target mod: nn.Embedding = graph_module.get_submodule(node.target) - key, layer_cache = id(mod), ctx.parallel_layer_cache + key, layer_cache = node.target, ctx.parallel_layer_cache if key in layer_cache: new_mod = layer_cache[key] else: