-
-
Notifications
You must be signed in to change notification settings - Fork 896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ia3 peft support #601
base: main
Are you sure you want to change the base?
add ia3 peft support #601
Conversation
we can support 4-bit IA3 once huggingface/peft#864 is merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not used IA3 before, but here's my comments from looking at the linked PR.
src/axolotl/utils/models.py
Outdated
if ( | ||
(cfg.adapter == "lora" and cfg.load_in_8bit) | ||
or (cfg.adapter == "qlora" and cfg.load_in_4bit) | ||
or (cfg.adapter == "ia3" and cfg.load_in_8bit) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second point, is ia3 load_in_8bit or 4bit? The linked PR seems to be 4bit addition but also support 8bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's 8 bit only for now. I added some checks to warn in the config validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will need to run this myself when I have time to verify since there's a lot of changes.
src/axolotl/utils/models.py
Outdated
@@ -450,11 +452,11 @@ def load_llama_adapter(model, cfg): | |||
task_type="CAUSAL_LM", | |||
) | |||
|
|||
if cfg.lora_model_dir: | |||
if cfg.peft_model_dir or cfg.lora_model_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're updating to peft_model_dir
, we could add a deprecation warning to validate config to reduce need for checking both like this line.
For backward compatibility, we can assign cfg.peft_model_dir = cfg.lora_model_dir
if it's not None.
README.md
Outdated
@@ -519,6 +519,9 @@ lora_modules_to_save: | |||
# - lm_head | |||
lora_out_dir: | |||
lora_fan_in_fan_out: false | |||
ia3_target_modules: # target modules for IA3, for llama, k, v, and down projections | |||
ia3_feedforward_modules: # ffn modules for IA3, for llama down projection | |||
ia3_fan_in_fan_out: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The target modules
and fan in fan out
feels a bit redundant since we have two similar names..
Hi - I wrote the PR for 4-bit IA3 - I adjusted my own installation of Axolotl to support IA3 (I didn't submit a PR as it was a hack based on rewriting existing LoRA support, which naturally broke it for LoRA purposes, and as I have literally no training or experience in coding, I wasn't confident in adding a new functionality without breaking everything else) and found IA3 ran properly for training with no other major changes required. Comparing my hack to this PR, the changes here seem near identical - fortunately I have found IA3 and LoRA are mostly interchangeable from a code perspective, there weren't any misleading adjustments I had to make to get it working. I did not test loading, inference or merging weights in Axolotl using IA3, as I did these tasks using my own scripts or adaptions of existing scripts. The only points I would raise are that:
|
Just a reminder, huggingface/peft#864 has been merged. |
Co-authored-by: NanoCode012 <[email protected]>
I was trying PEFT's version of IA3 back in early August and it would not work, regardless of what I tried. I'm curious to see what this will produce and will test it as soon as I can. |
|
||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. | ||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. | ||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities. | ||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994 | ||
lora_modules_to_save: | ||
peft_modules_to_save: | ||
# - embed_tokens | ||
# - lm_head | ||
|
||
# Once you complete training, the model will be saved to the following directory. | ||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory. | ||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model. | ||
lora_out_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you may have missed this variable
@@ -151,6 +153,13 @@ def flashattn_forward( | |||
key_states = self.k_proj(hidden_states) | |||
value_states = self.v_proj(hidden_states) | |||
|
|||
if query_states.dtype == torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a comment to explain this casting?
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") | ||
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning") | ||
|
||
if not cfg.load_in_8bit and cfg.adapter == "ia3": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can consolidate the checks here into one.
cfg.adapter in ["lora", "ia3"]
|
||
if "lm_head" in lora_module_names: # needed for 16-bit | ||
lora_module_names.remove("lm_head") | ||
if "lm_head" in peft_module_names: # needed for 16-bit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to add a log if done so and user explicitly set this.
@winglian any updates on this? |
No description provided.