Skip to content
This repository has been archived by the owner on Aug 26, 2022. It is now read-only.

Commit

Permalink
add module args for save_parallelized
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Aug 16, 2022
1 parent 20204ea commit 88ac33f
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 4 deletions.
11 changes: 11 additions & 0 deletions oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,19 @@ def __init__(
parallel_context: ParallelContext,
mapping: dict = None,
memory_priority: bool = False,
module_args: dict = None,
):
super().__init__(module, parallel_context)

if module_args is None:
if is_huggingface_model(module):
module_args = module.config
else:
raise ValueError(
"`config` must be input if the model is not huggingface model."
)

self.config = module_args
self.module = module
self.parallel_context = parallel_context
self.memory_priority = memory_priority
Expand Down
10 changes: 10 additions & 0 deletions oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
module: nn.Module,
parallel_context: ParallelContext,
mapping: dict = None,
module_args: dict = None,
):
super().__init__(module, parallel_context, mapping)
self.module = module
Expand All @@ -71,6 +72,15 @@ def __init__(
"`mapping` must be input if the model is not huggingface model."
)

if module_args is None:
if is_huggingface_model(module):
module_args = module.config
else:
raise ValueError(
"`config` must be input if the model is not huggingface model."
)

self.config = module_args
self.tensor_parallel_mapping = TensorParallelMapping(mapping)
self._parallelize()

Expand Down
10 changes: 10 additions & 0 deletions oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
module: nn.Module,
parallel_context: ParallelContext,
mapping: dict = None,
module_args: dict = None,
):
super().__init__(module, parallel_context, mapping)
self.module = module
Expand All @@ -68,6 +69,15 @@ def __init__(
"`mapping` must be input if the model is not huggingface model."
)

if module_args is None:
if is_huggingface_model(module):
module_args = module.config
else:
raise ValueError(
"`config` must be input if the model is not huggingface model."
)

self.config = module_args
self.tensor_parallel_mapping = TensorParallelMapping(mapping)
self._parallelize()

Expand Down
10 changes: 10 additions & 0 deletions oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
module: nn.Module,
parallel_context: ParallelContext,
mapping: dict = None,
module_args: dict = None,
):
super().__init__(module, parallel_context, mapping)
self.module = module
Expand All @@ -69,6 +70,15 @@ def __init__(
"`mapping` must be input if the model is not huggingface model."
)

if module_args is None:
if is_huggingface_model(module):
module_args = module.config
else:
raise ValueError(
"`config` must be input if the model is not huggingface model."
)

self.config = module_args
self.tensor_parallel_mapping = TensorParallelMapping(mapping)
self._parallelize()

Expand Down
15 changes: 11 additions & 4 deletions oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
parallel_context: Optional[ParallelContext] = None,
mapping: dict = None,
memory_priority: bool = False,
module_args: dict = None,
):
super().__init__()
self.parallel_context = get_parallel_context(module, parallel_context)
Expand All @@ -103,14 +104,20 @@ def __init__(

if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D:
self.module = _TensorParallel1D(
module, self.parallel_context, mapping, memory_priority
module, self.parallel_context, mapping, memory_priority, module_args
)
elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D:
self.module = _TensorParallel2D(module, self.parallel_context, mapping)
self.module = _TensorParallel2D(
module, self.parallel_context, mapping, module_args
)
elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D:
self.module = _TensorParallel2p5D(module, self.parallel_context, mapping)
self.module = _TensorParallel2p5D(
module, self.parallel_context, mapping, module_args
)
elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_3D:
self.module = _TensorParallel3D(module, self.parallel_context, mapping)
self.module = _TensorParallel3D(
module, self.parallel_context, mapping, module_args
)
else:
raise ValueError(
"currently, only 1d, 2d, 2p5d tensor parallelism is supported."
Expand Down

0 comments on commit 88ac33f

Please sign in to comment.