diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 728aa807a53..79100e7154a 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -26,8 +26,8 @@ from ..parallel_layers import ( ColumnParallelLinear, RowParallelLinear, - VocabParallelEmbedding, VocabParallelCrossEntropyLoss, + VocabParallelEmbedding, sharded_cross_entropy_wrapper_fn, ) from ..passes import ( @@ -39,6 +39,22 @@ class Backend(ABC): + """ + Abstract base class for implementing parallelization backends. + + This class defines the interface for creating parallel versions of various + PyTorch modules and operations. Subclasses should implement the abstract + methods to provide specific parallelization strategies. + + Methods: + create_column_parallel_linear: Create a column-parallel version of a linear layer. + create_row_parallel_linear: Create a row-parallel version of a linear layer. + create_parallel_embedding: Create a parallel version of an embedding layer. + create_parallel_cross_entropy: Create a parallel version of cross entropy loss. + pre_process: Perform pre-processing on the graph module before parallelization. + post_process: Perform post-processing on the graph module after parallelization. + init_parallelization_pass_pipeline: Initialize the parallelization pass pipeline. + """ @abstractmethod def create_column_parallel_linear( self, @@ -82,6 +98,7 @@ def create_parallel_cross_entropy( else: return sharded_cross_entropy_wrapper_fn(process_group=parallel_ctx.tp_group) + @abstractmethod def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: """ Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our @@ -97,12 +114,16 @@ def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", co param_meta.tied_to = parameter_mp[key] return graph_module + @abstractmethod def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: + """ + This method is called after the parallelization passes have been applied. It is used to perform any backend-specific + post-processing on the graph module. + """ return graph_module - def init_parallelization_pass_pipeline( - self, - ) -> PassPipeline: + @abstractmethod + def init_parallelization_pass_pipeline(self) -> PassPipeline: """ Ensemble a pass pipeline which contains the following passes: 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. @@ -165,6 +186,9 @@ def create_parallel_embedding( return VocabParallelEmbedding(parallel_ctx, mod) def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: + """ + Initialize or load parameters from checkpoint, and tie them if needed. + """ world_size = dist.get_world_size(ctx.tp_group) tp_rank = dist.get_rank(ctx.tp_group) diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index 847b7c4b924..76e0553a795 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -152,6 +152,10 @@ def create_parallel_embedding( def post_process( self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config" ) -> nn.Module: + """ + Convert parameters to `NanotronParameter` and tie them if needed. Note that we don't initialize or load weights here + because nanotron will do that for us in the trainer class. + """ from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.tied_parameters import tie_parameters