diff --git a/README.md b/README.md
index d502eec0b5..3a1eb0cd7a 100644
--- a/README.md
+++ b/README.md
@@ -413,9 +413,10 @@ tokenizer_legacy:
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
-# used to identify if the model is falcon/llama based
+# used to identify which the model is based on
is_falcon_derived_model:
is_llama_derived_model:
+is_mistral_derived_model:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml
new file mode 100644
index 0000000000..d199f947be
--- /dev/null
+++ b/examples/mistral/config.yml
@@ -0,0 +1,62 @@
+base_model: mistralai/Mistral-7B-v0.1
+base_model_config: mistralai/Mistral-7B-v0.1
+model_type: MistralForCausalLM
+tokenizer_type: LlamaTokenizer
+is_mistral_derived_model: true
+
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+output_dir: ./out
+
+sequence_len: 8192
+sample_packing:
+pad_to_sequence_len:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 3
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+eval_table_size: 5
+eval_table_max_new_tokens: 128
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py
index ac067b5055..67f9490c47 100644
--- a/src/axolotl/utils/config.py
+++ b/src/axolotl/utils/config.py
@@ -82,7 +82,7 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
- or "llama" in cfg.base_model
+ or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
@@ -98,10 +98,23 @@ def normalize_config(cfg):
]
)
or cfg.is_falcon_derived_model
- or "falcon" in cfg.base_model
+ or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
)
+ cfg.is_mistral_derived_model = (
+ (
+ hasattr(model_config, "model_type")
+ and model_config.model_type
+ in [
+ "mistral",
+ ]
+ )
+ or cfg.is_mistral_derived_model
+ or "mistral" in cfg.base_model.lower()
+ or (cfg.model_type and "mistral" in cfg.model_type.lower())
+ )
+
log_gpu_memory_usage(LOG, "baseline", cfg.device)