-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic Model Parallelism Through FX #1933
Automatic Model Parallelism Through FX #1933
Conversation
…ic_model_parallel_via_fx
…ic_model_parallel_via_fx
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is super cool!
rank = dist.get_rank(group = group) | ||
|
||
tensor = tensor.contiguous() | ||
tensors = [torch.empty_like(tensor) for _ in range(world_size)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not tried it, but may be https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_into_tensor is more efficient with a single empty
call, like https://github.com/huggingface/text-generation-inference/blob/d0225b10156320f294647ac676c130d03626473d/server/text_generation_server/layers/tensor_parallel.py#L98
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
size = tensor.size() | ||
assert size[split_dim] % world_size == 0 | ||
tensors = torch.split(tensor, size[split_dim] // world_size, dim = split_dim) | ||
tensor = tensors[rank].contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why contiguous?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensors after split
may not be contiguous, I think it's better be contiguous
self.bias.zero_() | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
input = differentiable_identity(input, self.process_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need an identity here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to take care of gradient reduce in backward
optimum/fx/parallelization/passes.py
Outdated
self.clear_marker_per_node(node) | ||
|
||
|
||
class ParallelLinearAnnotatePass(AnalyzeBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get the name parallel here? Isn't it more like successive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it actually means annotate some layers to be their parallel counterparts
tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size) | ||
|
||
# check results at main worker process | ||
if rank == 0: | ||
assert len(tensors) == world_size | ||
for i in range(1, world_size): | ||
torch.testing.assert_close(tensors[i - 1].cpu(), tensors[i].cpu(), rtol=1e-4, atol=1e-4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should probably be checked on all ranks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check at the main process should be enough, because it gathers results from other ranks at main process and does comparison
move_model_to_device(model, device=device) | ||
initialize_parameter_mapping(model, ctx=ctx) | ||
|
||
model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we compose with inductor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite confident to say now, but at least it won't be able single graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that is the hope for the future.
…ic_model_parallel_via_fx
…ic_model_parallel_via_fx
optimum/fx/parallelization/api.py
Outdated
model (Union[torch.nn.Module, str]): | ||
Model to parallelize, could either be a module or a model id in huggingface space. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model (Union[torch.nn.Module, str]): | |
Model to parallelize, could either be a module or a model id in huggingface space. | |
model (Union[torch.nn.Module, str]): | |
Model to parallelize, could either be a module or a model id on the Hugging Face Hub. |
optimum/fx/parallelization/api.py
Outdated
Model to parallelize, could either be a module or a model id in huggingface space. | ||
parallel_ctx (ParallelExecutionCtx): | ||
Parallel execution context containing process groups the current process belongs to. | ||
model_args (additional postional arguments, optional): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_args (additional postional arguments, optional): | |
*model_args (Any): |
Should we add also model_kwargs
?
optimum/fx/parallelization/api.py
Outdated
Whether to use local files only, will avoid downloading from remote if set to `True`. | ||
skip_load_weights (`bool`, defaults to `False`): | ||
Whether to skip loading weights from disk to model. | ||
kwargs (additional keyword arguments, optional): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwargs (additional keyword arguments, optional): | |
**kwargs (Dict[str, Any]): |
optimum/fx/parallelization/api.py
Outdated
cache_dir (`Optional[str]`, defaults to `None`): | ||
Cache directory to store downloaded weights. Defaults to None. | ||
local_files_only (`bool`, defaults to `False`): | ||
Whether to use local files only, will avoid downloading from remote if set to `True`. | ||
skip_load_weights (`bool`, defaults to `False`): | ||
Whether to skip loading weights from disk to model. | ||
kwargs (additional keyword arguments, optional): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We provide a lot of things here.
IMO we should simplify that. Most of these arguments come from the from_pretrained
method.
So I would gather them as one keyword argument: model_kwargs
.
optimum/fx/parallelization/api.py
Outdated
for k, v in kwargs.items(): | ||
if k in parallel_config.__dict__: | ||
setattr(parallel_config, k, v) | ||
kwargs = {k: v for k, v in kwargs.items() if k not in parallel_config.__dict__} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also iterate on a copy of kwargs
and pop elements as follows:
for k, v in kwargs.items(): | |
if k in parallel_config.__dict__: | |
setattr(parallel_config, k, v) | |
kwargs = {k: v for k, v in kwargs.items() if k not in parallel_config.__dict__} | |
for k, v in dict(kwargs).items(): | |
if k in parallel_config.__dict__: | |
setattr(parallel_config, k, v) | |
kwargs.pop(k) |
else: | ||
hf_folder = model | ||
|
||
# should be able to load config using only local files |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because you only allowed patterns to be safetensors and bin files, and config is a json.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here I move all the dowload logic including config and index files into download_model_from_hf
optimum/fx/parallelization/api.py
Outdated
use_safetensors = False | ||
for pattern in allow_patterns: | ||
if len(glob.glob(os.path.join(hf_folder, pattern))) > 0: | ||
use_safetensors = pattern == "*.safetensors" | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be simplified.
optimum/fx/parallelization/api.py
Outdated
index_path = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME) | ||
if os.path.isfile(index_path): | ||
with open(index_path) as f: | ||
index_dict = json.load(f) | ||
parallel_ctx.weight_map = {k: os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()} | ||
weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) | ||
if not use_safetensors: | ||
weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {} | ||
convert_bin_to_safetensors(model, cache_dir, weight_files, weight_map) | ||
parallel_ctx.weight_map = weight_map | ||
|
||
# try directly construct weight_map from weight files, should have safetensors file on disk in any case | ||
if not parallel_ctx.weight_map: | ||
from safetensors import safe_open | ||
|
||
weight_map, weight_files = {}, glob.glob(os.path.join(hf_folder, "*.safetensors")) | ||
for weight_file in weight_files: | ||
with safe_open(filename=weight_file, framework="pt") as f: | ||
for key in f.keys(): | ||
weight_map[key] = weight_file | ||
parallel_ctx.weight_map = weight_map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think overall it can be simplified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I move the logic into utils
so it looks cleaner in api
, but the logic itself is indeed complex, because we need to take care of situations where a local directory is passed, then the only thing we can do is to try to peek inside the folder and see if there are safetensors/bin files, and if there are only bin files, we need to convert them into safetensors, and if there is an index file, we load the weight_map
directly from it, otherwise we try scanning all the weight files in the folder and assemble a weight_map
out of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I am getting everything because it is a very long and complex PR but LGTM!
Let's iterate on smaller PRs from now on.
Thanks @zhenglongjiepheonix !
Merge this as experimental first version, more fix-ups and features coming in following PRs! |
options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/ | ||
env: | ||
NCCL_DEBUG: INFO | ||
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhenglongjiepheonix @michaelbenayoun is HF_TOKEN
used for the tests (can't see where) or can we remove ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not used, you can remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed in #2061
What does this PR do?
This PR tries to add an automatic parallelization backend for torch dynamo, which takes the dynamo-captured fx graph, runs a few passes to automatically identify parts that can be parallelized and transforms the graph into its parallelized version. For simplicity it focuses on models supporting dynamo tracing in transformers library right now and might not support custom models because of the tricky parts in parallel pattern matching.
For now it only supports parallelization of linears in the graph, in the context of transformers they would be attention layers and mlp layers, with the following milestones left:
Please feel free to review and provide suggestions even if it's still in progress and not covering all features.According to @michaelbenayoun , we should try merging the first version and iterations will be coming in following PRs