Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Llama 3.1 and 3.2 fine tuning #114

Open
DimensionSTP opened this issue Nov 19, 2024 · 3 comments
Open

Support for Llama 3.1 and 3.2 fine tuning #114

DimensionSTP opened this issue Nov 19, 2024 · 3 comments

Comments

@DimensionSTP
Copy link

DimensionSTP commented Nov 19, 2024

Hello,

I am deeply interested in your Optimum-TPU project.
Currently, I am planning to fine-tune the Llama 3.1 and 3.2 models on my native language and a specific domain, with a fairly large dataset (approximately 60B tokens).
I am using Google TPU Pods, but I have been facing significant challenges in implementing model parallel training from scratch, saving unified checkpoints in the safetensors format, setting up appropriate logging, and configuring hyperparameters.

While exploring solutions, I came across the Optimum-TPU project, which seems incredibly useful. However, I noticed that it currently only supports up to Llama 3.
Are there any plans to extend support to Llama 3.1 and 3.2 for fine-tuning?
I strongly hope that future updates will include support for these versions as well.

Thank you for considering this request!

@tengomucho
Copy link
Collaborator

tengomucho commented Nov 19, 2024

Hi @DimensionSTP !
We do not support Llama 3.1 or 3.2 yet, but we should add that support before the end of the year.
Having said that, if all you want is to fine-tune these models, you can probably just follow the example steps in our Llama fine tuning example and it should work (though this is untested yet).
For serving/inference you would still need to a better support for sharding, but for fine-tuning it should be fine.

@DimensionSTP
Copy link
Author

Hello again,

Thank you for your response and clarification regarding the support for Llama 3.1 and 3.2.

I attempted to follow the Llama fine-tuning example provided in your comment, but unfortunately, I encountered an issue related to the rope_scaling configuration.

In Llama 3, the configuration includes:
"rope_scaling": null,
"rope_theta": 500000.0

However, in Llama 3.1 and 3.2, the rope_scaling field differs significantly:

Llama 3.1:
"rope_scaling": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0

Llama 3.2:
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0

When attempting to fine-tune these models, I encountered the following error:
ValueError: rope_scaling must be a dictionary with two fields, type and factor, got {'factor': 8.0, 'high_freq_factor': 4.0, 'low_freq_factor': 1.0, 'original_max_position_embeddings': 8192, 'rope_type': 'llama3'}

I tried this both with the Transformers version currently supported by Optimum-TPU and after upgrading Transformers to the latest version, but the same error persisted. It seems that Optimum-TPU relies on an older version of Transformers and may not fully support the more complex rope_scaling configurations introduced in Llama 3.1 and 3.2

Would it be possible to update Optimum-TPU to handle these changes in rope_scaling? Alternatively, could you provide guidance on modifying the library locally to accommodate the new configuration while waiting for official support?

Thank you for your efforts in maintaining this amazing project. I'm looking forward to the updates!

@tengomucho
Copy link
Collaborator

Hi @DimensionSTP, as I told you before, working with Llama 3.1 is on the roadmap, and we will work on this soon.
As for the error you see, I searched the transformers code and I see this error used to be there for Llama models, but it has gone away during a refactor and it's handled differently from v4.43.0 and newer.
So I think you have to re-try and make sure you are using a newer transformers version, and that should at least do something related to this problem, hopefully make it go away completely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants