diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index bdfc56cd214..d4d563d5b6f 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -372,7 +372,7 @@ class ParallelLayerReplacePass(PassBase): """ @staticmethod - def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None: + def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None: graph_module = node.graph.owning_module axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: @@ -396,17 +396,17 @@ def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None gather_output = ParallelLayerAnnotatePass.get_stored_field_info( node, field="gather_output", must_have=True ) - new_mod = ColumnParallelLinear(ctx, mod, gather_output, config.weight_init_fn) + new_mod = ColumnParallelLinear(ctx, mod, gather_output) else: input_is_parallel = ParallelLayerAnnotatePass.get_stored_field_info( node, field="input_is_parallel", must_have=True ) - new_mod = RowParallelLinear(ctx, mod, input_is_parallel, config.weight_init_fn) + new_mod = RowParallelLinear(ctx, mod, input_is_parallel) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) @staticmethod - def handle_embedding(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None: + def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: graph_module = node.graph.owning_module axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: @@ -426,7 +426,7 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx, config: Config) -> N if key in layer_cache: new_mod = layer_cache[key] else: - new_mod = VocabParallelEmbedding(ctx, mod, config.weight_init_fn) + new_mod = VocabParallelEmbedding(ctx, mod) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) @@ -468,9 +468,9 @@ def update(node: Node, new_shape: List[Any], parallel_axis: int): def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: for node in graph_module.graph.nodes: if is_linear(node): - self.handle_linear(node, ctx, config) + self.handle_linear(node, ctx) elif is_embedding(node): - self.handle_embedding(node, ctx, config) + self.handle_embedding(node, ctx) # correct the attention head num in parallel setting elif is_shape_consumer(node): self.handle_hard_coded_axis_param(node, ctx)