From 83e6b29fe896c4af54fe7c3572f7ed6c15874f53 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Sat, 9 Sep 2023 06:48:41 +0000 Subject: [PATCH] Fix sample_packing support for wandb prediction table --- .vscode/launch.json | 3 ++- examples/llama-2/tiny-random.yml | 26 +++++++++++-------- .../monkeypatch/llama_attn_hijack_flash.py | 6 ++++- src/axolotl/utils/callbacks.py | 5 ++-- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index e116653768..e264f9d69f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -17,7 +17,8 @@ "localRoot": "${workspaceFolder}", "remoteRoot": "/workspace/axolotl/" } - ] + ], + "justMyCode": false }, { "name": "train", diff --git a/examples/llama-2/tiny-random.yml b/examples/llama-2/tiny-random.yml index 4a841f4609..138aab2963 100644 --- a/examples/llama-2/tiny-random.yml +++ b/examples/llama-2/tiny-random.yml @@ -19,24 +19,27 @@ strict: false datasets: # - path: mhenrichsen/alpaca_2k_test # type: alpaca - - path: teknium/GPT4-LLM-Cleaned - type: alpaca - # - path: Glavin001/startup-interviews + # - path: teknium/GPT4-LLM-Cleaned # type: alpaca + - path: Glavin001/startup-interviews + type: alpaca dataset_prepared_path: last_run_prepared # val_set_size: 0.01 -val_set_size: 0.001 +val_set_size: 0.02 +# val_set_size: 0.05 +# val_set_size: 0.001 # val_set_size: 0.1 # output_dir: ./lora-out # output_dir: ./lora-2-out -output_dir: ./lora-5-out +output_dir: ./lora-6-out # sequence_len: 4096 -sequence_len: 2048 +# sequence_len: 2048 # sequence_len: 256 # sequence_len: 512 -# sample_packing: true -sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py +sequence_len: 1024 +sample_packing: true +# sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py adapter: lora lora_model_dir: @@ -60,8 +63,9 @@ micro_batch_size: 16 # num_epochs: 3 # num_epochs: 0.001 # num_epochs: 0.01 -num_epochs: 1 +# num_epochs: 1 # num_epochs: 5 +num_epochs: 10 optimizer: adamw_bnb_8bit lr_scheduler: cosine learning_rate: 0.0002 @@ -81,9 +85,9 @@ xformers_attention: flash_attention: true warmup_steps: 10 -eval_steps: 10 +# eval_steps: 10 # eval_steps: 20 -# eval_steps: 2 +eval_steps: 2 # eval_steps: 1 save_steps: debug: diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 39cfb5c173..d90d5e5497 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -194,7 +194,11 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = key_states.shape == query_states.shape - if cu_seqlens is not None and max_seqlen is not None: + # if cu_seqlens is not None and max_seqlen is not None: + # if cu_seqlens is not None and max_seqlen is not None and self.training: + # if cu_seqlens is not None and max_seqlen is not None and query_states.shape == key_states.shape: + # if cu_seqlens is not None and max_seqlen is not None and len(cu_seqlens[0]) > 2: + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: # special handling using sample packing qkv = torch.stack( [query_states, key_states, value_states], dim=2 diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 6ee0243324..8ab19707bb 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -438,8 +438,9 @@ def log_table_from_dataloader(name: str, table_dataloader): with torch.no_grad(): generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=32, + # repetition_penalty=1.1, + max_new_tokens=128, + # max_new_tokens=32, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,