Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Jul 12, 2024
1 parent 473388b commit 0512b23
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions optimum/fx/parallelization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ class ParallelLayerReplacePass(PassBase):
"""

@staticmethod
def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None:
def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None:
graph_module = node.graph.owning_module
axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis")
if axis is None:
Expand All @@ -396,17 +396,17 @@ def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None
gather_output = ParallelLayerAnnotatePass.get_stored_field_info(
node, field="gather_output", must_have=True
)
new_mod = ColumnParallelLinear(ctx, mod, gather_output, config.weight_init_fn)
new_mod = ColumnParallelLinear(ctx, mod, gather_output)
else:
input_is_parallel = ParallelLayerAnnotatePass.get_stored_field_info(
node, field="input_is_parallel", must_have=True
)
new_mod = RowParallelLinear(ctx, mod, input_is_parallel, config.weight_init_fn)
new_mod = RowParallelLinear(ctx, mod, input_is_parallel)
layer_cache[key] = new_mod
setattr(parent_mod, field, new_mod)

@staticmethod
def handle_embedding(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None:
def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None:
graph_module = node.graph.owning_module
axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis")
if axis is None:
Expand All @@ -426,7 +426,7 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx, config: Config) -> N
if key in layer_cache:
new_mod = layer_cache[key]
else:
new_mod = VocabParallelEmbedding(ctx, mod, config.weight_init_fn)
new_mod = VocabParallelEmbedding(ctx, mod)
layer_cache[key] = new_mod
setattr(parent_mod, field, new_mod)

Expand Down Expand Up @@ -468,9 +468,9 @@ def update(node: Node, new_shape: List[Any], parallel_axis: int):
def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule:
for node in graph_module.graph.nodes:
if is_linear(node):
self.handle_linear(node, ctx, config)
self.handle_linear(node, ctx)
elif is_embedding(node):
self.handle_embedding(node, ctx, config)
self.handle_embedding(node, ctx)
# correct the attention head num in parallel setting
elif is_shape_consumer(node):
self.handle_hard_coded_axis_param(node, ctx)
Expand Down

0 comments on commit 0512b23

Please sign in to comment.