diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 5d891fc4..0c9090df 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -53,8 +53,19 @@ def __init__( parallel_context: ParallelContext, mapping: dict = None, memory_priority: bool = False, + module_args: dict = None, ): super().__init__(module, parallel_context) + + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.module = module self.parallel_context = parallel_context self.memory_priority = memory_priority diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 7a6f607c..cd68334f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -57,6 +57,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): super().__init__(module, parallel_context, mapping) self.module = module @@ -71,6 +72,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 86b36fe8..70e36a92 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -54,6 +54,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): super().__init__(module, parallel_context, mapping) self.module = module @@ -68,6 +69,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index e16af883..4543cee2 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -55,6 +55,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): super().__init__(module, parallel_context, mapping) self.module = module @@ -69,6 +70,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index d18e2212..ce7eb142 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -89,6 +89,7 @@ def __init__( parallel_context: Optional[ParallelContext] = None, mapping: dict = None, memory_priority: bool = False, + module_args: dict = None, ): super().__init__() self.parallel_context = get_parallel_context(module, parallel_context) @@ -103,14 +104,20 @@ def __init__( if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: self.module = _TensorParallel1D( - module, self.parallel_context, mapping, memory_priority + module, self.parallel_context, mapping, memory_priority, module_args ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D: - self.module = _TensorParallel2D(module, self.parallel_context, mapping) + self.module = _TensorParallel2D( + module, self.parallel_context, mapping, module_args + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D: - self.module = _TensorParallel2p5D(module, self.parallel_context, mapping) + self.module = _TensorParallel2p5D( + module, self.parallel_context, mapping, module_args + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_3D: - self.module = _TensorParallel3D(module, self.parallel_context, mapping) + self.module = _TensorParallel3D( + module, self.parallel_context, mapping, module_args + ) else: raise ValueError( "currently, only 1d, 2d, 2p5d tensor parallelism is supported."