-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic Model Parallelism Through FX #1933
Changes from 31 commits
5e39787
7a5d394
7e15d26
98e5846
2036dbb
34fffe8
0876f5d
87e66fb
ae6d9d2
455c0c7
27a9bb8
473388b
0512b23
8ec6727
5095f1e
f6ebfc0
e71e5ea
eb2a7a6
779c77d
e09df2a
22fe1a3
9fd29d1
01cfc25
8c16267
8ef00e0
6ef2081
2c561d3
fc96b6f
8d2cabb
c9c7571
97e6431
efd5d28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
name: Automatic Model Parallelism Test on GPUs | ||
|
||
on: | ||
pull_request: | ||
branches: | ||
- main | ||
paths: | ||
- 'optimum/fx/parallelization/**.py' | ||
push: | ||
branches: | ||
- main | ||
paths: | ||
- 'optimum/fx/parallelization/**.py' | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
run_gpu_tests: | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
config: | ||
- name: GPU-enabled Optimum Test Suite | ||
image: nvidia/cuda:12.4.1-devel-ubuntu22.04 | ||
gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"] | ||
|
||
name: ${{ matrix.config.name }} | ||
runs-on: | ||
group: "${{matrix.gpu_target}}" | ||
|
||
container: | ||
image: ${{ matrix.config.image }} | ||
options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/ | ||
env: | ||
NCCL_DEBUG: INFO | ||
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} | ||
defaults: | ||
run: | ||
shell: bash | ||
|
||
steps: | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.10' | ||
|
||
- name: Checkout optimum | ||
uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 1 | ||
|
||
- name: Run nvidia-smi | ||
run: | | ||
nvidia-smi | ||
|
||
- name: Install dependencies | ||
run: | | ||
python3 -m pip install -U pip | ||
python3 -m pip install torch transformers | ||
python3 -m pip install .[tests] | ||
|
||
- name: Run automatic model parallelism tests | ||
run: | | ||
pytest -s -v -o log_cli=true tests/fx/parallelization |
zhenglongjiepheonix marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .api import parallelize_backend, parallelize_model | ||
from .core import Config, ParallelExecutionCtx |
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,155 @@ | ||||||||||||||||||
# coding=utf-8 | ||||||||||||||||||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||||||||||||||||||
# | ||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||||||
# | ||||||||||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||
# | ||||||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||
# limitations under the License. | ||||||||||||||||||
import glob | ||||||||||||||||||
import importlib | ||||||||||||||||||
import json | ||||||||||||||||||
import os | ||||||||||||||||||
from functools import partial | ||||||||||||||||||
from typing import List, Optional, Union | ||||||||||||||||||
|
||||||||||||||||||
import torch | ||||||||||||||||||
from torch.fx import GraphModule | ||||||||||||||||||
|
||||||||||||||||||
from .core import Config, ParallelExecutionCtx | ||||||||||||||||||
from .passes import build_parallel_pass_pipeline | ||||||||||||||||||
from .utils import ( | ||||||||||||||||||
MetaAwareMethodsPatcher, | ||||||||||||||||||
convert_bin_to_safetensors, | ||||||||||||||||||
download_files_from_hf, | ||||||||||||||||||
initialize_parameter_meta, | ||||||||||||||||||
move_model_to_device, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def parallelize_backend( | ||||||||||||||||||
graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config | ||||||||||||||||||
) -> GraphModule: | ||||||||||||||||||
ctx.example_inputs = example_inputs | ||||||||||||||||||
pass_pipeline = build_parallel_pass_pipeline() | ||||||||||||||||||
graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) | ||||||||||||||||||
ctx.compile_times += 1 | ||||||||||||||||||
ctx.last_optimized_graph_module = graph_module | ||||||||||||||||||
return graph_module | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def parallelize_model( | ||||||||||||||||||
model: Union[torch.nn.Module, str], | ||||||||||||||||||
parallel_ctx: ParallelExecutionCtx, | ||||||||||||||||||
*model_args, | ||||||||||||||||||
revision: str = "main", | ||||||||||||||||||
cache_dir: Optional[str] = None, | ||||||||||||||||||
local_files_only: bool = False, | ||||||||||||||||||
skip_load_weights: bool = False, | ||||||||||||||||||
**kwargs, | ||||||||||||||||||
): | ||||||||||||||||||
""" | ||||||||||||||||||
API for automatic model parallelism through Pytorch FX. | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
model (Union[torch.nn.Module, str]): | ||||||||||||||||||
Model to parallelize, could either be a module or a model id in huggingface space. | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
parallel_ctx (ParallelExecutionCtx): | ||||||||||||||||||
Parallel execution context containing process groups the current process belongs to. | ||||||||||||||||||
model_args (additional postional arguments, optional): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Should we add also |
||||||||||||||||||
Additional postional arguments for intializing the model if a model id is passed. | ||||||||||||||||||
revision (`str`, defaults to `main`): | ||||||||||||||||||
Model revision for weights downloading if a model id is passed. | ||||||||||||||||||
cache_dir (`Optional[str]`, defaults to `None`): | ||||||||||||||||||
Cache directory to store downloaded weights. Defaults to None. | ||||||||||||||||||
local_files_only (`bool`, defaults to `False`): | ||||||||||||||||||
Whether to use local files only, will avoid downloading from remote if set to `True`. | ||||||||||||||||||
skip_load_weights (`bool`, defaults to `False`): | ||||||||||||||||||
Whether to skip loading weights from disk to model. | ||||||||||||||||||
kwargs (additional keyword arguments, optional): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We provide a lot of things here. |
||||||||||||||||||
Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. | ||||||||||||||||||
""" | ||||||||||||||||||
parallel_config = Config() | ||||||||||||||||||
for k, v in kwargs.items(): | ||||||||||||||||||
if k in parallel_config.__dict__: | ||||||||||||||||||
setattr(parallel_config, k, v) | ||||||||||||||||||
kwargs = {k: v for k, v in kwargs.items() if k not in parallel_config.__dict__} | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can also iterate on a copy of
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
if isinstance(model, str): | ||||||||||||||||||
from transformers import AutoConfig | ||||||||||||||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME | ||||||||||||||||||
|
||||||||||||||||||
is_local = os.path.isdir(model) | ||||||||||||||||||
allow_patterns = ["*.safetensors", "*.bin"] | ||||||||||||||||||
if not is_local: | ||||||||||||||||||
hf_folder = download_files_from_hf( | ||||||||||||||||||
model_name_or_path=model, | ||||||||||||||||||
cache_dir=cache_dir, | ||||||||||||||||||
allow_patterns=allow_patterns, | ||||||||||||||||||
revision=revision, | ||||||||||||||||||
local_files_only=local_files_only, | ||||||||||||||||||
skip_download_weights=skip_load_weights, | ||||||||||||||||||
) | ||||||||||||||||||
else: | ||||||||||||||||||
hf_folder = model | ||||||||||||||||||
|
||||||||||||||||||
# should be able to load config using only local files | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No because you only allowed patterns to be safetensors and bin files, and config is a json. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here I move all the dowload logic including config and index files into |
||||||||||||||||||
model_config, kwargs = AutoConfig.from_pretrained( | ||||||||||||||||||
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
# try getting model class info from config | ||||||||||||||||||
model_arch = model_config.architectures | ||||||||||||||||||
model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) | ||||||||||||||||||
|
||||||||||||||||||
if not skip_load_weights: | ||||||||||||||||||
use_safetensors = False | ||||||||||||||||||
for pattern in allow_patterns: | ||||||||||||||||||
if len(glob.glob(os.path.join(hf_folder, pattern))) > 0: | ||||||||||||||||||
use_safetensors = pattern == "*.safetensors" | ||||||||||||||||||
break | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can be simplified. |
||||||||||||||||||
index_path = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME) | ||||||||||||||||||
if os.path.isfile(index_path): | ||||||||||||||||||
with open(index_path) as f: | ||||||||||||||||||
index_dict = json.load(f) | ||||||||||||||||||
parallel_ctx.weight_map = {k: os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()} | ||||||||||||||||||
weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) | ||||||||||||||||||
if not use_safetensors: | ||||||||||||||||||
weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {} | ||||||||||||||||||
convert_bin_to_safetensors(model, cache_dir, weight_files, weight_map) | ||||||||||||||||||
parallel_ctx.weight_map = weight_map | ||||||||||||||||||
|
||||||||||||||||||
# try directly construct weight_map from weight files, should have safetensors file on disk in any case | ||||||||||||||||||
if not parallel_ctx.weight_map: | ||||||||||||||||||
from safetensors import safe_open | ||||||||||||||||||
|
||||||||||||||||||
weight_map, weight_files = {}, glob.glob(os.path.join(hf_folder, "*.safetensors")) | ||||||||||||||||||
for weight_file in weight_files: | ||||||||||||||||||
with safe_open(filename=weight_file, framework="pt") as f: | ||||||||||||||||||
for key in f.keys(): | ||||||||||||||||||
weight_map[key] = weight_file | ||||||||||||||||||
parallel_ctx.weight_map = weight_map | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think overall it can be simplified. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I move the logic into |
||||||||||||||||||
|
||||||||||||||||||
torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None | ||||||||||||||||||
if torch_dtype is not None: | ||||||||||||||||||
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) | ||||||||||||||||||
|
||||||||||||||||||
with MetaAwareMethodsPatcher(): | ||||||||||||||||||
model = model_cls(model_config, *model_args, **kwargs) | ||||||||||||||||||
# TODO: remove this once support training-time trace | ||||||||||||||||||
model.eval() | ||||||||||||||||||
|
||||||||||||||||||
if dtype_orig is not None: | ||||||||||||||||||
torch.set_default_dtype(dtype_orig) | ||||||||||||||||||
|
||||||||||||||||||
move_model_to_device(model, device=parallel_ctx.current_device) | ||||||||||||||||||
initialize_parameter_meta(model) | ||||||||||||||||||
backend = partial(parallelize_backend, ctx=parallel_ctx, config=parallel_config) | ||||||||||||||||||
model = torch.compile(model, fullgraph=True, backend=backend) | ||||||||||||||||||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhenglongjiepheonix @michaelbenayoun is
HF_TOKEN
used for the tests (can't see where) or can we remove ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not used, you can remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed in #2061