Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi2 rewrite #1058

Merged
merged 21 commits into from
Jan 8, 2024
Merged

Phi2 rewrite #1058

merged 21 commits into from
Jan 8, 2024

Conversation

winglian
Copy link
Collaborator

@winglian winglian commented Jan 7, 2024

@casper-hansen
Copy link
Collaborator

Looks good to me. I would remove the code you commented out. You can come back to the PR if you need to look up what changed.

Also, what are the possibilities of using sample packing with Phi2?

@winglian
Copy link
Collaborator Author

winglian commented Jan 7, 2024

@casper-hansen will clean up the commented out code

As far as sample packing, it should be pretty straightforward. I started working on a fix for the previous implementation #877 but I may simply start over.

You had mentioned last year figuring out a way to manage sample packing across all the architectures in a more manageable way. I'm happy to take a stab at it if you have a prrof of concept or anything.

@casper-hansen
Copy link
Collaborator

You had mentioned last year figuring out a way to manage sample packing across all the architectures in a more manageable way. I'm happy to take a stab at it if you have a prrof of concept or anything.

I had a branch going but didn't get to test and further implement it as I got busy with other stuff. The concept is to have one implementation that can be managed more easily managed in one module.

https://github.com/OpenAccess-AI-Collective/axolotl/tree/refactor-flash-attention

@@ -843,7 +844,14 @@ def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

return BatchSamplerDataCollatorForSeq2Seq(
if training_args.sample_packing:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend maybe consolidate the class?

data_collator = BatchSamplerDataCollatorForSeq2Seq if training_args.sample_packing else DataCollatorForSeq2Seq

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2024-01-08 at 12 11 10 PM

my IDE doesn't like that 😭

@winglian
Copy link
Collaborator Author

winglian commented Jan 8, 2024

alright, looks good on a single 4090 https://api.wandb.ai/links/oaaic/51qvcv4z

@winglian winglian merged commit 732851f into main Jan 8, 2024
6 checks passed
@winglian winglian deleted the phi2-rewrite branch January 8, 2024 19:04
@fakerybakery
Copy link

Hi, do I need to change any configuration options or just use the default ones w/ Phi 2?

djsaunde pushed a commit that referenced this pull request Dec 17, 2024
* restore to current phi modeling code from phi-2

* enable gradient checkpointing

* don't cast everything to float32 all the time

* gradient checkpointing for phi2 ParallelBlock module too

* fix enabling flash attn for phi2

* add comment about import

* fix phi2 example

* fix model type check for tokenizer

* revert float32 -> bf16 casting changes

* support fused dense flash attn

* fix the repo for flash-attn

* add package name for subdir pkg

* fix the data collator when not using sample packing

* install packaging for pytests in ci

* also fix setup to not install flash attn fused dense subdir if not extras

* split out the fused-dense-lib in extra requires

* don't train w group_by_length for phi

* update integration test to use phi2

* set max steps and save steps for phi e2e tests

* try to workaround ssave issue in ci

* skip phi2 e2e test for now
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants