diff --git a/README.md b/README.md index d502eec0b5..3a1eb0cd7a 100644 --- a/README.md +++ b/README.md @@ -413,9 +413,10 @@ tokenizer_legacy: # this is reported to improve training speed on some models resize_token_embeddings_to_32x: -# used to identify if the model is falcon/llama based +# used to identify which the model is based on is_falcon_derived_model: is_llama_derived_model: +is_mistral_derived_model: # whether you are training a 4-bit GPTQ quantized model gptq: true diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml new file mode 100644 index 0000000000..d199f947be --- /dev/null +++ b/examples/mistral/config.yml @@ -0,0 +1,62 @@ +base_model: mistralai/Mistral-7B-v0.1 +base_model_config: mistralai/Mistral-7B-v0.1 +model_type: MistralForCausalLM +tokenizer_type: LlamaTokenizer +is_mistral_derived_model: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./out + +sequence_len: 8192 +sample_packing: +pad_to_sequence_len: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 5 +eval_table_max_new_tokens: 128 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index ac067b5055..67f9490c47 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -82,7 +82,7 @@ def normalize_config(cfg): cfg.is_llama_derived_model = ( (hasattr(model_config, "model_type") and model_config.model_type == "llama") or cfg.is_llama_derived_model - or "llama" in cfg.base_model + or "llama" in cfg.base_model.lower() or (cfg.model_type and "llama" in cfg.model_type.lower()) ) @@ -98,10 +98,23 @@ def normalize_config(cfg): ] ) or cfg.is_falcon_derived_model - or "falcon" in cfg.base_model + or "falcon" in cfg.base_model.lower() or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower()) ) + cfg.is_mistral_derived_model = ( + ( + hasattr(model_config, "model_type") + and model_config.model_type + in [ + "mistral", + ] + ) + or cfg.is_mistral_derived_model + or "mistral" in cfg.base_model.lower() + or (cfg.model_type and "mistral" in cfg.model_type.lower()) + ) + log_gpu_memory_usage(LOG, "baseline", cfg.device)