-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel sharding #21
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
It will be used to adapt it for sharding. Only imports have been adapted, and only code relevant for GemmaForCausalLM has been added.
It seems that device_map parameter triggers a chain of calls that will try to use accelerate to load the model using less memory. The problem is that it skips the load state pre-hooks, making the weights loading impossible.
It will now be running in parallel. More changes to come.
This will lead to loading the model in bfloat16 when specified in the config.
941fdf2
to
2215595
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a couple of comments, I'll review the modeling file tomorrow!
f334bbd
to
fe888a9
Compare
API change when transformers was updated.
I wrongly chose the model's generation config instead of the one to the token selector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - Only concern about the explicit need to provide the torch_dtype in the from_pretrained
which I find a bit spurious but ok to merge and dig into another PR
@@ -56,7 +56,7 @@ def main(): | |||
model_id = "google/gemma-2b" | |||
torch_dtype = torch.bfloat16 | |||
|
|||
model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype) | |||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the torch_dtype=torch_dtype
? It should be taken from the config no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it doesn't look like it works this way:
>>> from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
print(model.config.torch_dtype)
print(model.model.layers[0].self_attn.o_proj.weight.dtype)
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████████████████| 2/2 [00:00<00:00, 2.65it/s]
>>> print(model.config.torch_dtype)
torch.bfloat16
>>> print(model.model.layers[0].self_attn.o_proj.weight.dtype)
torch.float32
optimum/tpu/distributed_model.py
Outdated
config = AutoConfig.from_pretrained(model_id) | ||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=config.torch_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have hard time to get why we need to do this way? We are overriding the default behaviour to the default behaviour no? @regisss do you know?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the default is to load in fp32 whatever the dtype specified in the config is: https://huggingface.slack.com/archives/C014N4749J9/p1712757959601599
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I got some insights on the design for this. It seems that transformers uses the default pytorch type, i.e.: torch.float32
. So probably I will need to change this code later, as it might not work if there are models whose weights were not trained in float32/bfloat16. I have seen we cannot use bf16 everywhere already, because some operations cannot be made (I've seen it in a unit test with gpt2). It is probably a custom configuration we need to add to the model. I pushed a fix cleaner than this.
bfloat16 will be set by default in gemma models, other models will still load in float32 by default.
What does this PR do?
This enables sharding on Gemma model, making it possible to load
google/gemma-7b
and do inference on it.TGI integration is yet to come but it should be done soon!
Before submitting