diff --git a/README.md b/README.md index 6d1a383..db34937 100644 --- a/README.md +++ b/README.md @@ -170,8 +170,8 @@ model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca") - [x] INT4 LLaMA LoRA fine-tuning with INT4 generation - [x] Support for a `Generic model` wrapper - [x] Support for `Falcon-7B` model +- [X] INT4 low-precision fine-tuning support - [ ] Evaluation of LLM models -- [ ] INT4 low-precision fine-tuning support - [ ] INT3, INT2, INT1 low-precision fine-tuning support - [ ] Support for Stable Diffusion diff --git a/pyproject.toml b/pyproject.toml index 125ae30..78edad0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "xturing" -version = "0.1.5" +version = "0.1.6" description = "Fine-tuning, evaluation and data generation for LLMs" authors = [ diff --git a/src/xturing/__about__.py b/src/xturing/__about__.py index 1276d02..0a8da88 100644 --- a/src/xturing/__about__.py +++ b/src/xturing/__about__.py @@ -1 +1 @@ -__version__ = "0.1.5" +__version__ = "0.1.6" diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml index 5052d53..a3c5d5d 100644 --- a/src/xturing/config/finetuning_config.yaml +++ b/src/xturing/config/finetuning_config.yaml @@ -14,80 +14,87 @@ defaults: optimizer_name: adamw output_dir: saved_model -llama: +bloom: learning_rate: 5e-5 weight_decay: 0.01 num_train_epochs: 3 - optimizer_name: cpu_adam - -llama_lora: - learning_rate: 1e-4 - weight_decay: 0.01 - num_train_epochs: 3 - batch_size: 1 -llama_lora_int8: +bloom_lora: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 8 - max_length: 256 + batch_size: 4 -llama_lora_int4: +bloom_lora_int8: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 max_length: 256 -gptj: +cerebras: learning_rate: 5e-5 weight_decay: 0.01 num_train_epochs: 3 - optimizer_name: cpu_adam -gptj_lora: +cerebras_lora: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 1 + batch_size: 4 -gptj_lora_int8: +cerebras_lora_int8: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 max_length: 256 -gpt2: +distilgpt2: learning_rate: 1e-3 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 -gpt2_lora: +distilgpt2_lora: learning_rate: 3e-3 weight_decay: 0.01 num_train_epochs: 3 batch_size: 16 -gpt2_lora_int8: - learning_rate: 3e-3 +falcon: + learning_rate: 5e-5 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 16 + batch_size: 1 + max_length: 256 -distilgpt2: - learning_rate: 1e-3 +falcon_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 1 + max_length: 256 + +falcon_lora: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 1 + +falcon_lora_int8: + learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 + max_length: 256 -distilgpt2_lora: - learning_rate: 3e-3 +falcon_lora_kbit: + learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 16 + batch_size: 8 + max_length: 256 galactica: learning_rate: 5e-5 @@ -108,109 +115,130 @@ galactica_lora_int8: batch_size: 8 max_length: 256 -opt: - learning_rate: 5e-5 - weight_decay: 0.01 - num_train_epochs: 3 - -opt_lora: +generic: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 1 + batch_size: 8 + max_length: 256 -opt_lora_int8: +generic_int8: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 max_length: 256 -cerebras: - learning_rate: 5e-5 +generic_lora: + learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 + batch_size: 8 + max_length: 256 -cerebras_lora: +generic_lora_int8: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 4 + batch_size: 8 + max_length: 256 -cerebras_lora_int8: +generic_lora_kbit: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 max_length: 256 -bloom: +gptj: learning_rate: 5e-5 weight_decay: 0.01 num_train_epochs: 3 + optimizer_name: cpu_adam -bloom_lora: +gptj_lora: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 4 + batch_size: 1 -bloom_lora_int8: +gptj_lora_int8: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 max_length: 256 -generic: - learning_rate: 1e-4 +gpt2: + learning_rate: 1e-3 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 - max_length: 256 -generic_int8: - learning_rate: 1e-4 +gpt2_lora: + learning_rate: 3e-3 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 8 - max_length: 256 + batch_size: 16 -generic_int8_lora: +gpt2_lora_int8: + learning_rate: 3e-3 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 16 + +llama: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + +llama_lora: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 8 - max_length: 256 + batch_size: 1 -generic_lora: +llama_lora_int8: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 batch_size: 8 max_length: 256 - -falcon: - learning_rate: 5e-5 - weight_decay: 0.01 + +llama_lora_kbit: + learning_rate: 3e-4 num_train_epochs: 3 batch_size: 1 max_length: 256 + lora_r: 32 + lora_alpha: 128 + lora_groupsize: 128 + lora_dropout: 0.05 + seed: 0 + cache: False + seqlen: 2048 + kl_weight: 1.0 + ce_weight: 200.0 + save_freq: 1 + trainable_kl_weight: False + trainable_ce_weight: False + weight_decay: 1e-5 + intra_save_freq: 200 + groupsize: 128 -falcon_int8: - learning_rate: 1e-4 +opt: + learning_rate: 5e-5 weight_decay: 0.01 num_train_epochs: 3 - batch_size: 1 - max_length: 256 -falcon_lora: +opt_lora: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 batch_size: 1 -falcon_lora_int8: +opt_lora_int8: learning_rate: 1e-4 weight_decay: 0.01 num_train_epochs: 3 diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml index 7e01bfb..8163bb7 100644 --- a/src/xturing/config/generation_config.yaml +++ b/src/xturing/config/generation_config.yaml @@ -7,84 +7,81 @@ defaults: max_new_tokens: 256 # Contrastive search -llama: +bloom: penalty_alpha: 0.6 top_k: 4 max_new_tokens: 256 do_sample: false # Contrastive search -llama_lora: +bloom_lora: penalty_alpha: 0.6 top_k: 4 max_new_tokens: 256 do_sample: false # Greedy search -llama_lora_int8: - max_new_tokens: 256 - do_sample: false - -# Contrastive search -llama_lora_int4: - penalty_alpha: 0.6 - top_k: 4 +bloom_lora_int8: max_new_tokens: 256 do_sample: false # Contrastive search -gptj: +cerebras: penalty_alpha: 0.6 top_k: 4 max_new_tokens: 256 do_sample: false # Contrastive search -gptj_lora: +cerebras_lora: penalty_alpha: 0.6 top_k: 4 max_new_tokens: 256 do_sample: false # Greedy search -gptj_lora_int8: +cerebras_lora_int8: max_new_tokens: 256 do_sample: false # Top-p sampling -gpt2: +distilgpt2: do_sample: true top_k: 0 top_p: 0.92 max_new_tokens: 256 # Top-p sampling -gpt2_lora: +distilgpt2_lora: do_sample: true top_k: 0 top_p: 0.92 max_new_tokens: 256 -# Top-p sampling -gpt2_lora_int8: - do_sample: true - top_k: 0 - top_p: 0.92 +# Greedy search +falcon: max_new_tokens: 256 + do_sample: false -# Top-p sampling -distilgpt2: - do_sample: true - top_k: 0 - top_p: 0.92 +# Greedy search +falcon_int8: max_new_tokens: 256 + do_sample: false -# Top-p sampling -distilgpt2_lora: - do_sample: true - top_k: 0 - top_p: 0.92 +# Greedy search +falcon_lora: max_new_tokens: 256 + do_sample: false + +# Greedy search +falcon_lora_int8: + max_new_tokens: 256 + do_sample: false + +# Greedy search +falcon_lora_kbit: + max_new_tokens: 256 + do_sample: false # Contrastive search galactica: @@ -105,100 +102,110 @@ galactica_lora_int8: max_new_tokens: 256 do_sample: false -# Contrastive search -opt: - penalty_alpha: 0.6 - top_k: 4 - max_new_tokens: 256 - do_sample: false - -# Contrastive search -opt_lora: - penalty_alpha: 0.6 - top_k: 4 +# Greedy search +generic: max_new_tokens: 256 do_sample: false # Greedy search -opt_lora_int8: +generic_int8: max_new_tokens: 256 do_sample: false -# Contrastive search -cerebras: - penalty_alpha: 0.6 - top_k: 4 +# Greedy search +generic_lora: max_new_tokens: 256 do_sample: false -# Contrastive search -cerebras_lora: - penalty_alpha: 0.6 - top_k: 4 +# Greedy search +generic_lora_int8: max_new_tokens: 256 do_sample: false # Greedy search -cerebras_lora_int8: +generic_lora_kbit: max_new_tokens: 256 do_sample: false # Contrastive search -bloom: +gptj: penalty_alpha: 0.6 top_k: 4 max_new_tokens: 256 do_sample: false # Contrastive search -bloom_lora: +gptj_lora: penalty_alpha: 0.6 top_k: 4 max_new_tokens: 256 do_sample: false # Greedy search -bloom_lora_int8: +gptj_lora_int8: max_new_tokens: 256 do_sample: false -# Greedy search -generic: +# Top-p sampling +gpt2: + do_sample: true + top_k: 0 + top_p: 0.92 max_new_tokens: 256 - do_sample: false -# Greedy search -generic_int8: +# Top-p sampling +gpt2_lora: + do_sample: true + top_k: 0 + top_p: 0.92 max_new_tokens: 256 - do_sample: false -# Greedy search -generic_lora: +# Top-p sampling +gpt2_lora_int8: + do_sample: true + top_k: 0 + top_p: 0.92 + max_new_tokens: 256 + +# Contrastive search +llama: + penalty_alpha: 0.6 + top_k: 4 max_new_tokens: 256 do_sample: false +# Contrastive search +llama_lora: + penalty_alpha: 0.6 + top_k: 4 + max_new_tokens: 256 + do_sample: false # Greedy search -generic_lora_int8: +llama_lora_int8: max_new_tokens: 256 do_sample: false # Greedy search -falcon: +llama_lora_kbit: max_new_tokens: 256 do_sample: false -# Greedy search -falcon_lora: +# Contrastive search +opt: + penalty_alpha: 0.6 + top_k: 4 max_new_tokens: 256 do_sample: false -# Greedy search -falcon_int8: +# Contrastive search +opt_lora: + penalty_alpha: 0.6 + top_k: 4 max_new_tokens: 256 do_sample: false # Greedy search -falcon_lora_int8: +opt_lora_int8: max_new_tokens: 256 do_sample: false diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py index 56a9192..4c10a45 100644 --- a/src/xturing/engines/__init__.py +++ b/src/xturing/engines/__init__.py @@ -17,6 +17,7 @@ FalconInt8Engine, FalconLoraEngine, FalconLoraInt8Engine, + FalconLoraKbitEngine, ) from .galactica_engine import ( GalacticaEngine, @@ -29,6 +30,7 @@ GenericInt8Engine, GenericLoraEngine, GenericLoraInt8Engine, + GenericLoraKbitEngine, ) from .gpt2_engine import GPT2Engine, GPT2Int8Engine, GPT2LoraEngine, GPT2LoraInt8Engine from .gptj_engine import GPTJEngine, GPTJInt8Engine, GPTJLoraEngine, GPTJLoraInt8Engine @@ -36,47 +38,49 @@ LLamaEngine, LLamaInt8Engine, LlamaLoraEngine, - LlamaLoraInt4Engine, LlamaLoraInt8Engine, + LlamaLoraKbitEngine, ) from .opt_engine import OPTEngine, OPTInt8Engine, OPTLoraEngine, OPTLoraInt8Engine +BaseEngine.add_to_registry(BloomEngine.config_name, BloomEngine) +BaseEngine.add_to_registry(BloomInt8Engine.config_name, BloomInt8Engine) +BaseEngine.add_to_registry(BloomLoraEngine.config_name, BloomLoraEngine) +BaseEngine.add_to_registry(BloomLoraInt8Engine.config_name, BloomLoraInt8Engine) +BaseEngine.add_to_registry(CerebrasEngine.config_name, CerebrasEngine) +BaseEngine.add_to_registry(CerebrasInt8Engine.config_name, CerebrasInt8Engine) +BaseEngine.add_to_registry(CerebrasLoraEngine.config_name, CerebrasLoraEngine) +BaseEngine.add_to_registry(CerebrasLoraInt8Engine.config_name, CerebrasLoraInt8Engine) BaseEngine.add_to_registry(DistilGPT2Engine.config_name, DistilGPT2Engine) BaseEngine.add_to_registry(DistilGPT2LoraEngine.config_name, DistilGPT2LoraEngine) +BaseEngine.add_to_registry(FalconEngine.config_name, FalconEngine) +BaseEngine.add_to_registry(FalconInt8Engine.config_name, FalconInt8Engine) +BaseEngine.add_to_registry(FalconLoraEngine.config_name, FalconLoraEngine) +BaseEngine.add_to_registry(FalconLoraInt8Engine.config_name, FalconLoraInt8Engine) +BaseEngine.add_to_registry(FalconLoraKbitEngine.config_name, FalconLoraKbitEngine) +BaseEngine.add_to_registry(GalacticaEngine.config_name, GalacticaEngine) +BaseEngine.add_to_registry(GalacticaInt8Engine.config_name, GalacticaInt8Engine) +BaseEngine.add_to_registry(GalacticaLoraEngine.config_name, GalacticaLoraEngine) +BaseEngine.add_to_registry(GalacticaLoraInt8Engine.config_name, GalacticaLoraInt8Engine) +BaseEngine.add_to_registry(GenericEngine.config_name, GenericEngine) +BaseEngine.add_to_registry(GenericInt8Engine.config_name, GenericInt8Engine) +BaseEngine.add_to_registry(GenericLoraEngine.config_name, GenericLoraEngine) +BaseEngine.add_to_registry(GenericLoraInt8Engine.config_name, GenericLoraInt8Engine) +BaseEngine.add_to_registry(GenericLoraKbitEngine.config_name, GenericLoraKbitEngine) BaseEngine.add_to_registry(GPTJEngine.config_name, GPTJEngine) -BaseEngine.add_to_registry(GPTJLoraEngine.config_name, GPTJLoraEngine) BaseEngine.add_to_registry(GPTJInt8Engine.config_name, GPTJInt8Engine) +BaseEngine.add_to_registry(GPTJLoraEngine.config_name, GPTJLoraEngine) BaseEngine.add_to_registry(GPTJLoraInt8Engine.config_name, GPTJLoraInt8Engine) BaseEngine.add_to_registry(GPT2Engine.config_name, GPT2Engine) -BaseEngine.add_to_registry(GPT2LoraEngine.config_name, GPT2LoraEngine) BaseEngine.add_to_registry(GPT2Int8Engine.config_name, GPT2Int8Engine) +BaseEngine.add_to_registry(GPT2LoraEngine.config_name, GPT2LoraEngine) BaseEngine.add_to_registry(GPT2LoraInt8Engine.config_name, GPT2LoraInt8Engine) BaseEngine.add_to_registry(LLamaEngine.config_name, LLamaEngine) -BaseEngine.add_to_registry(LlamaLoraEngine.config_name, LlamaLoraEngine) BaseEngine.add_to_registry(LLamaInt8Engine.config_name, LLamaInt8Engine) +BaseEngine.add_to_registry(LlamaLoraEngine.config_name, LlamaLoraEngine) BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine) -BaseEngine.add_to_registry(LlamaLoraInt4Engine.config_name, LlamaLoraInt4Engine) -BaseEngine.add_to_registry(GalacticaEngine.config_name, GalacticaEngine) -BaseEngine.add_to_registry(GalacticaInt8Engine.config_name, GalacticaInt8Engine) -BaseEngine.add_to_registry(GalacticaLoraEngine.config_name, GalacticaLoraEngine) -BaseEngine.add_to_registry(GalacticaLoraInt8Engine.config_name, GalacticaLoraInt8Engine) +BaseEngine.add_to_registry(LlamaLoraKbitEngine.config_name, LlamaLoraKbitEngine) BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine) -BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine) BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine) +BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine) BaseEngine.add_to_registry(OPTLoraInt8Engine.config_name, OPTLoraInt8Engine) -BaseEngine.add_to_registry(CerebrasEngine.config_name, CerebrasEngine) -BaseEngine.add_to_registry(CerebrasLoraEngine.config_name, CerebrasLoraEngine) -BaseEngine.add_to_registry(CerebrasInt8Engine.config_name, CerebrasInt8Engine) -BaseEngine.add_to_registry(CerebrasLoraInt8Engine.config_name, CerebrasLoraInt8Engine) -BaseEngine.add_to_registry(BloomEngine.config_name, BloomEngine) -BaseEngine.add_to_registry(BloomLoraEngine.config_name, BloomLoraEngine) -BaseEngine.add_to_registry(BloomInt8Engine.config_name, BloomInt8Engine) -BaseEngine.add_to_registry(BloomLoraInt8Engine.config_name, BloomLoraInt8Engine) -BaseEngine.add_to_registry(GenericEngine.config_name, GenericEngine) -BaseEngine.add_to_registry(GenericInt8Engine.config_name, GenericInt8Engine) -BaseEngine.add_to_registry(GenericLoraEngine.config_name, GenericLoraEngine) -BaseEngine.add_to_registry(GenericLoraInt8Engine.config_name, GenericLoraInt8Engine) -BaseEngine.add_to_registry(FalconEngine.config_name, FalconEngine) -BaseEngine.add_to_registry(FalconLoraEngine.config_name, FalconLoraEngine) -BaseEngine.add_to_registry(FalconInt8Engine.config_name, FalconInt8Engine) -BaseEngine.add_to_registry(FalconLoraInt8Engine.config_name, FalconLoraInt8Engine) diff --git a/src/xturing/engines/causal.py b/src/xturing/engines/causal.py index b5609bf..7944463 100644 --- a/src/xturing/engines/causal.py +++ b/src/xturing/engines/causal.py @@ -17,6 +17,8 @@ LoraModel, prepare_model_for_int8_training, ) +from xturing.engines.quant_utils.peft_utils import LoraConfig as peftLoraConfig +from xturing.engines.quant_utils.peft_utils import prepare_model_for_kbit_training from xturing.utils.loss_fns import CrossEntropyLoss @@ -30,6 +32,7 @@ def __init__( tokenizer: Optional[Any] = None, load_8bit: Optional[bool] = False, trust_remote_code: Optional[bool] = False, + **kwargs, ): self.model_name = model_name @@ -45,11 +48,12 @@ def __init__( load_in_8bit=True, device_map=device_map, trust_remote_code=trust_remote_code, + **kwargs, ) self.model = prepare_model_for_int8_training(self.model) else: self.model = AutoModelForCausalLM.from_pretrained( - weights_path, torch_dtype=DEFAULT_DTYPE + weights_path, torch_dtype=DEFAULT_DTYPE, **kwargs ) self.tokenizer = AutoTokenizer.from_pretrained(weights_path) elif model is not None and tokenizer is not None: @@ -64,6 +68,7 @@ def __init__( load_in_8bit=True, device_map=device_map, trust_remote_code=trust_remote_code, + **kwargs, ) for param in self.model.parameters(): param.data = param.data.contiguous() @@ -73,6 +78,7 @@ def __init__( model_name, torch_dtype=DEFAULT_DTYPE, trust_remote_code=trust_remote_code, + **kwargs, ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) else: @@ -197,3 +203,71 @@ def save(self, saving_path: Union[str, Path]): # Save tokenizer self.tokenizer.save_pretrained(saving_path) + + +class CausalLoraKbitEngine(CausalEngine): + def __init__( + self, + *, + model_name: Optional[str] = None, + weights_path: Optional[Union[str, Path]] = None, + model: Optional[Any] = None, + tokenizer: Optional[Any] = None, + target_modules: Optional[Union[List[str], str]] = None, + trust_remote_code: Optional[bool] = False, + ): + if model is None: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=DEFAULT_DTYPE, + device_map=device_map, + load_in_4bit=True, + trust_remote_code=trust_remote_code, + ) + + model = prepare_model_for_kbit_training(model) + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model_name) + + super().__init__( + model_name=model_name, + weights_path=None, + model=model, + tokenizer=tokenizer, + ) + + self.print_trainable_parameters() + + self.loss_fct = CrossEntropyLoss() + + def set_from_state_dict(self, state_dict, strict=False): + self.model.load_state_dict(state_dict, strict=strict) + + def save(self, saving_path: Union[str, Path]): + # Save HF config file + self.model.config.save_pretrained(str(saving_path)) + # Save model weights + model_weights = str(Path(saving_path).resolve() / "pytorch_model.bin") + + torch.save(self.model.state_dict(), model_weights) + # save adapter + self.model.save_pretrained(saving_path) + + # Save tokenizer + self.tokenizer.save_pretrained(saving_path) + + def print_trainable_parameters(self): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in self.model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) diff --git a/src/xturing/engines/falcon_engine.py b/src/xturing/engines/falcon_engine.py index 0f9268d..e08f6ac 100644 --- a/src/xturing/engines/falcon_engine.py +++ b/src/xturing/engines/falcon_engine.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Optional, Union -from xturing.engines.causal import CausalEngine, CausalLoraEngine +from xturing.engines.causal import CausalEngine, CausalLoraEngine, CausalLoraKbitEngine class FalconEngine(CausalEngine): @@ -71,3 +71,25 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None): self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class FalconLoraKbitEngine(CausalLoraKbitEngine): + config_name: str = "falcon_lora_kbit_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + model_name = "tiiuae/falcon-7b" + super().__init__( + model_name=model_name, + weights_path=None, + target_modules=[ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + ], + trust_remote_code=True, + load_4bit=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id diff --git a/src/xturing/engines/generic_engine.py b/src/xturing/engines/generic_engine.py index 958725d..5a88849 100644 --- a/src/xturing/engines/generic_engine.py +++ b/src/xturing/engines/generic_engine.py @@ -2,19 +2,16 @@ from pathlib import Path from typing import List, Optional, Union -from xturing.engines.causal import CausalEngine, CausalLoraEngine +from xturing.engines.causal import CausalEngine, CausalLoraEngine, CausalLoraKbitEngine class GenericEngine(CausalEngine): config_name: str = "generic_engine" def __init__( - self, model_name: str, weights_path: Optional[Union[str, Path]] = None + self, model_name: str, weights_path: Optional[Union[str, Path]] = None, **kwargs ): - super().__init__( - model_name=model_name, - weights_path=weights_path, - ) + super().__init__(model_name=model_name, weights_path=weights_path, **kwargs) self.tokenizer.pad_token = self.tokenizer.eos_token @@ -65,3 +62,22 @@ def __init__( ) self.tokenizer.pad_token = self.tokenizer.eos_token + + +class GenericLoraKbitEngine(CausalLoraKbitEngine): + config_name: str = "generic+lora_kbit_engine" + + def __init__( + self, + model_name: str, + target_modules: List[str], + weights_path: Optional[Union[str, Path]] = None, + ): + super().__init__( + model_name=model_name, + weights_path=weights_path, + load_4bit=True, + target_modules=target_modules, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token diff --git a/src/xturing/engines/llama_engine.py b/src/xturing/engines/llama_engine.py index 2456a51..35669e6 100644 --- a/src/xturing/engines/llama_engine.py +++ b/src/xturing/engines/llama_engine.py @@ -1,3 +1,4 @@ +import argparse import os from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union @@ -6,10 +7,13 @@ import transformers from torch import nn -from xturing.engines.causal import CausalEngine, CausalLoraEngine +from xturing.config.config_data_classes import FinetuningConfig, GenerationConfig +from xturing.config.read_config import load_config, read_yaml +from xturing.engines.causal import CausalEngine, CausalLoraEngine, CausalLoraKbitEngine from xturing.engines.llama_utils import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from xturing.engines.lora_engine import prepare_model_for_int8_training from xturing.engines.quant_utils import autotune_warmup, make_quant +from xturing.engines.quant_utils.lrec import get_c4, prepare_models, train_model from xturing.utils.hub import ModelHub @@ -117,76 +121,53 @@ def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): return res -class LlamaLoraInt4Engine(CausalLoraEngine): - config_name: str = "llama_lora_int4_engine" +class LlamaLoraKbitEngine(CausalLoraKbitEngine): + config_name: str = "llama_lora_kbit_engine" def __init__(self, weights_path: Optional[Union[str, Path]] = None): model_name = "decapoda-research/llama-7b-hf" - - if weights_path is None: - weights_path = ModelHub().load("x/llama_lora_int4") - - config = LlamaConfig.from_pretrained(model_name) - - saved_kaiming_uniform_ = torch.nn.init.kaiming_uniform_ - saved_uniform_ = torch.nn.init.uniform_ - saved_normal_ = torch.nn.init.normal_ - - def noop(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop - - torch.set_default_dtype(torch.half) - transformers.modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) - model = LlamaForCausalLM(config) - torch.set_default_dtype(torch.float) - model = model.eval() - - layers = find_layers(model) - - for name in ["lm_head"]: - if name in layers: - del layers[name] - - wbits = 4 - groupsize = 128 - warmup_autotune = True - - make_quant(model, layers, wbits, groupsize) - - state_dict = torch.load( - weights_path / Path("pytorch_model.bin"), map_location="cpu" - ) - - if warmup_autotune: - autotune_warmup(model) - - model.seqlen = 2048 - - model.gptq = True - - model.gradient_checkpointing_enable() - model.enable_input_require_grads() + # lrec_config = { + # "base_model": model_name, + # "intq_checkpoint": str( + # Path(__file__).parent / "llama7b-2bit-128g.pt" + # ), ## how to do this + # "wbits": wbits, + # "lora_target_modules": [ + # "q_proj", + # "v_proj", + # "k_proj", + # "o_proj", + # "up_proj", + # "down_proj", + # "gate_proj", + # ], + # # "n_samples": 100, + # # "train_cache_dir": "./train_cache/", + # # "val_cache_dir": "./val_cache/", + # # "ckpt_dir": "./ckpts/", + # # "save_dir": "./save/", + # } + + # # Finetuning config + # yml_content = read_yaml( + # Path(__file__).parent.parent / "config" / "finetuning_config.yaml", + # ) + # lrec_config.update(yml_content["defaults"]) + # lrec_config.update(yml_content[self.config_name.replace("_engine", "")]) + + # model, fp_model = prepare_models(argparse.Namespace(**lrec_config)) + + # # The model before applying LoRA + # self.base_model = fp_model tokenizer = LlamaTokenizer.from_pretrained(model_name, add_bos_token=False) tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id super().__init__( - model=model, + model_name=model_name, + weights_path=None, tokenizer=tokenizer, - target_modules=[ - "q_proj", - "v_proj", - ], + target_modules=["q_proj", "v_proj"], + load_4bit=True, ) - - torch.nn.init.kaiming_uniform_ = saved_kaiming_uniform_ - torch.nn.init.uniform_ = saved_uniform_ - torch.nn.init.normal_ = saved_normal_ - - self.set_from_state_dict(state_dict) diff --git a/src/xturing/engines/lora_engine/lora.py b/src/xturing/engines/lora_engine/lora.py index 8e81faf..2202c51 100644 --- a/src/xturing/engines/lora_engine/lora.py +++ b/src/xturing/engines/lora_engine/lora.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import importlib import json import math @@ -21,17 +22,19 @@ from dataclasses import asdict, dataclass, field from enum import Enum from typing import List, Optional, Union -import enum import torch import torch.nn as nn import torch.nn.functional as F +import transformers +from transformers import LlamaConfig, LlamaForCausalLM from transformers.pytorch_utils import Conv1D from xturing.engines.lora_engine.save_and_load import ( get_peft_model_state_dict, set_peft_model_state_dict, ) +from xturing.engines.quant_utils import QuantLinear, autotune_warmup def is_bnb_available(): @@ -45,18 +48,22 @@ def is_bnb_available(): def transpose(weight, fan_in_fan_out): return weight.T if fan_in_fan_out else weight + def is_gptq_available(): return importlib.util.find_spec("xturing.engines.quant_utils") is not None + if is_gptq_available(): from ..quant_utils import QuantLinear + class PeftType(str, enum.Enum): PROMPT_TUNING = "PROMPT_TUNING" P_TUNING = "P_TUNING" PREFIX_TUNING = "PREFIX_TUNING" LORA = "LORA" + WEIGHTS_NAME = "adapter_model.bin" CONFIG_NAME = "adapter_config.json" @@ -90,14 +97,22 @@ class LoraConfig: lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"}) lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"}) merge_weights: bool = field( - default=False, metadata={"help": "Merge weights of the original model and the Lora model"} + default=False, + metadata={"help": "Merge weights of the original model and the Lora model"}, ) fan_in_fan_out: bool = field( default=False, - metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + metadata={ + "help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)" + }, + ) + enable_lora: Optional[List[bool]] = field( + default=None, metadata={"help": "Used with `lora.MergedLinear`."} + ) + bias: str = field( + default="none", + metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}, ) - enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."}) - bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}) modules_to_save: Optional[List[str]] = field( default=None, metadata={ @@ -231,7 +246,9 @@ def __init__(self, config, model): def _find_and_replace(self): loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) - is_gtq_quantized = getattr(self.model, "gptq", False) # Step 1: Check if the model is GTQ quantized + is_gtq_quantized = getattr( + self.model, "gptq", False + ) # Step 1: Check if the model is GTQ quantized if loaded_in_8bit and not is_bnb_available(): raise ImportError( @@ -245,7 +262,9 @@ def _find_and_replace(self): "lora_alpha": self.peft_config.lora_alpha, "lora_dropout": self.peft_config.lora_dropout, "fan_in_fan_out": self.peft_config.fan_in_fan_out, - "merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode) + "merge_weights": ( + self.peft_config.merge_weights or self.peft_config.inference_mode + ) and not is_hf_device_map_available, } key_list = [key for key, _ in self.model.named_modules()] @@ -253,7 +272,10 @@ def _find_and_replace(self): if isinstance(self.peft_config.target_modules, str): target_module_found = re.fullmatch(self.peft_config.target_modules, key) else: - target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules) + target_module_found = any( + key.endswith(target_key) + for target_key in self.peft_config.target_modules + ) if target_module_found: if not is_target_modules_in_base_model: is_target_modules_in_base_model = True @@ -269,10 +291,14 @@ def _find_and_replace(self): } ) if self.peft_config.enable_lora is None: - new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs) + new_module = Linear8bitLt( + target.in_features, target.out_features, bias=bias, **kwargs + ) else: kwargs.update({"enable_lora": self.peft_config.enable_lora}) - new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs) + new_module = MergedLinear8bitLt( + target.in_features, target.out_features, bias=bias, **kwargs + ) elif is_gptq_available() and isinstance(target, QuantLinear): kwargs.update( { @@ -281,7 +307,9 @@ def _find_and_replace(self): } ) if self.peft_config.enable_lora is None: - new_module = LinearqbitLt(target.infeatures, target.outfeatures, bias=bias, **kwargs) + new_module = LinearqbitLt( + target.infeatures, target.outfeatures, bias=bias, **kwargs + ) new_module.scales = target.scales new_module.qzeros = target.qzeros new_module.g_idx = target.g_idx @@ -289,29 +317,45 @@ def _find_and_replace(self): new_module.bias = target.bias else: kwargs.update({"enable_lora": self.peft_config.enable_lora}) - new_module = MergedLinearqbitLt(target.infeatures, target.outfeatures, bias=bias, **kwargs) + new_module = MergedLinearqbitLt( + target.infeatures, target.outfeatures, bias=bias, **kwargs + ) new_module.scales = target.scales new_module.qzeros = target.qzeros new_module.g_idx = target.g_idx if target.bias: new_module.bias = target.bias - elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None: - new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs) + elif ( + isinstance(target, torch.nn.Linear) + and self.peft_config.enable_lora is None + ): + new_module = Linear( + target.in_features, target.out_features, bias=bias, **kwargs + ) elif self.peft_config.enable_lora is not None: kwargs.update({"enable_lora": self.peft_config.enable_lora}) if isinstance(target, Conv1D): in_features, out_features = ( - target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape + target.weight.ds_shape + if hasattr(target.weight, "ds_shape") + else target.weight.shape ) else: - in_features, out_features = target.in_features, target.out_features + in_features, out_features = ( + target.in_features, + target.out_features, + ) if kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to True but the target module is not a Conv1D. " "Setting fan_in_fan_out to False." ) - kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False - new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs) + kwargs[ + "fan_in_fan_out" + ] = self.peft_config.fan_in_fan_out = False + new_module = MergedLinear( + in_features, out_features, bias=bias, **kwargs + ) self._replace_module(parent, target_name, new_module, target) if not is_target_modules_in_base_model: raise ValueError( @@ -364,7 +408,10 @@ def modules_to_save(self): return None def get_peft_config_as_dict(self, inference: bool = False): - config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()} + config = { + k: v.value if isinstance(v, Enum) else v + for k, v in asdict(self.peft_config).items() + } if inference: config["inference_mode"] = True return config @@ -512,7 +559,13 @@ def __init__( **kwargs, ): nn.Linear.__init__(self, in_features, out_features, **kwargs) - LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + LoraLayer.__init__( + self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights, + ) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters @@ -541,14 +594,20 @@ def train(self, mode: bool = True): # Merge the weights and mark it if self.r > 0: self.weight.data += ( - transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling + transpose( + self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out + ) + * self.scaling ) self.merged = True elif self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= ( - transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling + transpose( + self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out + ) + * self.scaling ) self.merged = False @@ -561,19 +620,30 @@ def forward(self, x: torch.Tensor): if self.disable_adapters: if self.r > 0 and self.merged: self.weight.data -= ( - transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling + transpose( + self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out + ) + * self.scaling ) self.merged = False - return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + return F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) elif self.r > 0 and not self.merged: - result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + result = F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) if self.r > 0: - loraoutput = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + loraoutput = ( + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + ) result = result + loraoutput return result else: - return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + return F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) class MergedLinear(nn.Linear, LoraLayer): @@ -591,7 +661,13 @@ def __init__( **kwargs, ): nn.Linear.__init__(self, in_features, out_features, **kwargs) - LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + LoraLayer.__init__( + self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights, + ) if out_features % len(enable_lora) != 0: raise ValueError("The length of enable_lora must divide out_features") self.enable_lora = enable_lora @@ -610,7 +686,9 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.requires_grad = False # Compute the indices - self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) + self.lora_ind = self.weight.new_zeros( + (out_features,), dtype=torch.bool + ).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) self.reset_parameters() @@ -627,7 +705,9 @@ def reset_parameters(self): def zero_pad(self, x): result = x.new_zeros((*x.shape[:-1], self.out_features)) result = result.view(-1, self.out_features) - result[:, self.lora_ind] = x.reshape(-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)) + result[:, self.lora_ind] = x.reshape( + -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) + ) return result.view((*x.shape[:-1], self.out_features)) def train(self, mode: bool = True): @@ -646,7 +726,9 @@ def train(self, mode: bool = True): .squeeze(0) .transpose(-2, -1) ) - self.weight.data += transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) + self.weight.data += transpose( + self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out + ) self.merged = True elif self.merge_weights and self.merged: # Make sure that the weights are not merged @@ -660,7 +742,9 @@ def train(self, mode: bool = True): .squeeze(0) .transpose(-2, -1) ) - self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) + self.weight.data -= transpose( + self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out + ) self.merged = False def eval(self): @@ -680,13 +764,21 @@ def forward(self, x: torch.Tensor): .squeeze(0) .transpose(-2, -1) ) - self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) + self.weight.data -= transpose( + self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out + ) self.merged = False - return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + return F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) elif self.merged: - return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + return F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) else: - result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + result = F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) if self.r > 0: after_A = self.lora_A(self.lora_dropout(x)) after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) @@ -713,11 +805,19 @@ def __init__( out_features, bias=kwargs.get("bias", True), has_fp16_weights=kwargs.get("has_fp16_weights", True), - memory_efficient_backward=kwargs.get("memory_efficient_backward", False), + memory_efficient_backward=kwargs.get( + "memory_efficient_backward", False + ), threshold=kwargs.get("threshold", 0.0), index=kwargs.get("index", None), ) - LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + LoraLayer.__init__( + self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=False, + ) # Actual trainable parameters if r > 0: self.lora_A = nn.Linear(in_features, r, bias=False) @@ -744,10 +844,17 @@ def forward(self, x: torch.Tensor): if x.dtype != torch.float32: x = x.float() - output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling + output = ( + self.lora_B(self.lora_A(self.lora_dropout(x))).to( + expected_dtype + ) + * self.scaling + ) result += output else: - output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + output = ( + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + ) result += output return result @@ -769,11 +876,19 @@ def __init__( out_features, bias=kwargs.get("bias", True), has_fp16_weights=kwargs.get("has_fp16_weights", True), - memory_efficient_backward=kwargs.get("memory_efficient_backward", False), + memory_efficient_backward=kwargs.get( + "memory_efficient_backward", False + ), threshold=kwargs.get("threshold", 0.0), index=kwargs.get("index", None), ) - LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + LoraLayer.__init__( + self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=False, + ) if out_features % len(enable_lora) != 0: raise ValueError("The length of enable_lora must divide out_features") self.enable_lora = enable_lora @@ -791,7 +906,9 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.requires_grad = False # Compute the indices - self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) + self.lora_ind = self.weight.new_zeros( + (out_features,), dtype=torch.bool + ).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) self.reset_parameters() @@ -830,7 +947,9 @@ def forward(self, x: torch.Tensor): result += output return result + if is_gptq_available(): + class LinearqbitLt(QuantLinear, LoraLayer): # Lora implemented in a dense layer def __init__( @@ -842,17 +961,22 @@ def __init__( lora_dropout: float = 0.0, **kwargs, ): - QuantLinear.__init__( self, - kwargs.get('bits', 4), - kwargs.get('groupsize', 128), + kwargs.get("bits", 4), + kwargs.get("groupsize", 128), in_features, out_features, - kwargs.get('bias', False), + kwargs.get("bias", False), ) - LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + LoraLayer.__init__( + self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=False, + ) # Actual trainable parameters if r > 0: self.lora_A = nn.Linear(in_features, r, bias=False) @@ -868,16 +992,20 @@ def reset_parameters(self): if hasattr(self, "lora_A"): # initialize A the same way as the default for nn.Linear and B to zero # nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) - self.lora_A.weight = torch.nn.Parameter(torch.nn.init.kaiming_uniform(self.lora_A.weight, a=math.sqrt(5))) + self.lora_A.weight = torch.nn.Parameter( + torch.nn.init.kaiming_uniform(self.lora_A.weight, a=math.sqrt(5)) + ) nn.init.zeros_(self.lora_B.weight) def forward(self, x: torch.Tensor): # x = x.detach() custom_layer_output = super().forward(x) - + dtype = custom_layer_output.dtype x = x.float() - lora_output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(dtype) * self.scaling + lora_output = ( + self.lora_B(self.lora_A(self.lora_dropout(x))).to(dtype) * self.scaling + ) result = custom_layer_output + lora_output return result @@ -895,12 +1023,18 @@ def __init__( ): QuantLinear.__init__( self, - kwargs.get('bits', 4), - kwargs.get('groupsize', 128), + kwargs.get("bits", 4), + kwargs.get("groupsize", 128), in_features, out_features, ) - LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + LoraLayer.__init__( + self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=False, + ) if out_features % len(enable_lora) != 0: raise ValueError("The length of enable_lora must divide out_features") self.enable_lora = enable_lora @@ -918,7 +1052,9 @@ def __init__( # Freezing the pre-trained weight matrix self.qweight.requires_grad = False # Compute the indices - self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) + self.lora_ind = self.weight.new_zeros( + (out_features,), dtype=torch.bool + ).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) self.reset_parameters() @@ -938,7 +1074,7 @@ def zero_pad(self, x): return result.view((*x.shape[:-1], self.out_features)) def forward(self, x: torch.Tensor): - result = super().forward(x)#.detach() + result = super().forward(x) # .detach() if self.disable_adapters: return result elif self.r > 0: @@ -1019,3 +1155,81 @@ def forward(self, x): ) return model + + +def make_quant(module, names, bits, groupsize, name=""): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + "." + attr if name != "" else attr + if name1 in names: + delattr(module, attr) + setattr( + module, + attr, + QuantLinear( + bits, + groupsize, + tmp.in_features, + tmp.out_features, + tmp.bias is not None, + ), + ) + for name1, child in module.named_children(): + make_quant( + child, names, bits, groupsize, name + "." + name1 if name != "" else name1 + ) + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res + + +def load_quant( + model, checkpoint, wbits, groupsize=128, warmup_autotune=True, model_seqlen=2048 +): + config = LlamaConfig.from_pretrained(model) + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = LlamaForCausalLM(config) + torch.set_default_dtype(torch.float) + model = model.eval() + layers = find_layers(model) + for name in ["lm_head"]: + if name in layers: + del layers[name] + make_quant(model, layers, wbits, groupsize) + + del layers + + print("Loading model ...") + if checkpoint.endswith(".safetensors"): + from safetensors.torch import load_file as safe_load + + model.load_state_dict(safe_load(checkpoint), strict=False) + else: + model.load_state_dict(torch.load(checkpoint), strict=False) + if warmup_autotune: + autotune_warmup(model) + model.seqlen = model_seqlen + print("Done.") + return model diff --git a/src/xturing/engines/quant_utils/cachedistillationoutputs.py b/src/xturing/engines/quant_utils/cachedistillationoutputs.py new file mode 100644 index 0000000..0cf54db --- /dev/null +++ b/src/xturing/engines/quant_utils/cachedistillationoutputs.py @@ -0,0 +1,80 @@ +import argparse +import os +import random + +import torch +from transformers import AutoModelForCausalLM, LlamaTokenizer + +from xturing.engines.quant_utils.qerdataloading import ( + create_random_trainloader, + create_random_valenc, + load_c4_datasets, +) + +HF_CACHE_DIR = ... ###ANONYMIZED### + + +def cache_distillation_outputs( + base_model, seqlen, n_samples, train_cache_dir, val_cache_dir +): + os.makedirs(train_cache_dir, exist_ok=True) + os.makedirs(val_cache_dir, exist_ok=True) + + tokenizer = LlamaTokenizer.from_pretrained(base_model, use_fast=False) + fp_model = AutoModelForCausalLM.from_pretrained( + base_model, torch_dtype=torch.float16, cache_dir=HF_CACHE_DIR, device_map="auto" + ) + fp_model.eval() + + traindata, valdata = load_c4_datasets() + trainloader = create_random_trainloader(traindata, tokenizer, seqlen, n_samples) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + random.seed(seed) + valenc = create_random_valenc(valdata, tokenizer, seqlen) + + for idx, (inp, tar) in enumerate(trainloader): + inp, tar = inp.squeeze(1).to("cuda"), tar.squeeze(1).to("cuda") + with torch.no_grad(): + target = fp_model(input_ids=inp, labels=tar).logits # .log_softmax(dim=-1) + torch.save(target, os.path.join(train_cache_dir, f"target_{idx}.pt")) + + for idx, (inp, tar) in enumerate(valenc): + inp, tar = inp.squeeze(1).to("cuda"), tar.squeeze(1).to("cuda") + if idx < 8: + torch.save(inp, os.path.join(val_cache_dir, f"input_{idx}.pt")) + torch.save(tar, os.path.join(val_cache_dir, f"label_{idx}.pt")) + with torch.no_grad(): + target = fp_model(input_ids=inp, labels=tar).logits # .log_softmax(dim=-1) + torch.save(target, os.path.join(val_cache_dir, f"target_{idx}.pt")) + + +if __name__ == "__main__": + base_model = "decapoda-research/llama-7b-hf" + seqlen = 2048 + n_samples = 10000 + train_cache_dir = ... ###ANONYMIZED### + val_cache_dir = ... ###ANONYMIZED### + seed = 1 + + parser = argparse.ArgumentParser() + parser.add_argument("--base_model", type=str, default=base_model) + parser.add_argument("--seqlen", type=int, default=seqlen) + parser.add_argument("--n_samples", type=int, default=n_samples) + parser.add_argument("--train_cache_dir", type=str, default=train_cache_dir) + parser.add_argument("--val_cache_dir", type=str, default=val_cache_dir) + parser.add_argument("--seed", type=int, default=seed) + args = parser.parse_args() + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + random.seed(seed) + + cache_distillation_outputs( + args.base_model, + args.seqlen, + args.n_samples, + args.train_cache_dir, + args.val_cache_dir, + ) diff --git a/src/xturing/engines/quant_utils/lrec.py b/src/xturing/engines/quant_utils/lrec.py new file mode 100644 index 0000000..ff4c4d4 --- /dev/null +++ b/src/xturing/engines/quant_utils/lrec.py @@ -0,0 +1,435 @@ +import argparse +import os +import random +from concurrent.futures import ThreadPoolExecutor + +import torch +import torch.nn.functional as F +import wandb +from datasets import load_dataset +from torch.nn import KLDivLoss +from torch.utils.checkpoint import checkpoint +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoModelForCausalLM, LlamaTokenizer + +from xturing.engines.lora_engine.lora import LoraConfig, LoraModel, load_quant +from xturing.engines.quant_utils.qerdataloading import get_c4 + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train a model with given parameters.") + parser.add_argument( + "--base_model", + type=str, + default="decapoda-research/llama-7b-hf", + help="The base model.", + ) + parser.add_argument( + "--intq_checkpoint", + type=str, + default="llama7b-2bit-128g.pt", + help="The intq checkpoint.", + ) + parser.add_argument( + "--wbits", + type=int, + default=2, + help="The number of bits for weight quantization.", + ) + parser.add_argument("--groupsize", type=int, default=128, help="The group size.") + parser.add_argument( + "--lora_alpha", type=int, default=128, help="The Lora alpha value." + ) + parser.add_argument("--lora_r", type=int, default=32, help="The Lora r value.") + parser.add_argument( + "--lora_dropout", type=float, default=0.05, help="The Lora dropout rate." + ) + parser.add_argument( + "--lora_target_modules", + nargs="+", + default=[ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "up_proj", + "down_proj", + "gate_proj", + ], + help="List of target modules for Lora.", + ) + parser.add_argument( + "--n_samples", type=int, default=2048, help="The number of samples." + ) + parser.add_argument("--lr", type=float, default=3e-4, help="The learning rate.") + parser.add_argument("--batch_size", type=int, default=4, help="The batch size.") + parser.add_argument( + "--num_epochs", type=int, default=20, help="The number of epochs." + ) + parser.add_argument("--kl_weight", type=float, default=1.0, help="The KL weight.") + parser.add_argument("--ce_weight", type=float, default=200.0, help="The CE weight.") + parser.add_argument( + "--trainable_kl_weight", + action="store_true", + help="Whether to learn the KL weight.", + default=False, + ) + parser.add_argument( + "--trainable_ce_weight", + action="store_true", + help="Whether to learn the CE weight.", + default=False, + ) + parser.add_argument( + "--weight_decay", type=float, default=1e-5, help="The weight decay." + ) + parser.add_argument( + "--save_freq", + type=int, + default=1, + help="The frequency (period) of saving checkpoints in epochs.", + ) + parser.add_argument( + "--intra_save_freq", + type=int, + default=200, + help="The period (in num_batches) of saving checkpoints within an epoch.", + ) + parser.add_argument("--seed", type=int, default=0, help="The random seed.") + parser.add_argument("--seqlen", type=int, default=2048, help="The sequence length.") + parser.add_argument( + "--cache", + action="store_true", + default=True, + help="Use cached distillation outputs.", + ) + parser.add_argument( + "--train_cache_dir", + type=str, + default="###ANONYMIZED###/train_cache/", + help="Training cache directory.", + ) + parser.add_argument( + "--val_cache_dir", + type=str, + default="###ANONYMIZED###/val_cache/", + help="Validation cache directory.", + ) + parser.add_argument( + "--ckpt_dir", + type=str, + default="", + help="The directory for saving and loading checkpoints.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="", + help="The directory for saving the final model.", + ) + + return parser.parse_args() + + +def get_lora_model(model, config): + return LoraModel(config, model) + + +def prepare_models(args): + if not args.cache: + fp_model = AutoModelForCausalLM.from_pretrained( + args.base_model, torch_dtype=torch.float16 + ).to("cuda") + model = load_quant( + args.base_model, + args.intq_checkpoint, + args.wbits, + args.groupsize, + ).to("cuda") + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=args.lora_target_modules, + lora_dropout=args.lora_dropout, + bias="none", + peft_type="CAUSAL_LM", + ) + model = get_lora_model(model, config) + # model = get_peft_model(model, config) + return model, fp_model if not args.cache else model + + +def reduce_loss(pointwise_loss, reduction="batchmean"): + if reduction == "batchmean": + if pointwise_loss.dtype == torch.float16: + # If the loss sum is larger than 65536, it will overflow during the mean computation. + # If that's the case, we'll compute the iterative mean across the batch dimension + # new average = old average * (n-1)/n + (new_loss)/n). + if torch.isinf(pointwise_loss.sum()): + loss = torch.tensor( + [0], dtype=torch.float16, device=pointwise_loss.device + ) + for i in range(pointwise_loss.size(0)): + sample_loss = pointwise_loss[i].sum() + # Divide first to avoid overflow. + loss = (loss / (i + 1)) * i + sample_loss / (i + 1) + return loss + return pointwise_loss.sum() / pointwise_loss.size(0) + elif reduction == "none": + return pointwise_loss + else: + raise NotImplementedError(f"Unknown reduction {reduction}") + + +def train_model(args, model, fp_model, trainloader, valenc): + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + kl_lossfn = KLDivLoss(reduction="none", log_target=True) + base_model_name = ( + args.intq_checkpoint + + "-qer-r" + + str(args.lora_r) + + "-tm" + + str(args.lora_target_modules) + + "-ce" + + str(args.ce_weight) + + "-kl" + + str(args.kl_weight) + + "-lr" + + str(args.lr) + + "-bs" + + str(args.batch_size) + + "-wd" + + str(args.weight_decay) + + "-dstl" + + str(args.train_cache_dir.split("/")[-2]) + ) + wandb.init(project="qer", name=base_model_name, config=args) + if args.ckpt_dir: + os.makedirs(args.ckpt_dir, exist_ok=True) + ckpt_path = os.path.join(args.ckpt_dir, f"{base_model_name}.pt") + if os.path.exists(ckpt_path): + ckpt = torch.load(ckpt_path) + m_class = model.__class__ + model = m_class.from_pretrained(model.base_model, ckpt_path).to("cuda") + print(f"Loaded checkpoint from {ckpt_path}") + optimizer.load_state_dict(ckpt["optimizer_state_dict"]) + start_epoch = ckpt["epoch"] + 1 + else: + start_epoch = 0 + + if args.trainable_kl_weight: + kl_weight = torch.nn.Parameter(torch.tensor(args.kl_weight)) + optimizer.add_param_group({"params": kl_weight}) + else: + kl_weight = args.kl_weight + if args.trainable_ce_weight: + ce_weight = torch.nn.Parameter(torch.tensor(args.ce_weight)) + optimizer.add_param_group({"params": ce_weight}) + else: + ce_weight = args.ce_weight + + def get_cached_targets(batch_idx, batch_size, cache_dir): + with ThreadPoolExecutor() as executor: + targets = list( + executor.map( + lambda idx: torch.load(os.path.join(cache_dir, f"target_{idx}.pt")), + range(batch_idx * batch_size, (batch_idx + 1) * batch_size), + ) + ) + return torch.cat( + targets, dim=0 + ) # Concatenate targets along the batch dimension + + def custom_forward(inp, labels): + model_out = model(input_ids=inp, labels=labels) + return model_out.logits + + def eval_model(model, valenc, train_flag=False, max_samples=256): + val_pbar = tqdm(enumerate(valenc), total=len(valenc)) + model.eval() + if not args.cache: + fp_model.eval() + with torch.no_grad(): + samples = 0 + total_loss = 0 + total_kl_loss = 0 + total_ce_loss = 0 + for batch_idx, (inp, labels) in val_pbar: + if samples > max_samples: + break + + inp, labels = inp.squeeze(1).to("cuda"), labels.squeeze(1).to("cuda") + + model_out = model(input_ids=inp, labels=labels).logits + + model_out_logsoftmax = model_out.add(1e-7).log_softmax(dim=-1) + + if args.cache: + pre_target = get_cached_targets( + batch_idx, + args.batch_size, + args.val_cache_dir if not train_flag else args.train_cache_dir, + ) + target = pre_target.add(1e-7).log_softmax(dim=-1) + else: + target = fp_model(input_ids=inp, labels=labels).logits.log_softmax( + dim=-1 + ) + + kl_loss_pointwise = kl_lossfn(model_out_logsoftmax, target) + kl_loss = reduce_loss(kl_loss_pointwise, reduction="batchmean") + ce_target = inp[:, 1:].contiguous().view(-1) + ce_logits = ( + model_out[:, :-1, :].contiguous().view(-1, model_out.size(-1)) + ) + ce_loss = F.cross_entropy(ce_logits, ce_target) + total_loss += kl_weight * kl_loss + ce_weight * ce_loss + total_kl_loss += kl_loss.item() + total_ce_loss += ce_loss.item() + val_pbar.set_description( + f"KL loss: {kl_loss.item():.3f} | CE loss: {ce_loss.item()}" + ) + samples += args.batch_size + denominator = samples / args.batch_size + val_loss = total_loss / denominator + total_kl_loss /= denominator + total_ce_loss /= denominator + if train_flag: + print(f"Training loss: {val_loss}") + print(f"Training KL loss: {total_kl_loss}") + print(f"Training CE loss: {total_ce_loss}") + else: + print(f"Validation loss: {val_loss}") + print(f"Validation KL loss: {total_kl_loss}") + print(f"Validation CE loss: {total_ce_loss}") + return val_loss, total_kl_loss, total_ce_loss + + model.train() + + for epoch in range(start_epoch, args.num_epochs): + val_loss, val_kl_loss, val_ce_loss = eval_model(model, valenc) + train_loss, train_kl_loss, train_ce_loss = eval_model( + model, trainloader, train_flag=True + ) + wandb.log( + { + "Epoch": epoch, + "Train loss": train_loss, + "Validation loss": val_loss, + "Train KL loss": train_kl_loss, + "Validation KL loss": val_kl_loss, + "Train CE loss": train_ce_loss, + "Validation CE loss": val_ce_loss, + }, + step=epoch * len(trainloader), + ) + pbar = tqdm(enumerate(trainloader), total=len(trainloader)) + print(f"Epoch {epoch} validation loss: {val_loss}") + model.train() + fp_model.eval() if not args.cache else None + for batch_idx, (inp, labels) in pbar: + inp, labels = inp.squeeze(1).to("cuda"), labels.squeeze(1).to("cuda") + + model_out = checkpoint(custom_forward, inp, labels, use_reentrant=False) + model_out_logsoftmax = model_out.add(1e-7).log_softmax(dim=-1) + if args.cache: + target = ( + get_cached_targets(batch_idx, args.batch_size, args.train_cache_dir) + .add(1e-7) + .log_softmax(dim=-1) + ) + else: + with torch.no_grad(): + target = fp_model(input_ids=inp, labels=labels).logits.log_softmax( + dim=-1 + ) + + ce_target = inp[:, 1:].contiguous().view(-1) + ce_logits = model_out[:, :-1, :].contiguous().view(-1, model_out.size(-1)) + ce_loss = F.cross_entropy(ce_logits, ce_target) + kl_loss_pointwise = kl_lossfn(model_out_logsoftmax, target) + kl_loss = reduce_loss(kl_loss_pointwise, reduction="batchmean") + loss = kl_weight * kl_loss + ce_weight * ce_loss + # Check if loss is nan + loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0) + loss.backward() + # Double check if loss is nan + if torch.isnan(loss): + continue + optimizer.step() + description = f"Epoch {epoch} batch {batch_idx} ce loss: {ce_loss.item():.3f} | kl loss: {kl_loss.item():.2f} | total loss: {loss.item():.3f}" + if args.trainable_ce_weight: + description += f" | ce weight: {ce_weight}" + if args.trainable_kl_weight: + description += f" | kl weight: {kl_weight}" + pbar.set_description(description) + wandb.log( + { + "Epoch": epoch, + "Batch": batch_idx, + "CE loss": ce_loss.item(), + "KL loss": kl_loss.item(), + "Total loss": loss.item(), + }, + step=epoch * len(trainloader) + batch_idx, + ) + optimizer.zero_grad() + + if batch_idx % args.intra_save_freq == 0 and batch_idx != 0: + with torch.no_grad(): + ckpt = { + "optimizer_state_dict": optimizer.state_dict(), + "epoch": epoch, + } + torch.save( + ckpt, + f"{args.save_dir}/tmp/ckpts/" + + base_model_name + + f"-{epoch}-{batch_idx}.pt", + ) + model.save_pretrained( + f"{args.save_dir}/tmp/models/{base_model_name}-{epoch}-{batch_idx}" + ) + + if epoch % args.save_freq == 0: + ckpt = { + "optimizer_state_dict": optimizer.state_dict(), + "epoch": epoch, + } + torch.save(ckpt, f"{args.save_dir}/tmp/ckpts/" + base_model_name + ".pt") + model.save_pretrained( + f"{args.save_dir}/tmp/models/{base_model_name}-{epoch}" + ) + with open(f"{args.save_dir}/tmp/logs/{base_model_name}.txt", "a") as f: + f.write(f"Epoch {epoch - 1} val loss: {val_loss}\n") + f.write(f"Epoch {epoch - 1} train loss: {train_loss}\n") + + model.save_pretrained(base_model_name + ".pt") + + +def main(): + args = parse_args() + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + wandb.login() + if args.cache: + model = prepare_models(args)[0] + else: + model, fp_model = prepare_models(args) + trainloader, valenc = get_c4( + args.base_model, args.seqlen, args.n_samples, args.batch_size, args.seed + ) + if args.cache: + train_model(args, model, None, trainloader, valenc) # Pass None for fp_model + else: + train_model(args, model, fp_model, trainloader, valenc) + + +if __name__ == "__main__": + main() diff --git a/src/xturing/engines/quant_utils/peft_utils.py b/src/xturing/engines/quant_utils/peft_utils.py new file mode 100644 index 0000000..7ea20b6 --- /dev/null +++ b/src/xturing/engines/quant_utils/peft_utils.py @@ -0,0 +1,338 @@ +import enum +import inspect +import json +import os +from dataclasses import asdict, dataclass, field +from typing import List, Optional, Union + +import torch +from huggingface_hub import hf_hub_download +from transformers.utils import PushToHubMixin + +CONFIG_NAME = "adapter_config.json" + + +class PeftType(str, enum.Enum): + PROMPT_TUNING = "PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" + LORA = "LORA" + ADALORA = "ADALORA" + ADAPTION_PROMPT = "ADAPTION_PROMPT" + + +class TaskType(str, enum.Enum): + SEQ_CLS = "SEQ_CLS" + SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" + CAUSAL_LM = "CAUSAL_LM" + TOKEN_CLS = "TOKEN_CLS" + QUESTION_ANS = "QUESTION_ANS" + + +@dataclass +class PeftConfigMixin(PushToHubMixin): + r""" + This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all + PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to + push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a + directory. The method `from_pretrained` will load the configuration of your adapter model from a directory. + + Args: + peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. + """ + peft_type: Optional[PeftType] = field( + default=None, metadata={"help": "The type of PEFT model."} + ) + + @property + def __dict__(self): + return asdict(self) + + def to_dict(self): + return self.__dict__ + + def save_pretrained(self, save_directory, **kwargs): + r""" + This method saves the configuration of your adapter model in a directory. + + Args: + save_directory (`str`): + The directory where the configuration will be saved. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`] + method. + """ + if os.path.isfile(save_directory): + raise AssertionError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + + os.makedirs(save_directory, exist_ok=True) + + output_dict = self.__dict__ + output_path = os.path.join(save_directory, CONFIG_NAME) + + # save it + with open(output_path, "w") as writer: + writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs): + r""" + This method loads the configuration of your adapter model from a directory. + + Args: + pretrained_model_name_or_path (`str`): + The directory or the Hub repository id where the configuration is saved. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments passed along to the child class initialization. + """ + path = ( + os.path.join(pretrained_model_name_or_path, subfolder) + if subfolder is not None + else pretrained_model_name_or_path + ) + + hf_hub_download_kwargs, class_kwargs, other_kwargs = cls._split_kwargs(kwargs) + + if os.path.isfile(os.path.join(path, CONFIG_NAME)): + config_file = os.path.join(path, CONFIG_NAME) + else: + try: + config_file = hf_hub_download( + pretrained_model_name_or_path, + CONFIG_NAME, + subfolder=subfolder, + **hf_hub_download_kwargs, + ) + except Exception: + raise ValueError( + f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'" + ) + + loaded_attributes = cls.from_json_file(config_file) + + config = cls(**class_kwargs) + + for key, value in loaded_attributes.items(): + if hasattr(config, key): + setattr(config, key, value) + + return config + + @classmethod + def from_json_file(cls, path_json_file, **kwargs): + r""" + Loads a configuration file from a json file. + + Args: + path_json_file (`str`): + The path to the json file. + """ + with open(path_json_file, "r") as file: + json_object = json.load(file) + + return json_object + + @classmethod + def _split_kwargs(cls, kwargs): + hf_hub_download_kwargs = {} + class_kwargs = {} + other_kwargs = {} + + for key, value in kwargs.items(): + if key in inspect.signature(hf_hub_download).parameters: + hf_hub_download_kwargs[key] = value + elif key in list(cls.__annotations__): + class_kwargs[key] = value + else: + other_kwargs[key] = value + + return hf_hub_download_kwargs, class_kwargs, other_kwargs + + @classmethod + def _get_peft_type( + cls, + model_id, + subfolder: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + ): + path = os.path.join(model_id, subfolder) if subfolder is not None else model_id + + if os.path.isfile(os.path.join(path, CONFIG_NAME)): + config_file = os.path.join(path, CONFIG_NAME) + else: + try: + config_file = hf_hub_download( + model_id, + CONFIG_NAME, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + ) + except Exception: + raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'") + + loaded_attributes = cls.from_json_file(config_file) + return loaded_attributes["peft_type"] + + +@dataclass +class PeftConfig(PeftConfigMixin): + """ + This is the base configuration class to store the configuration of a [`PeftModel`]. + + Args: + peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. + task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform. + inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode. + """ + + base_model_name_or_path: str = field( + default=None, metadata={"help": "The name of the base model to use."} + ) + revision: str = field( + default=None, metadata={"help": "The specific model version to use."} + ) + peft_type: Union[str, PeftType] = field( + default=None, metadata={"help": "Peft type"} + ) + task_type: Union[str, TaskType] = field( + default=None, metadata={"help": "Task type"} + ) + inference_mode: bool = field( + default=False, metadata={"help": "Whether to use inference mode"} + ) + + +@dataclass +class LoraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`LoraModel`]. + + Args: + r (`int`): Lora attention dimension. + target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. + lora_alpha (`int`): The alpha parameter for Lora scaling. + lora_dropout (`float`): The dropout probability for Lora layers. + fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). + For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.: + bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' + modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable + and saved in the final checkpoint. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the LoRA transformations on + the layer indexes that are specified in this list. If a single integer is passed, it will apply the LoRA + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + """ + + r: int = field(default=8, metadata={"help": "Lora attention dimension"}) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with Lora." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"}) + lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"}) + fan_in_fan_out: bool = field( + default=False, + metadata={ + "help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)" + }, + ) + bias: str = field( + default="none", + metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + init_lora_weights: bool = field( + default=True, + metadata={"help": "Whether to initialize the weights of the Lora layers."}, + ) + layers_to_transform: Optional[Union[List, int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.LORA + + +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): + r""" + This method wraps the entire protocol for preparing a model before running a training. This includes: + 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm + head to fp32 + + Args: + model, (`transformers.PreTrainedModel`): + The loaded model from `transformers` + """ + loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr( + model, "is_loaded_in_4bit", False + ) + + for name, param in model.named_parameters(): + # freeze base model's layers + param.requires_grad = False + + # cast all non INT8 parameters to fp32 + for param in model.parameters(): + if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + param.data = param.data.to(torch.float32) + + if loaded_in_kbit and use_gradient_checkpointing: + # For backward compatibility + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + model.gradient_checkpointing_enable() + + return model + + +# def get_peft_model(model, peft_config, adapter_name="default") -> PeftModel: +# """ +# Returns a Peft model object from a model and a config. + +# Args: +# model ([`transformers.PreTrainedModel`]): Model to be wrapped. +# peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. +# """ +# model_config = model.config.to_dict() if hasattr(model.config, "to_dict") else model.config +# peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) +# if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( +# peft_config, PromptLearningConfig +# ): +# return PeftModel(model, peft_config, adapter_name=adapter_name) +# if isinstance(peft_config, PromptLearningConfig): +# peft_config = _prepare_prompt_learning_config(peft_config, model_config) +# return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) diff --git a/src/xturing/engines/quant_utils/qerdataloading.py b/src/xturing/engines/quant_utils/qerdataloading.py new file mode 100644 index 0000000..6839da1 --- /dev/null +++ b/src/xturing/engines/quant_utils/qerdataloading.py @@ -0,0 +1,88 @@ +import random + +import torch +from datasets import load_dataset +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import LlamaTokenizer + +# Set all seeds to 0 +torch.manual_seed(0) +torch.cuda.manual_seed(0) +random.seed(0) + + +def get_c4(base_model, seqlen, n_samples, batch_size, seed=0): + traindata, valdata = load_c4_datasets() + + tokenizer = LlamaTokenizer.from_pretrained(base_model, use_fast=False) + + trainloader = create_random_trainloader( + traindata, tokenizer, seqlen, n_samples, seed + ) + valenc = create_random_valenc(valdata, tokenizer, seqlen, seed) + + trainloader = DataLoader( + trainloader, batch_size=batch_size, shuffle=False, drop_last=True + ) + valenc = DataLoader(valenc, batch_size=batch_size, shuffle=False, drop_last=True) + + return trainloader, valenc + + +def load_c4_datasets(): + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + use_auth_token=False, + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + use_auth_token=False, + ) + return traindata, valdata + + +def create_random_trainloader(traindata, tokenizer, seqlen, n_samples, seed=0): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + random.seed(seed) + trainloader = [] + pbar = tqdm(total=n_samples, desc="Creating trainloader") + while len(trainloader) < n_samples: + i = random.randint(0, len(traindata) - 1) + text = traindata[i]["text"] + enc = tokenizer(text, return_tensors="pt") + if enc.input_ids.shape[1] >= seqlen + 1: + start = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) + end = start + seqlen + inp = enc.input_ids[:, start:end] + tar = torch.cat([torch.tensor([-100]), inp.squeeze()[:-1]]).unsqueeze(0) + trainloader.append((inp, tar)) + pbar.update(1) + return trainloader + + +def create_random_valenc(valdata, tokenizer, seqlen, seed=0): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + random.seed(seed) + valenc = [] + pbar = tqdm(total=256, desc="Creating valenc") + while len(valenc) < 256: + i = random.randint(0, len(valdata) - 1) + text = valdata[i]["text"] + enc = tokenizer(text, return_tensors="pt") + if enc.input_ids.shape[1] >= seqlen + 1: + start = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) + end = start + seqlen + inp = enc.input_ids[:, start:end] + tar = torch.cat([torch.tensor([-100]), inp.squeeze()[:-1]]).unsqueeze(0) + valenc.append((inp, tar)) + pbar.update(1) + return valenc diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py index 69299ac..cdc4f39 100644 --- a/src/xturing/models/__init__.py +++ b/src/xturing/models/__init__.py @@ -2,57 +2,60 @@ from .bloom import Bloom, BloomInt8, BloomLora, BloomLoraInt8 from .cerebras import Cerebras, CerebrasInt8, CerebrasLora, CerebrasLoraInt8 from .distilgpt2 import DistilGPT2, DistilGPT2Lora -from .falcon import Falcon, FalconInt8, FalconLora, FalconLoraInt8 +from .falcon import Falcon, FalconInt8, FalconLora, FalconLoraInt8, FalconLoraKbit from .galactica import Galactica, GalacticaInt8, GalacticaLora, GalacticaLoraInt8 from .generic import ( GenericInt8Model, GenericLoraInt8Model, + GenericLoraKbitModel, GenericLoraModel, GenericModel, ) from .gpt2 import GPT2, GPT2Int8, GPT2Lora, GPT2LoraInt8 from .gptj import GPTJ, GPTJInt8, GPTJLora, GPTJLoraInt8 -from .llama import Llama, LlamaInt8, LlamaLora, LlamaLoraInt4, LlamaLoraInt8 +from .llama import Llama, LlamaInt8, LlamaLora, LlamaLoraInt8, LlamaLoraKbit from .opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 from .stable_diffusion import StableDiffusion +BaseModel.add_to_registry(Bloom.config_name, Bloom) +BaseModel.add_to_registry(BloomInt8.config_name, BloomInt8) +BaseModel.add_to_registry(BloomLora.config_name, BloomLora) +BaseModel.add_to_registry(BloomLoraInt8.config_name, BloomLoraInt8) +BaseModel.add_to_registry(Cerebras.config_name, Cerebras) +BaseModel.add_to_registry(CerebrasInt8.config_name, CerebrasInt8) +BaseModel.add_to_registry(CerebrasLora.config_name, CerebrasLora) +BaseModel.add_to_registry(CerebrasLoraInt8.config_name, CerebrasLoraInt8) BaseModel.add_to_registry(DistilGPT2.config_name, DistilGPT2) BaseModel.add_to_registry(DistilGPT2Lora.config_name, DistilGPT2Lora) -BaseModel.add_to_registry(GPT2.config_name, GPT2) -BaseModel.add_to_registry(GPT2Lora.config_name, GPT2Lora) -BaseModel.add_to_registry(GPT2Int8.config_name, GPT2Int8) -BaseModel.add_to_registry(GPT2LoraInt8.config_name, GPT2LoraInt8) +BaseModel.add_to_registry(Falcon.config_name, Falcon) +BaseModel.add_to_registry(FalconInt8.config_name, FalconInt8) +BaseModel.add_to_registry(FalconLora.config_name, FalconLora) +BaseModel.add_to_registry(FalconLoraInt8.config_name, FalconLoraInt8) +BaseModel.add_to_registry(FalconLoraKbit.config_name, FalconLoraKbit) +BaseModel.add_to_registry(Galactica.config_name, Galactica) +BaseModel.add_to_registry(GalacticaInt8.config_name, GalacticaInt8) +BaseModel.add_to_registry(GalacticaLora.config_name, GalacticaLora) +BaseModel.add_to_registry(GalacticaLoraInt8.config_name, GalacticaLoraInt8) +BaseModel.add_to_registry(GenericModel.config_name, GenericModel) +BaseModel.add_to_registry(GenericInt8Model.config_name, GenericInt8Model) +BaseModel.add_to_registry(GenericLoraModel.config_name, GenericLoraModel) +BaseModel.add_to_registry(GenericLoraInt8Model.config_name, GenericLoraInt8Model) +BaseModel.add_to_registry(GenericLoraKbitModel.config_name, GenericLoraKbitModel) BaseModel.add_to_registry(GPTJ.config_name, GPTJ) -BaseModel.add_to_registry(GPTJLora.config_name, GPTJLora) BaseModel.add_to_registry(GPTJInt8.config_name, GPTJInt8) +BaseModel.add_to_registry(GPTJLora.config_name, GPTJLora) BaseModel.add_to_registry(GPTJLoraInt8.config_name, GPTJLoraInt8) +BaseModel.add_to_registry(GPT2.config_name, GPT2) +BaseModel.add_to_registry(GPT2Int8.config_name, GPT2Int8) +BaseModel.add_to_registry(GPT2Lora.config_name, GPT2Lora) +BaseModel.add_to_registry(GPT2LoraInt8.config_name, GPT2LoraInt8) BaseModel.add_to_registry(Llama.config_name, Llama) -BaseModel.add_to_registry(LlamaLora.config_name, LlamaLora) BaseModel.add_to_registry(LlamaInt8.config_name, LlamaInt8) +BaseModel.add_to_registry(LlamaLora.config_name, LlamaLora) BaseModel.add_to_registry(LlamaLoraInt8.config_name, LlamaLoraInt8) -BaseModel.add_to_registry(LlamaLoraInt4.config_name, LlamaLoraInt4) -BaseModel.add_to_registry(Galactica.config_name, Galactica) -BaseModel.add_to_registry(GalacticaLora.config_name, GalacticaLora) -BaseModel.add_to_registry(GalacticaInt8.config_name, GalacticaInt8) -BaseModel.add_to_registry(GalacticaLoraInt8.config_name, GalacticaLoraInt8) +BaseModel.add_to_registry(LlamaLoraKbit.config_name, LlamaLoraKbit) BaseModel.add_to_registry(OPT.config_name, OPT) -BaseModel.add_to_registry(OPTLora.config_name, OPTLora) BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8) +BaseModel.add_to_registry(OPTLora.config_name, OPTLora) BaseModel.add_to_registry(OPTLoraInt8.config_name, OPTLoraInt8) -BaseModel.add_to_registry(Cerebras.config_name, Cerebras) -BaseModel.add_to_registry(CerebrasLora.config_name, CerebrasLora) -BaseModel.add_to_registry(CerebrasInt8.config_name, CerebrasInt8) -BaseModel.add_to_registry(CerebrasLoraInt8.config_name, CerebrasLoraInt8) -BaseModel.add_to_registry(Bloom.config_name, Bloom) -BaseModel.add_to_registry(BloomLora.config_name, BloomLora) -BaseModel.add_to_registry(BloomInt8.config_name, BloomInt8) -BaseModel.add_to_registry(BloomLoraInt8.config_name, BloomLoraInt8) BaseModel.add_to_registry(StableDiffusion.config_name, StableDiffusion) -BaseModel.add_to_registry(GenericModel.config_name, GenericModel) -BaseModel.add_to_registry(GenericLoraModel.config_name, GenericLoraModel) -BaseModel.add_to_registry(GenericInt8Model.config_name, GenericInt8Model) -BaseModel.add_to_registry(GenericLoraInt8Model.config_name, GenericLoraInt8Model) -BaseModel.add_to_registry(Falcon.config_name, Falcon) -BaseModel.add_to_registry(FalconLora.config_name, FalconLora) -BaseModel.add_to_registry(FalconInt8.config_name, FalconInt8) -BaseModel.add_to_registry(FalconLoraInt8.config_name, FalconLoraInt8) diff --git a/src/xturing/models/base.py b/src/xturing/models/base.py index e705cb8..698c298 100644 --- a/src/xturing/models/base.py +++ b/src/xturing/models/base.py @@ -53,7 +53,12 @@ def load_from_local(cls, weights_dir_path): cls.registry.get(model_name) is not None ), "The model_name {} is not valid".format(model_name) - model = cls.create(model_name, weights_path=weights_dir_path) + if "generic" in model_name: + model = cls.create( + model_name, model_name=model_name, weights_path=weights_dir_path + ) + else: + model = cls.create(model_name, weights_path=weights_dir_path) return model diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index 24d0c75..40457d2 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -31,11 +31,13 @@ def __init__( weights_path: Optional[str] = None, model_name: Optional[str] = None, target_modules: Optional[List[str]] = None, + **kwargs, ): arguments = dict( weights_path=weights_path, model_name=model_name, target_modules=target_modules, + **kwargs, ) self.engine = BaseEngine.create( @@ -211,9 +213,12 @@ def __init__( engine: str, weights_path: Optional[str] = None, model_name: Optional[str] = None, + **kwargs, ): assert_not_cpu_int8() - super().__init__(engine, weights_path=weights_path, model_name=model_name) + super().__init__( + engine, weights_path=weights_path, model_name=model_name, **kwargs + ) class CausalLoraModel(CausalModel): @@ -223,12 +228,14 @@ def __init__( weights_path: Optional[str] = None, model_name: Optional[str] = None, target_modules: Optional[List[str]] = None, + **kwargs, ): super().__init__( engine, weights_path=weights_path, model_name=model_name, target_modules=target_modules, + **kwargs, ) def _make_trainer( @@ -258,6 +265,7 @@ def __init__( weights_path: Optional[str] = None, model_name: Optional[str] = None, target_modules: Optional[List[str]] = None, + **kwargs, ): assert_not_cpu_int8() super().__init__( @@ -265,4 +273,11 @@ def __init__( weights_path=weights_path, model_name=model_name, target_modules=target_modules, + **kwargs, ) + + +class CausalLoraKbitModel(CausalLoraModel): + def __init__(self, engine: str, weights_path: Optional[str] = None): + assert_not_cpu_int8() + super().__init__(engine, weights_path) diff --git a/src/xturing/models/falcon.py b/src/xturing/models/falcon.py index 0f68508..a9e7c54 100644 --- a/src/xturing/models/falcon.py +++ b/src/xturing/models/falcon.py @@ -5,10 +5,12 @@ FalconInt8Engine, FalconLoraEngine, FalconLoraInt8Engine, + FalconLoraKbitEngine, ) from xturing.models.causal import ( CausalInt8Model, CausalLoraInt8Model, + CausalLoraKbitModel, CausalLoraModel, CausalModel, ) @@ -40,3 +42,10 @@ class FalconLoraInt8(CausalLoraInt8Model): def __init__(self, weights_path: Optional[str] = None): super().__init__(FalconLoraInt8Engine.config_name, weights_path) + + +class FalconLoraKbit(CausalLoraKbitModel): + config_name: str = "falcon_lora_kbit" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(FalconLoraKbitEngine.config_name, weights_path) diff --git a/src/xturing/models/generic.py b/src/xturing/models/generic.py index ea8400c..05b295a 100644 --- a/src/xturing/models/generic.py +++ b/src/xturing/models/generic.py @@ -1,14 +1,18 @@ -from typing import List, Optional +import json +from pathlib import Path +from typing import List, Optional, Union from xturing.engines.generic_engine import ( GenericEngine, GenericInt8Engine, GenericLoraEngine, GenericLoraInt8Engine, + GenericLoraKbitEngine, ) from xturing.models.causal import ( CausalInt8Model, CausalLoraInt8Model, + CausalLoraKbitModel, CausalLoraModel, CausalModel, ) @@ -17,8 +21,22 @@ class GenericModel(CausalModel): config_name: str = "generic" - def __init__(self, model_name: str, weights_path: Optional[str] = None): - super().__init__(GenericEngine.config_name, weights_path, model_name=model_name) + def __init__(self, model_name: str, weights_path: Optional[str] = None, **kwargs): + super().__init__( + GenericEngine.config_name, weights_path, model_name=model_name, **kwargs + ) + + def _save_config(self, path: Union[str, Path]): + xturing_config_path = Path(path) / "xturing.json" + xturing_config = { + "model_name": self.model_name, + "engine_name": self.engine.model_name, + "finetuning_config": self.finetuning_args.dict(), + "generation_config": self.generation_args.dict(), + } + + with open(str(xturing_config_path), "w", encoding="utf-8") as f: + json.dump(xturing_config, f, ensure_ascii=False, indent=4) class GenericLoraModel(CausalLoraModel): @@ -29,21 +47,23 @@ def __init__( model_name: str, target_modules: List[str] = ["c_attn"], weights_path: Optional[str] = None, + **kwargs, ): super().__init__( GenericLoraEngine.config_name, weights_path, model_name=model_name, target_modules=target_modules, + **kwargs, ) class GenericInt8Model(CausalInt8Model): config_name: str = "generic_int8" - def __init__(self, model_name: str, weights_path: Optional[str] = None): + def __init__(self, model_name: str, weights_path: Optional[str] = None, **kwargs): super().__init__( - GenericInt8Engine.config_name, weights_path, model_name=model_name + GenericInt8Engine.config_name, weights_path, model_name=model_name, **kwargs ) @@ -55,10 +75,29 @@ def __init__( model_name: str, target_modules: List[str] = ["c_attn"], weights_path: Optional[str] = None, + **kwargs, ): super().__init__( GenericLoraInt8Engine.config_name, weights_path, model_name=model_name, target_modules=target_modules, + **kwargs, + ) + + +class GenericLoraKbitModel(CausalLoraKbitModel): + config_name: str = "generic_lora_kbit" + + def __init__( + self, + model_name: str, + target_modules: List[str] = ["c_attn"], + weights_path: Optional[str] = None, + ): + super().__init__( + GenericLoraKbitEngine.config_name, + weights_path, + model_name=model_name, + target_modules=target_modules, ) diff --git a/src/xturing/models/llama.py b/src/xturing/models/llama.py index 85278b4..0624393 100644 --- a/src/xturing/models/llama.py +++ b/src/xturing/models/llama.py @@ -1,22 +1,24 @@ from typing import Iterable, List, Optional, Union + from pytorch_lightning.loggers import Logger +from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.datasets.text_dataset import TextDataset from xturing.engines.llama_engine import ( LLamaEngine, LLamaInt8Engine, LlamaLoraEngine, LlamaLoraInt8Engine, - LlamaLoraInt4Engine, + LlamaLoraKbitEngine, ) from xturing.models.causal import ( CausalInt8Model, CausalLoraInt8Model, + CausalLoraKbitModel, CausalLoraModel, CausalModel, ) from xturing.trainers.base import BaseTrainer -from xturing.datasets.instruction_dataset import InstructionDataset -from xturing.datasets.text_dataset import TextDataset from xturing.trainers.lightning_trainer import LightningTrainer @@ -48,25 +50,28 @@ def __init__(self, weights_path: Optional[str] = None): super().__init__(LlamaLoraInt8Engine.config_name, weights_path) -class LlamaLoraInt4(CausalLoraInt8Model): - config_name: str = "llama_lora_int4" - - def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset], - logger: Union[Logger, Iterable[Logger], bool] = True): - return BaseTrainer.create( - LightningTrainer.config_name, - self.engine, - dataset, - self._make_collate_fn(dataset), - int(self.finetuning_args.num_train_epochs), - int(self.finetuning_args.batch_size), - float(self.finetuning_args.learning_rate), - self.finetuning_args.optimizer_name, - True, - True, - lora_type=32, - logger=logger, - ) +class LlamaLoraKbit(CausalLoraKbitModel): + config_name: str = "llama_lora_kbit" + + # def _make_trainer( + # self, + # dataset: Union[TextDataset, InstructionDataset], + # logger: Union[Logger, Iterable[Logger], bool] = True, + # ): + # return BaseTrainer.create( + # LightningTrainer.config_name, + # self.engine, + # dataset, + # self._make_collate_fn(dataset), + # int(self.finetuning_args.num_train_epochs), + # int(self.finetuning_args.batch_size), + # float(self.finetuning_args.learning_rate), + # self.finetuning_args.optimizer_name, + # True, + # True, + # lora_type=32, + # logger=logger, + # ) def __init__(self, weights_path: Optional[str] = None): - super().__init__(LlamaLoraInt4Engine.config_name, weights_path) + super().__init__(LlamaLoraKbitEngine.config_name, weights_path)