Skip to content

Commit

Permalink
Phi2 rewrite (#1058)
Browse files Browse the repository at this point in the history
* restore to current phi modeling code from phi-2

* enable gradient checkpointing

* don't cast everything to float32 all the time

* gradient checkpointing for phi2 ParallelBlock module too

* fix enabling flash attn for phi2

* add comment about import

* fix phi2 example

* fix model type check for tokenizer

* revert float32 -> bf16 casting changes

* support fused dense flash attn

* fix the repo for flash-attn

* add package name for subdir pkg

* fix the data collator when not using sample packing

* install packaging for pytests in ci

* also fix setup to not install flash attn fused dense subdir if not extras

* split out the fused-dense-lib in extra requires

* don't train w group_by_length for phi

* update integration test to use phi2

* set max steps and save steps for phi e2e tests

* try to workaround ssave issue in ci

* skip phi2 e2e test for now
  • Loading branch information
winglian authored Jan 8, 2024
1 parent 9ca358b commit 732851f
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 97 deletions.
73 changes: 73 additions & 0 deletions examples/phi/phi2-ft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
base_model: microsoft/phi-2
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca

dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out

sequence_len: 2048
sample_packing: false # currently unsupported
pad_to_sequence_len:

adapter:
lora_model_dir:
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embd
- lm_head

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: paged_adamw_8bit
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 1e-5

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 100
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ fire
PyYAML>=6.0
datasets>=2.15.0
flash-attn==2.3.3
fused-dense-lib @ git+https://github.com/Dao-AILab/[email protected]#subdirectory=csrc/fused_dense_lib
sentencepiece
wandb
einops
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def parse_requirements():
_dependency_links.append(url)
elif (
"flash-attn" not in line
and "flash-attention" not in line
and "deepspeed" not in line
and line
and line[0] != "#"
Expand Down Expand Up @@ -51,6 +52,9 @@ def parse_requirements():
"flash-attn": [
"flash-attn==2.3.3",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/[email protected]#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed",
],
Expand Down
10 changes: 9 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler
Expand Down Expand Up @@ -843,7 +844,14 @@ def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

return BatchSamplerDataCollatorForSeq2Seq(
if training_args.sample_packing:
return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)

return DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
Expand Down
Loading

0 comments on commit 732851f

Please sign in to comment.