Skip to content

Commit

Permalink
Fix sample_packing support for wandb prediction table
Browse files Browse the repository at this point in the history
  • Loading branch information
Glavin001 committed Sep 9, 2023
1 parent e9eae77 commit 83e6b29
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"localRoot": "${workspaceFolder}",
"remoteRoot": "/workspace/axolotl/"
}
]
],
"justMyCode": false
},
{
"name": "train",
Expand Down
26 changes: 15 additions & 11 deletions examples/llama-2/tiny-random.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,27 @@ strict: false
datasets:
# - path: mhenrichsen/alpaca_2k_test
# type: alpaca
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
# - path: Glavin001/startup-interviews
# - path: teknium/GPT4-LLM-Cleaned
# type: alpaca
- path: Glavin001/startup-interviews
type: alpaca
dataset_prepared_path: last_run_prepared
# val_set_size: 0.01
val_set_size: 0.001
val_set_size: 0.02
# val_set_size: 0.05
# val_set_size: 0.001
# val_set_size: 0.1
# output_dir: ./lora-out
# output_dir: ./lora-2-out
output_dir: ./lora-5-out
output_dir: ./lora-6-out

# sequence_len: 4096
sequence_len: 2048
# sequence_len: 2048
# sequence_len: 256
# sequence_len: 512
# sample_packing: true
sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py
sequence_len: 1024
sample_packing: true
# sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py

adapter: lora
lora_model_dir:
Expand All @@ -60,8 +63,9 @@ micro_batch_size: 16
# num_epochs: 3
# num_epochs: 0.001
# num_epochs: 0.01
num_epochs: 1
# num_epochs: 1
# num_epochs: 5
num_epochs: 10
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
Expand All @@ -81,9 +85,9 @@ xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 10
# eval_steps: 10
# eval_steps: 20
# eval_steps: 2
eval_steps: 2
# eval_steps: 1
save_steps:
debug:
Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape

if cu_seqlens is not None and max_seqlen is not None:
# if cu_seqlens is not None and max_seqlen is not None:
# if cu_seqlens is not None and max_seqlen is not None and self.training:
# if cu_seqlens is not None and max_seqlen is not None and query_states.shape == key_states.shape:
# if cu_seqlens is not None and max_seqlen is not None and len(cu_seqlens[0]) > 2:
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
Expand Down
5 changes: 3 additions & 2 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,9 @@ def log_table_from_dataloader(name: str, table_dataloader):

with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=32,
# repetition_penalty=1.1,
max_new_tokens=128,
# max_new_tokens=32,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
Expand Down

0 comments on commit 83e6b29

Please sign in to comment.