diff --git a/examples/hymba/fft-1.5b.yml b/examples/hymba/fft-1.5b.yml
new file mode 100644
index 0000000000..e11a08ae66
--- /dev/null
+++ b/examples/hymba/fft-1.5b.yml
@@ -0,0 +1,58 @@
+base_model: nvidia/Hymba-1.5B-Base
+
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: tatsu-lab/alpaca
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.05
+output_dir: ./outputs/out
+
+sequence_len: 2048
+sample_packing: true
+pad_to_sequence_len: true
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 2
+micro_batch_size: 2
+num_epochs: 1
+optimizer: paged_adamw_8bit
+lr_scheduler: cosine
+learning_rate: 2e-5
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+trust_remote_code: true
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+early_stopping_patience:
+resume_from_checkpoint:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 5
+evals_per_epoch: 2
+eval_table_size:
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: <|end_of_text|>
diff --git a/examples/hymba/qlora-1.5b.yml b/examples/hymba/qlora-1.5b.yml
new file mode 100644
index 0000000000..472f8706fb
--- /dev/null
+++ b/examples/hymba/qlora-1.5b.yml
@@ -0,0 +1,73 @@
+base_model: nvidia/Hymba-1.5B-Base
+
+load_in_8bit: false
+load_in_4bit: True
+strict: false
+
+datasets:
+ - path: tatsu-lab/alpaca
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.05
+output_dir: ./outputs/out
+
+sequence_len: 2048
+sample_packing: true
+pad_to_sequence_len: true
+
+adapter: qlora
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+lora_target_modules:
+ - gate_proj
+ - down_proj
+ - up_proj
+ - q_proj
+ - v_proj
+ - k_proj
+ - o_proj
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 2
+micro_batch_size: 2
+num_epochs: 1
+optimizer: paged_adamw_8bit
+lr_scheduler: cosine
+learning_rate: 2e-5
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+trust_remote_code: true
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+early_stopping_patience:
+resume_from_checkpoint:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 5
+evals_per_epoch: 2
+eval_table_size:
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: <|end_of_text|>
diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py
index 3ee89d2e5c..66e615516c 100644
--- a/src/axolotl/monkeypatch/multipack.py
+++ b/src/axolotl/monkeypatch/multipack.py
@@ -25,6 +25,7 @@
"gemmoe",
"starcoder2",
"deepseek_v2",
+ "hymba",
]
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index ffe5e24853..ac923164b5 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -31,6 +31,7 @@
"qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}",
"metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}",
+ "hymba": "{{'System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n ' + tool|tojson + ' ' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n ' + context.strip() + ' ' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ 'Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'Assistant\n'}}{%- endif %}",
}
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index 3671e1bb93..7420cbbbd0 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -1581,3 +1581,19 @@ def check_adopt_torch_version(cls, data):
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_hymba_torch_version(cls, data):
+ if "hymba" in data.get("base_model", {}).lower():
+ env_capabilities = data.get("env_capabilities", {})
+ torch_version = env_capabilities.get("torch_version")
+
+ if torch_version is None:
+ import torch
+
+ torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
+
+ if version.parse(torch_version) < version.parse("2.5.0"):
+ raise ValueError("Hymba requires torch version >= 2.5")
+ return data
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 11f4c6d0fe..3ad9a45d42 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -420,6 +420,7 @@ def apply_patches(self) -> None:
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
+ # some model config objects are not subscriptable
try:
auto_map_config = self.model_config["auto_map"]
except TypeError:
diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py
index 2317bfb97a..58e2493c25 100644
--- a/tests/e2e/test_optimizers.py
+++ b/tests/e2e/test_optimizers.py
@@ -67,8 +67,8 @@ def test_optimi_adamw(self, temp_dir):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
- @with_temp_dir
@require_torch_2_5_1
+ @with_temp_dir
def test_adopt_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py
index dd0af32f3c..43f623ca6c 100644
--- a/tests/e2e/test_packing_loss.py
+++ b/tests/e2e/test_packing_loss.py
@@ -14,7 +14,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
-from .utils import check_tensorboard, with_temp_dir
+from .utils import check_tensorboard, require_torch_2_5_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,3 +68,129 @@ def test_loss_packed(self, temp_dir):
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
+
+
+class TestPackedHymba(unittest.TestCase):
+ """
+ Test case for Packed training of hymba models
+ """
+
+ @require_torch_2_5_1
+ @with_temp_dir
+ def test_loss_packed(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "nvidia/Hymba-1.5B-Base",
+ "trust_remote_code": True,
+ "load_in_4bit": True,
+ "adapter": "qlora",
+ "lora_r": 32,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_modules": [
+ "gate_proj",
+ "down_proj",
+ "up_proj",
+ "q_proj",
+ "v_proj",
+ "k_proj",
+ "o_proj",
+ ],
+ "sequence_len": 1024,
+ "sample_packing": True,
+ "flash_attention": True,
+ "val_set_size": 0.0,
+ "datasets": [
+ {
+ "path": "vicgalle/alpaca-gpt4",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 5,
+ "use_tensorboard": True,
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+
+ check_tensorboard(
+ temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
+ )
+
+
+class TestUnpackedHymba(unittest.TestCase):
+ """
+ Test case for Unpacked training of hymba models
+ """
+
+ @require_torch_2_5_1
+ @with_temp_dir
+ def test_loss_unpacked(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "nvidia/Hymba-1.5B-Base",
+ "trust_remote_code": True,
+ "load_in_4bit": True,
+ "adapter": "qlora",
+ "lora_r": 32,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_modules": [
+ "gate_proj",
+ "down_proj",
+ "up_proj",
+ "q_proj",
+ "v_proj",
+ "k_proj",
+ "o_proj",
+ ],
+ "sequence_len": 1024,
+ "sample_packing": False,
+ "flash_attention": True,
+ "val_set_size": 0.0,
+ "datasets": [
+ {
+ "path": "vicgalle/alpaca-gpt4",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 5,
+ "use_tensorboard": True,
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+
+ check_tensorboard(
+ temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
+ )