Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Sep 20, 2024
1 parent b5b371f commit 0ff39bb
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion optimum/fx/parallelization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parallelize_model(
Parallel execution context containing process groups the current process belongs to.
*model_args (`Any`):
Additional postional arguments for intializing the model if a model id is passed.
model_id_or_path (`str`):
model_id_or_path (`Optional[str]`, defaults to `None`):
Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights
of the model.
model_cls (`Optional[Type[PreTrainedModel]]`, defaults to `None`):
Expand Down
2 changes: 1 addition & 1 deletion optimum/fx/parallelization/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BackEnd, DefaultBackend
from .base import Backend, DefaultBackend
6 changes: 3 additions & 3 deletions optimum/fx/parallelization/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)


class BackEnd(ABC):
class Backend(ABC):
@abstractmethod
def create_column_parallel_linear(
self,
Expand Down Expand Up @@ -85,7 +85,7 @@ def create_parallel_cross_entropy(
def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule:
"""
Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our
passes don't.
passes do not.
"""
parameter_mp = {}
for name, tensor in graph_module.named_parameters(remove_duplicate=False):
Expand Down Expand Up @@ -121,7 +121,7 @@ def init_parallelization_pass_pipeline(
)


class DefaultBackend(BackEnd):
class DefaultBackend(Backend):
def create_column_parallel_linear(
self,
mod: nn.Linear,
Expand Down
6 changes: 3 additions & 3 deletions optimum/fx/parallelization/backend/nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from torch.fx import GraphModule

from ..core import Config, ParallelExecutionCtx, ParameterMeta
from .base import BackEnd
from .base import Backend


# Check if nanotron is installed
_nanotron_available = importlib.util.find_spec("nanotron") is not None

if TYPE_CHECKING:
if TYPE_CHECKING and _nanotron_available:
from nanotron.config import Config as NanotronConfig
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.nn import (
Expand All @@ -38,7 +38,7 @@
)


class NanotronBackend(BackEnd):
class NanotronBackend(Backend):
"""
Backend class which glues optimum fx parallelization context and nanotron context.
"""
Expand Down
6 changes: 3 additions & 3 deletions optimum/fx/parallelization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


if TYPE_CHECKING:
from .backend import BackEnd
from .backend import Backend


class HashableSlice:
Expand Down Expand Up @@ -115,7 +115,7 @@ class ParallelExecutionCtx:
- current_device (`torch.device`):
Device correpsonding to the current process.
- backend (`Optional[BackEnd]`, defaults to `None`):
- backend (`Optional[Backend]`, defaults to `None`):
Backend instance which converts layers into their parallelized counterparts.
- example_inputs (`List[Any]`):
Expand Down Expand Up @@ -146,7 +146,7 @@ class ParallelExecutionCtx:

tp_group: dist.ProcessGroup
current_device: torch.device
backend: Optional["BackEnd"] = None
backend: Optional["Backend"] = None
example_inputs: List[Any] = field(default_factory=list)
parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict)
param_cache: Dict[str, nn.Parameter] = field(default_factory=dict)
Expand Down

0 comments on commit 0ff39bb

Please sign in to comment.