From d47996de8cbc32df4d2830b81d2910adb704aff2 Mon Sep 17 00:00:00 2001 From: Alain Le Noac'h <47578089+glerzing@users.noreply.github.com> Date: Fri, 23 Jun 2023 23:55:36 +0200 Subject: [PATCH] peft to opendelta migration (#434) + memory optimization (#320) (#486) * Migrate to peft from opendelta for parameter efficient tuning methods (#434) + Collapse reference+learner hydra heads when using LoRa (#320) * fix from_config * Review corrections * ILQL generate when temperature is 0. * revert: guard against experimental 8-bit loading support * format: run `black` --------- Co-authored-by: jon-tow Co-authored-by: maxreciprocate <56548574+maxreciprocate@users.noreply.github.com> --- examples/ppo_sentiments_peft.py | 67 ++++ requirements.txt | 1 + tests/test_peft.py | 489 ++++++++++++++++++++++++ tests/test_trainers.py | 24 ++ tests/test_utils.py | 37 -- trlx/data/configs.py | 25 +- trlx/models/modeling_base.py | 156 +++++++- trlx/models/modeling_ilql.py | 159 ++++++-- trlx/models/modeling_ppo.py | 140 +++++-- trlx/trainer/accelerate_base_trainer.py | 57 ++- trlx/trainer/accelerate_ilql_trainer.py | 1 + trlx/trainer/accelerate_ppo_trainer.py | 11 +- trlx/trainer/accelerate_sft_trainer.py | 24 +- trlx/utils/__init__.py | 5 + trlx/utils/modeling.py | 254 +----------- 15 files changed, 1046 insertions(+), 404 deletions(-) create mode 100644 examples/ppo_sentiments_peft.py create mode 100644 tests/test_peft.py diff --git a/examples/ppo_sentiments_peft.py b/examples/ppo_sentiments_peft.py new file mode 100644 index 000000000..1409a02c7 --- /dev/null +++ b/examples/ppo_sentiments_peft.py @@ -0,0 +1,67 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from peft import LoraConfig +from peft.utils.config import TaskType +from transformers import pipeline + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + # Just insert your peft config here (the type must be an instance of peft.PeftConfig or a dict). + config.model.peft_config = LoraConfig( + r=8, + task_type=TaskType.CAUSAL_LM, + lora_alpha=32, + lora_dropout=0.1, + ) + + def reward_fn(samples: List[str], **kwargs) -> List[float]: + sentiments = list(map(get_positive_score, sentiment_fn(samples))) + return sentiments + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + trlx.train( + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/requirements.txt b/requirements.txt index 7b220c4d9..9770d2033 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,6 +43,7 @@ numpy==1.24.3 packaging==23.1 pandas==2.0.1 pathtools==0.1.2 +peft==0.3.0 pkgutil_resolve_name==1.3.10 platformdirs==3.5.0 protobuf==4.22.3 diff --git a/tests/test_peft.py b/tests/test_peft.py new file mode 100644 index 000000000..ffe6f8bcb --- /dev/null +++ b/tests/test_peft.py @@ -0,0 +1,489 @@ +import copy +import gc +import os +import sys +import tempfile +import unittest +from typing import Optional + +import numpy as np +import torch +import transformers +from peft import get_peft_config, get_peft_model +from peft.utils.config import PeftType, TaskType +from transformers import AutoConfig, AutoModelForCausalLM + +from trlx.data.configs import TokenizerConfig +from trlx.data.default_configs import ( + ModelConfig, + default_ilql_config, + default_ppo_config, + default_sft_config, +) +from trlx.models.modeling_ilql import ( + AutoModelForCausalLMWithILQLHeads, + AutoModelForSeq2SeqLMWithILQLHeads, +) +from trlx.models.modeling_ppo import ( + AutoModelForCausalLMWithHydraValueHead, + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithHydraValueHead, +) +from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer +from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer +from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer + +PPO = "ppo" +ILQL = "ilql" +SFT = "sft" +TRAINING_TYPES = [PPO, ILQL, SFT] + +CAUSAL = "causal" +SEQ2SEQ = "seq2seq" + +MODEL_TASK_TYPE = { + "gpt2": CAUSAL, + "google/t5-efficient-tiny": SEQ2SEQ, + # "EleutherAI/pythia-160m": CAUSAL, + # "facebook/opt-125m": CAUSAL, +} +MODELS_TO_TEST = list(MODEL_TASK_TYPE.keys()) + +PEFT_CONFIGS_TO_TEST = [PeftType.LORA, PeftType.PROMPT_TUNING, PeftType.PREFIX_TUNING] + +ALL_TEST_COMBINATIONS = [ + [training_type, model_path, peft_type] + for training_type in TRAINING_TYPES + for model_path in MODELS_TO_TEST + for peft_type in PEFT_CONFIGS_TO_TEST + if [training_type, MODEL_TASK_TYPE[model_path]] != [SFT, SEQ2SEQ] # Seq2Seq SFT not implemented + and (MODEL_TASK_TYPE[model_path] != SEQ2SEQ or peft_type == PeftType.LORA) + # Skip some tests due to implementation problems of peft 0.3.0 with Seq2Seq +] + + +class TestPeft(unittest.TestCase): + def setUp(self): + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + + def tearDown(self): + gc.collect() # Try to free up memory + + def _create_model( + self, + training_type: str, + model_path: str, + task_type: str, + peft_type: Optional[str], + create_trainer: bool = False, + ): + self.peft_config = self._get_peft_config(peft_type, task_type) if peft_type else None + if create_trainer: + self.trainer = self._get_trainer(training_type, model_path, task_type, self.peft_config) + self.model = self.trainer.model.to("cpu") + else: + # Should be a bit faster to execute than creating a trainer. + if training_type == SFT: + self.model = AutoModelForCausalLM.from_pretrained(model_path) + if self.peft_config: + self.model = get_peft_model(self.model, self.peft_config) + else: + self.model = self._get_auto_model_type(training_type, task_type).from_pretrained( + model_path, + peft_config=self.peft_config, + ) + + self._create_inputs(model_path, task_type) + + def _create_inputs(self, tokenizer_path, task_type): + self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path) + + if task_type == CAUSAL: + self.inputs = self.tokenizer( + "Once upon a time there was a happy goose named Louis. He liked to eat bananas and", + return_tensors="pt", + ) + elif task_type == SEQ2SEQ: + self.encoder_text = "Translate this text to French: Hello, my dog is cute" + self.decoder_text = "Bonjour, mon chien est mignon" + encoder_inputs = self.tokenizer(self.encoder_text, return_tensors="pt") + decoder_inputs = self.tokenizer(self.decoder_text, return_tensors="pt") + self.inputs = { + **encoder_inputs, + "decoder_input_ids": decoder_inputs.input_ids, + "decoder_attention_mask": decoder_inputs.attention_mask, + } + else: + # Classification tasks not implemented + raise NotImplementedError + + def _get_trainer(self, training_type, model_path: str, task_type: str, peft_config, tokenizer_path: str = None): + if training_type == PPO: + config = default_ppo_config() + trainer_type = AcceleratePPOTrainer + elif training_type == ILQL: + config = default_ilql_config() + trainer_type = AccelerateILQLTrainer + elif training_type == SFT: + config = default_sft_config() + trainer_type = AccelerateSFTTrainer + else: + raise ValueError(f"Training type {training_type} not recognized.") + + config.tokenizer = TokenizerConfig(tokenizer_path=tokenizer_path if tokenizer_path else model_path) + config.model = ModelConfig(model_path=model_path, peft_config=peft_config, model_arch_type=task_type) + config.train.tracker = None + + return trainer_type(config) + + def _get_auto_model_type(self, training_type, task_type): + if training_type == PPO: + if task_type == CAUSAL: + return AutoModelForCausalLMWithHydraValueHead + elif task_type == SEQ2SEQ: + return AutoModelForSeq2SeqLMWithHydraValueHead + elif training_type == ILQL: + if task_type == CAUSAL: + return AutoModelForCausalLMWithILQLHeads + elif task_type == SEQ2SEQ: + return AutoModelForSeq2SeqLMWithILQLHeads + elif training_type == SFT and task_type == CAUSAL: + return AutoModelForCausalLM + + raise ValueError(f"Training type {training_type} for the task {task_type} not recognized.") + + def _get_peft_config(self, peft_type: str, task_type: str): + assert task_type in [CAUSAL, SEQ2SEQ] + task_type = TaskType.CAUSAL_LM if task_type == "causal" else TaskType.SEQ_2_SEQ_LM + + if peft_type == PeftType.LORA: + return get_peft_config( + { + "peft_type": peft_type, + "task_type": task_type, + "r": 8, + "lora_alpha": 32, + "lora_dropout": 0.0, + } + ) + elif peft_type == PeftType.PREFIX_TUNING: + return get_peft_config( + { + "peft_type": peft_type, + "task_type": task_type, + "num_virtual_tokens": 10, + } + ) + elif peft_type == PeftType.PROMPT_TUNING: + return get_peft_config( + { + "peft_type": peft_type, + "task_type": task_type, + "prompt_tuning_init": "RANDOM", + "num_virtual_tokens": 10, + } + ) + else: + raise NotImplementedError + + def _backprop(self, model): + output = model(**self.inputs, return_dict=True) + # Just apply an arbitrary loss to cause whatever change in the model's parameters. + # This loss doesn't make sense, but it causes a gradient, so it's fine. + loss = torch.nn.functional.binary_cross_entropy_with_logits( + output.logits[0][-1][:1], + torch.tensor([0.53]), + ) + + if hasattr(output, "value"): + loss += torch.nn.functional.binary_cross_entropy_with_logits( + output.value.squeeze()[-1:], + torch.tensor([0.53]), + ) + + loss.backward() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + optimizer.step() + + return model + + def _check_that_models_are_equivalent(self, model1, model2, training_type, test_hydra=False): + self.assertTrue( + torch.equal(model1(**self.inputs, return_dict=True).logits, model2(**self.inputs, return_dict=True).logits) + ) + + state_dict1 = model1.state_dict() + state_dict2 = model2.state_dict() + self.assertEqual(state_dict1.keys(), state_dict2.keys()) + for name in state_dict1.keys(): + self.assertTrue(torch.equal(state_dict1[name], state_dict2[name])) + + if training_type != SFT: + self.assertTrue( + torch.equal( + model1(**self.inputs, return_dict=True).value, + model2(**self.inputs, return_dict=True).value, + ) + ) + + if training_type == PPO and test_hydra: + self.assertTrue( + torch.equal( + model1.forward_hydra(**self.inputs, return_dict=True).logits, + model2.forward_hydra(**self.inputs, return_dict=True).logits, + ) + ) + + def test_save_and_load(self): + for training_type in [PPO, ILQL]: + for model_path in MODELS_TO_TEST: + peft_type = PeftType.LORA + task_type = MODEL_TASK_TYPE[model_path] + self._create_model(training_type, model_path, task_type, peft_type) + self._backprop(self.model) + + with tempfile.TemporaryDirectory() as tmp_dir: + self.model.save_pretrained(tmp_dir) + + self.assertTrue(os.path.isfile(f"{tmp_dir}/adapter_model.bin")) + self.assertTrue(os.path.isfile(f"{tmp_dir}/adapter_config.json")) + self.assertTrue(os.path.isfile(f"{tmp_dir}/pytorch_model.bin")) + + # Check that it didn't save the whole model (which weights around 500MB) + # pytorch_model.bin should only contain the other trained parts like the value heads. + # ILQL heads are very big though (around 1.1GB for gpt2). + self.assertLess(os.path.getsize(f"{tmp_dir}/pytorch_model.bin"), 1.3e9 if ILQL else 1e7) + + auto_model_type = self._get_auto_model_type(training_type, task_type) + + loaded_model = auto_model_type.from_pretrained(tmp_dir) + self._check_that_models_are_equivalent(loaded_model, self.model, training_type, True) + + def test_from_config(self): + """Check that from_config will add a peft adapter if given the argument peft_config""" + for training_type in TRAINING_TYPES: + peft_config = self._get_peft_config(PeftType.LORA, CAUSAL) + gpt2_config = AutoConfig.from_pretrained("gpt2") + trainer = self._get_trainer(training_type, gpt2_config, CAUSAL, peft_config, tokenizer_path="gpt2") + state_dict = trainer.model.state_dict() + + self.assertTrue(any(["lora" in layer_name for layer_name in state_dict.keys()])) + + def test_save_and_load_without_peft(self): + """Similar to test_save_load, but with peft not installed. Should not raise any error.""" + with unittest.mock.patch.dict(sys.modules, {"peft": None}): + for training_type in [PPO, ILQL]: + for model_path in MODELS_TO_TEST: + task_type = MODEL_TASK_TYPE[model_path] + self._create_model(training_type, model_path, task_type, peft_type=None) + self._backprop(self.model) + + with tempfile.TemporaryDirectory() as tmp_dir: + self.model.save_pretrained(tmp_dir) + auto_model_type = self._get_auto_model_type(training_type, task_type) + + loaded_model = auto_model_type.from_pretrained(tmp_dir) + self._check_that_models_are_equivalent(loaded_model, self.model, training_type) + + def test_backpropagation_and_disabling(self): + for training_type, model_path, peft_type in ALL_TEST_COMBINATIONS: + task_type = MODEL_TASK_TYPE[model_path] + self._create_model(training_type, model_path, task_type, peft_type, create_trainer=True) + old_logits = self.model(**self.inputs, return_dict=True).logits + initial_model_state_dict = copy.deepcopy(self.model.state_dict()) + + self._backprop(self.model) + self._backprop(self.model) + new_logits = self.model(**self.inputs, return_dict=True).logits + new_model_state_dict = self.model.state_dict() + + # Check that the backpropagation affected the predictions + self.assertFalse(torch.equal(old_logits, new_logits)) + + # Check that only the peft adapter layers are modified by the backpropagation + self.assertEqual(initial_model_state_dict.keys(), new_model_state_dict.keys()) + for name in initial_model_state_dict.keys(): + parameters_equal = torch.equal(initial_model_state_dict[name], new_model_state_dict[name]) + if "lora" in name or "prompt" in name or "v_head" in name: + self.assertFalse(parameters_equal) + else: + self.assertTrue(parameters_equal) + + # Check Lora enabling and disabling + if "LORA" in peft_type: + # If disabling the Lora adapter restores the original logits, + # this shows that the backpropagation only affected the Lora adapter + self.lora_model = self.model.base_model if training_type != SFT else self.model + self.lora_model.disable_adapter_layers() + new_logits = self.model(**self.inputs, return_dict=True).logits + self.assertTrue(torch.equal(old_logits, new_logits)) + + # Re-enabling the Lora adapter should make the 2 models different again + self.lora_model.enable_adapter_layers() + new_logits = self.model(**self.inputs, return_dict=True).logits + self.assertFalse(torch.equal(old_logits, new_logits)) + + def test_forward_hydra(self): + """Test that PPO hydra heads work and give similar logits to the model without any fine-tuning.""" + for model_path in MODELS_TO_TEST: + for peft_type in PEFT_CONFIGS_TO_TEST: + task_type = MODEL_TASK_TYPE[model_path] + if task_type == SEQ2SEQ and peft_type != PeftType.LORA: + continue # TODO: pass some tests due to some bugs in peft 0.3.0 with Seq2Seq + + self._create_model(PPO, model_path, task_type, peft_type) + + logits_without_peft = self.model.base_model.base_model(**self.inputs, return_dict=True).logits + logits_before_backpropagation = self.model(**self.inputs, return_dict=True).logits + + self._backprop(self.model) + + # forward_hydra should return the same logits as the original model + new_logits_from_hydra = self.model.forward_hydra(**self.inputs, return_dict=True).logits + self.assertTrue(torch.equal(logits_without_peft, new_logits_from_hydra)) + + if "LORA" in peft_type: + # True because the Lora adapter initially does not modify the output + self.assertTrue(torch.equal(logits_before_backpropagation, new_logits_from_hydra)) + else: + # False because the initial prompt before backpropagation + # was used to calculate logits_before_backpropagation, but not for new_logits_from_hydra. + self.assertFalse(torch.equal(logits_before_backpropagation, new_logits_from_hydra)) + + def test_generate(self): + """ + Check that generate works, and that it's deterministic when the temperature is very low. + """ + temperature = 0.0 + + for training_type, model_path, peft_type in ALL_TEST_COMBINATIONS: + task_type = MODEL_TASK_TYPE[model_path] + self._create_model(training_type, model_path, task_type, peft_type) + self._backprop(self.model) + with torch.no_grad(): + output1 = self.model.generate( + **self.inputs, + temperature=temperature, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + output2 = self.model.generate( + **self.inputs, + temperature=temperature, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + self.assertTrue(torch.equal(output1, output2)) + + def test_peft_not_installed_error(self): + """If the argument peft_config is used but peft is not installed, expect a ModuleNotFoundError""" + with unittest.mock.patch.dict(sys.modules, {"peft": None}): + peft_config = {"peft_type": "LORA"} + + with self.assertRaises(ModuleNotFoundError): + self._get_trainer(PPO, "gpt2", CAUSAL, peft_config) + + with self.assertRaises(ModuleNotFoundError): + AutoModelForCausalLMWithHydraValueHead.from_pretrained("gpt2", peft_config=peft_config) + + def test_lora_modules_to_save(self): + """ + Test the special Lora config option 'modules_to_save'. + It allows also train some non-lora modules, and its implementation is a bit tricky. + """ + for training_type in [PPO, ILQL]: + trainable_layer_name = "base_model.model.transformer.h.3.mlp" + + peft_config = { + "peft_type": PeftType.LORA, + "task_type": CAUSAL, + "r": 8, + "lora_alpha": 32, + "lora_dropout": 0.0, + "modules_to_save": [trainable_layer_name], + } + + model = self._get_auto_model_type(training_type, CAUSAL).from_pretrained("gpt2", peft_config=peft_config) + initial_state_dict = copy.deepcopy(model.state_dict()) + self._create_inputs("gpt2", CAUSAL) + # initial_logits = model(**self.inputs, return_dict=True).logits + + self._backprop(model) + self._backprop(model) + new_state_dict = model.state_dict() + + self.assertEqual(initial_state_dict.keys(), new_state_dict.keys()) + for name in initial_state_dict.keys(): + parameters_equal = torch.equal(initial_state_dict[name], new_state_dict[name]) + if trainable_layer_name + ".modules_to_save" in name or "lora" in name or "v_head" in name: + self.assertFalse(parameters_equal) + else: + self.assertTrue(parameters_equal) + + # TODO: deactivated until the issue (https://github.com/huggingface/peft/issues/493) is fixed + # if training_type == PPO: + # forward_hydra_logits = model.forward_hydra(**self.inputs, return_dict=True).logits + # self.assertTrue(torch.equal(initial_logits, forward_hydra_logits)) + + trained_model_logits = model(**self.inputs, return_dict=True).logits + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + loaded_model = self._get_auto_model_type(training_type, CAUSAL).from_pretrained(tmp_dir) + loaded_model_logits = loaded_model(**self.inputs, return_dict=True).logits + self.assertTrue(torch.equal(trained_model_logits, loaded_model_logits)) + + # @unittest.skipUnless( + # importlib.util.find_spec("bitsandbytes") and torch.cuda.is_available(), + # "bitsandbytes and GPU needed to execute test_8bits", + # ) + @unittest.skip("`8-bit` model loading support is not yet fully implemented") + def test_8bits(self): + """Test the behaviour of from_pretrained with 8 bits models""" + from bitsandbytes.nn import Linear8bitLt + + # gpt2 uses Conv1D instead of Linear, so use pythia-160m instead. + model_id = "EleutherAI/pythia-160m" + + peft_config = { + "peft_type": PeftType.LORA, + "task_type": TaskType.CAUSAL_LM, + "lora_dropout": 0.0, + "lora_alpha": 32, + } + reference_model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_id, + peft_config=peft_config, + ) + initial_nb_trainable_params = sum(p.numel() for p in reference_model.parameters() if p.requires_grad) + + model_8bit = AutoModelForCausalLMWithValueHead.from_pretrained( + model_id, + peft_config=peft_config, + load_in_8bit=True, + peft_int8_kwargs={"use_gradient_checkpointing": True}, + device_map="auto", + ) + + new_nb_trainable_params = sum(p.numel() for p in model_8bit.parameters() if p.requires_grad) + self.assertEqual(new_nb_trainable_params, initial_nb_trainable_params) + + self.assertIsInstance(reference_model.base_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h, torch.nn.Linear) + self.assertIsInstance(model_8bit.base_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h, Linear8bitLt) + + base_model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto") + model_8bit = AutoModelForCausalLMWithValueHead.from_pretrained( + base_model, + peft_config=peft_config, + load_in_8bit=True, + peft_int8_kwargs={"use_gradient_checkpointing": False}, + device_map="auto", + ) + + new_nb_trainable_params = sum(p.numel() for p in model_8bit.parameters() if p.requires_grad) + self.assertEqual(new_nb_trainable_params, initial_nb_trainable_params) + + self.assertIsInstance(model_8bit.base_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h, Linear8bitLt) diff --git a/tests/test_trainers.py b/tests/test_trainers.py index 667c1c766..2e8810e3b 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -134,6 +134,30 @@ def test_save_checkpoint(self): self.assertTrue(os.path.isdir(os.path.join(tmpdir, f"checkpoint_{total_steps}"))) self.assertTrue(os.path.isdir(os.path.join(tmpdir, "best_checkpoint"))) + def test_save_lora_checkpoint(self): + with tempfile.TemporaryDirectory() as tmp_dir: + config = self.get_default_config() + config.train.checkpoint_dir = tmp_dir + config.model.peft_config = { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 32, + "lora_dropout": 0.0, + } + + trainer = self.get_trainer(config) + trainer.learn() + + total_steps = config.train.total_steps + interval = config.train.checkpoint_interval + for i in range(interval, total_steps + 1, interval): + checkpoint_dir = os.path.join(tmp_dir, f"checkpoint_{i}") + self.assertTrue(os.path.isdir(checkpoint_dir)) + if total_steps % interval != 0: + self.assertTrue(os.path.isdir(os.path.join(tmp_dir, f"checkpoint_{total_steps}"))) + self.assertTrue(os.path.isdir(os.path.join(tmp_dir, "best_checkpoint"))) + def test_accumulate_context(self): config = self.get_default_config() trainer = self.get_trainer(config) diff --git a/tests/test_utils.py b/tests/test_utils.py index f3c09c23b..16111375e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -92,43 +92,6 @@ def test_hf_attr_getters(model_name: str): assert False, "Failed to get config attribute with error: " + str(e) -@pytest.mark.parametrize( - "model_name", - [ - "EleutherAI/gpt-j-6B", - "EleutherAI/gpt-neox-20b", - "facebook/opt-1.3b", - "bigscience/bloom-560m", - "google/flan-t5-large", - ], -) -def test_parse_delta_kwargs(model_name): - config = transformers.AutoConfig.from_pretrained(model_name) - - modified_modules_dict = modeling_utils.MODIFIED_MODULES_DICT[config.model_type] - for default_modifier, default_modified_modules in modified_modules_dict.items(): - delta_type, delta_kwargs = modeling_utils.parse_delta_kwargs( - delta_kwargs={"delta_type": "lora", "modified_modules": default_modifier}, - config=config, - num_layers_unfrozen=4, - ) - # Ensure the parsed module regex patterns capture the default module names - for kwarg_mod, default_mod in zip(delta_kwargs["modified_modules"], default_modified_modules): - assert kwarg_mod.endswith( - default_mod - ), f"Parsed modified module `{kwarg_mod}` should contain the trlx default `{default_mod}`" - assert delta_type == "lora", "Delta type should be lora" - - # Ensure the defaults don't get used if the user specifies a list of `modified_modules` - delta_type, delta_kwargs = modeling_utils.parse_delta_kwargs( - delta_kwargs={"delta_type": "lora", "modified_modules": ["a", "b"]}, - config=config, - num_layers_unfrozen=2, - ) - for kwarg_mod in delta_kwargs["modified_modules"]: - assert kwarg_mod.endswith("a") or kwarg_mod.endswith("b"), "Parsed modified module should contain ['a', 'b']" - - class TestStatistics(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 8b2af9ccb..d2cb621e2 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -49,25 +49,22 @@ class ModelConfig: -1 means all layers are unfrozen. :type num_layers_unfrozen: int - :param delta_kwargs: Keyword arguments for instantiating OpenDelta models for delta-tuning. - Follow the `OpenDelta.AutoDeltaConfig` specification, e.g. for LoRA style tuning, set - the `delta_type` to `lora` and include the model specific hyper-parameters (e.g. `lora_r`) - {"delta_type": "lora", "modified_modules": "all", "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.0} - or in YAML format: - delta_kwargs: - delta_type: lora - modified_modules: "all" - lora_r: 8 - lora_alpha: 16 - lora_dropout: 0.0 - See: https://opendelta.readthedocs.io/en/latest/modules/auto_delta.html#opendelta.auto_delta.AutoDeltaConfig - :type delta_kwargs: Optional[Dict[str, Any]] + :param peft_config: configuration for peft (Parameter Efficient Fine-Tuning library). + Peft is designed to reduce the number of parameters to train and the memory footprint, + without significant performance loss. It supports multiple techniques such as LORA + or prefix tuning (cf. https://github.com/huggingface/peft). + + Here is an example of LORA configuration: + {"peft_type": "LORA", "r": 8, "lora_alpha": 32, "lora_dropout": 0.1} + + (parameter-efficient fine-tuning was previously done in trlx with OpenDelta, but it is no longer supported) + :type peft_config: Union[peft.PeftConfig, Dict[str, Any]] """ model_path: str model_arch_type: str = "causal" num_layers_unfrozen: int = -1 - delta_kwargs: Optional[Dict[str, Any]] = None + peft_config: Any = None @classmethod def from_dict(cls, config: Dict[str, Any]): diff --git a/trlx/models/modeling_base.py b/trlx/models/modeling_base.py index 7fa8dfb3e..d0ddbcae5 100644 --- a/trlx/models/modeling_base.py +++ b/trlx/models/modeling_base.py @@ -26,6 +26,20 @@ import transformers from huggingface_hub import hf_hub_download +import trlx.utils.logging as logging +from trlx.utils import is_peft_available + +logger = logging.get_logger(__name__) + +if is_peft_available(): + from peft import ( + PeftConfig, + PeftModel, + get_peft_config, + get_peft_model, + prepare_model_for_int8_training, + ) + class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin): """A wrapper around `transformers.PreTrainedModel` @@ -50,13 +64,21 @@ class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin): # TODO (jon-tow): Supported args should come from a `PretrainedConfig` of the # specific underlying type similar to how config instances can be used to instantiate # `transformers.PreTrainedModel`s. - _supported_args: List[str] = None + _supported_args: List[str] = [] - def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, **kwargs): + def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, peft_config=None, **kwargs): super().__init__() self.base_model = base_model # cache `forward` args for general use (avoids incompatible args across architectures) self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args + self.is_loaded_in_8bit = getattr(base_model, "is_loaded_in_8bit", False) + if self.is_loaded_in_8bit: + # TODO(glerzing): Fully test and support loading in 8-bit + raise NotImplementedError( + "`is_loaded_in_8bit` is an experimental feature not yet fully supported. Please do not use it." + ) + self.peft_config = peft_config + self.peft_type = peft_config.peft_type if peft_config else None @classmethod def _split_kwargs(cls, kwargs: Dict[str, Any]): @@ -73,12 +95,13 @@ def _split_kwargs(cls, kwargs: Dict[str, Any]): return supported_kwargs, unsupported_kwargs @classmethod - def from_config(cls, config: transformers.PretrainedConfig, **kwargs): + def from_config(cls, config: transformers.PretrainedConfig, peft_config=None, **kwargs): """Instantiate the pretrained pytorch model from a configuration. Args: config (transformers.PretrainedConfig): The configuration to use to instantiate the base model. + peft_config (peft.PeftConfig or dict, *optional*): Configuration for the peft adapter NOTE: Loading a model from its configuration file does **not** load the model weights. It only affects the model's configuration. Use @@ -90,6 +113,12 @@ def from_config(cls, config: transformers.PretrainedConfig, **kwargs): from_config_kwargs = {} wrapped_model_kwargs = {} base_model = cls._auto_model_parent_class.from_config(config, **from_config_kwargs) + if peft_config: + if isinstance(peft_config, dict): + peft_config = get_peft_config(peft_config) + base_model = get_peft_model(base_model, peft_config) + wrapped_model_kwargs["peft_config"] = peft_config + model = cls(base_model, **wrapped_model_kwargs) return model @@ -98,6 +127,7 @@ def from_pretrained( # noqa: max-complexity cls, pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel], revision=None, + peft_config=None, *model_args, **kwargs, ): @@ -109,6 +139,17 @@ def from_pretrained( # noqa: max-complexity Args: pretrained_model_name_or_path (str or `transformers.PreTrainedModel`): The identifier of the pretrained model to load or the pretrained model itself. + revision (str, *optional*): Optional specific Git branch, tag or commit hash. + peft_config (peft.PeftConfig or dict, *optional*): The peft configuration to create a peft adapter. + This is *only useful when creating a new peft adapter, not when loading an already trained adapter*. + To load an already trained peft adapter, set `pretrained_model_name_or_path` to the directory containing + the trained adapter, which contains at least 2 files: a config file ("adapter_config.json" by default), + and a file containing the weights ("adapter_model.bin" by default). If there is a value head, + it will be loaded from this directory as well. + For additional argument to give to PeftModel.from_pretrained (such as adapter_name or subdir), + use the dict argument `peft_from_pretrained_kwargs`. There is also a dict argument + `peft_int8_kwargs` for specific options with 8-bit models. These arguments will be + retrieved from kwargs. *model_args (sequence of positional arguments, *optional*): All remaining positional arguments will be passed to the `_auto_model_parent_class`. **kwargs (dict, *optional*): @@ -119,24 +160,114 @@ def from_pretrained( # noqa: max-complexity NOTE: You must pass in arguments specific to the wrapped model as keyword arguments. """ if kwargs is not None: + peft_from_pretrained_kwargs = kwargs.pop("peft_from_pretrained_kwargs", {}) + peft_int8_kwargs = kwargs.pop("peft_int8_kwargs", {}) wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs) else: + peft_from_pretrained_kwargs = {} + peft_int8_kwargs = {} from_pretrained_kwargs = {} wrapped_model_kwargs = {} if isinstance(pretrained_model_name_or_path, str): - # Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM) - base_model = cls._auto_model_parent_class.from_pretrained( - pretrained_model_name_or_path, *model_args, revision=revision, **from_pretrained_kwargs + is_loaded_in_8bit = ( + from_pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in from_pretrained_kwargs else False + ) + else: + is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) + + if is_loaded_in_8bit: + # TODO(glerzing): Fully test and support loading in 8-bit + raise NotImplementedError( + "`is_loaded_in_8bit` is an experimental feature not yet fully supported. Please do not use it." ) + + if peft_config is not None: + if not is_peft_available(): + raise ModuleNotFoundError("To use the argument peft_config, please install `peft`") + if not isinstance(peft_config, PeftConfig): + if isinstance(peft_config, dict): + peft_config = get_peft_config(peft_config) + else: + raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.") + + if isinstance(pretrained_model_name_or_path, str): + # Check if there is a local peft adapter + local_peft_adapter = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json")) + + if local_peft_adapter and not is_peft_available(): + logger.warning("WARNING: peft adapter detected but peft is not installed. Ignoring the adapter...") + + base_model = None + if is_peft_available(): + try: + trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) + except ValueError: + trained_adapter_config = None + + if peft_config is not None: + if trained_adapter_config is not None: + logger.warning( + f"WARNING: peft config file detected in {pretrained_model_name_or_path}" + " but ignored since the argument `peft_config` is provided. Remove the" + " argument `peft_config` to use the trained peft adapter." + ) + + # Create a new peft adapter with the given config + base_model = cls._auto_model_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs + ) + + if is_loaded_in_8bit: + base_model = prepare_model_for_int8_training( + base_model, + **peft_int8_kwargs, + ) + base_model = get_peft_model(base_model, peft_config) + logger.info("peft adapter initialised") + + elif trained_adapter_config is not None: + peft_config = trained_adapter_config + + # Use the pretrained (local or remote) peft adapter file "adapter_config.json" + base_model = cls._auto_model_parent_class.from_pretrained( + trained_adapter_config.base_model_name_or_path, *model_args, **from_pretrained_kwargs + ) + + # Load the peft weights in "adapter_model.bin" and wrap the base model with a PeftModel + base_model = PeftModel.from_pretrained( + base_model, + pretrained_model_name_or_path, + **peft_from_pretrained_kwargs, + ) + logger.info("Trained peft adapter loaded") + + if base_model is None: + # No peft + base_model = cls._auto_model_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs + ) + elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel): base_model = pretrained_model_name_or_path + + if peft_config is not None: + if is_loaded_in_8bit: + base_model = prepare_model_for_int8_training( + base_model, + **peft_int8_kwargs, + ) + base_model = get_peft_model(base_model, peft_config) + logger.info("peft adapter initialised") else: raise ValueError( f"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}" "Expected `str` or `transformers.PreTrainedModel`." ) + if peft_config is not None: + wrapped_model_kwargs["peft_config"] = peft_config + model = cls(base_model, **wrapped_model_kwargs) if isinstance(pretrained_model_name_or_path, str): @@ -203,6 +334,13 @@ def save_pretrained(self, *args, **kwargs): state_dict = self.state_dict() kwargs["state_dict"] = state_dict + if self.peft_type: + # Save the heads, which are not part of the peft adapter + save_path = os.path.join(args[0], "pytorch_model.bin") + head_state_dict = self.state_dict(heads_only=True) + + torch.save(head_state_dict, save_path) + return self.base_model.save_pretrained(*args, **kwargs) def state_dict(self, *args, **kwargs): @@ -214,7 +352,11 @@ def post_init(self, *args, **kwargs): instantiated and loaded from a checkpoint. It can be used to perform additional operations such as loading the state_dict. """ - raise NotImplementedError + if self.peft_type: + # Don't use the interface of the peft model, + # use the interface of the underlying transformer model instead. + # (peft adds 2 "base_model" layers) + self.forward_kwargs = inspect.getfullargspec(self.base_model.base_model.base_model.forward).args def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: """Filter out arguments not supported by the specific instance of diff --git a/trlx/models/modeling_ilql.py b/trlx/models/modeling_ilql.py index 07bea7e5a..4e3023218 100644 --- a/trlx/models/modeling_ilql.py +++ b/trlx/models/modeling_ilql.py @@ -1,9 +1,10 @@ import gc import os +from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass from functools import reduce -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import deepspeed # type: ignore import numpy as np @@ -12,6 +13,7 @@ import transformers from torch import nn from torchtyping import TensorType +from transformers.modeling_outputs import ModelOutput from trlx.data.ilql_types import ILQLBatch from trlx.data.method_configs import MethodConfig, register_method @@ -193,8 +195,18 @@ def sync_target_q_heads(self): self._sync_target_q_heads(self.alpha) +@dataclass +class CausalILQLOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + value: Optional[torch.FloatTensor] = None + qs: Optional[Tuple[torch.FloatTensor]] = None + target_qs: Optional[Tuple[torch.FloatTensor]] = None + + class AutoModelForCausalLMWithILQLHeads(PreTrainedModelWrapper): - """An `AutoModel` class wrapper for `transformers` causal models wtih a language + """An `AutoModel` class wrapper for `transformers` causal models with a language modeling head and ILQL heads. References: @@ -204,7 +216,7 @@ class AutoModelForCausalLMWithILQLHeads(PreTrainedModelWrapper): _auto_model_parent_class = transformers.AutoModelForCausalLM _supported_modules = ["ilql_heads"] - _supported_args = ["two_qs", "alpha"] + _supported_args = ["two_qs", "alpha", "peft_config"] def __init__( self, @@ -212,8 +224,9 @@ def __init__( *, two_qs: bool = True, alpha: float = 0.99, + peft_config=None, ): - super().__init__(base_model) + super().__init__(base_model, peft_config=peft_config) hidden_size = hf_get_hidden_size(self.base_model.config) vocab_size = self.base_model.config.vocab_size dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype @@ -229,6 +242,8 @@ def forward( past_key_values=None, actions_ixs=None, states_ixs=None, + return_dict=False, + bypass_peft_prompt_adapter=False, ): forward_kwargs = self.get_compatible_forward_kwargs( input_ids=input_ids, @@ -238,9 +253,19 @@ def forward( ) forward_kwargs["output_hidden_states"] = True - outputs = self.base_model(**forward_kwargs) + if self.peft_type == "PREFIX_TUNING" and not bypass_peft_prompt_adapter: + # Peft redefines past_key_values, remove it to avoid an exception. + forward_kwargs.pop("past_key_values", None) + + if bypass_peft_prompt_adapter: + outputs = self.base_model.base_model(**forward_kwargs) + else: + outputs = self.base_model(**forward_kwargs) qs, target_qs, vs = self.ilql_heads(outputs.hidden_states[-1], states_ixs=states_ixs, actions_ixs=actions_ixs) + if return_dict: + return CausalILQLOutput(outputs.logits, outputs.past_key_values, outputs.hidden_states, vs, qs, target_qs) + return outputs.logits, qs, target_qs, vs, outputs.past_key_values def generate( @@ -259,7 +284,7 @@ def generate( eos_token_id=None, ): """ - Generates samples akin to hf's `.generate` but with custom logp prepossessing: + Generates samples akin to hf's `.generate` but with custom logp preprocessing: changing token probabilities as to how advantageous they would be according to value functions estimations. """ @@ -277,12 +302,14 @@ def generate( max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) - for _ in range(max_new_tokens): + bypass_peft = False + for token in range(max_new_tokens): out = self.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + bypass_peft_prompt_adapter=bypass_peft, ) logits, _, target_qs, vs, past_key_values = out @@ -301,9 +328,13 @@ def generate( adv = qs - vs pi_beta = F.log_softmax(logits, -1) pi_top_k = topk_mask(pi_beta + beta * adv, top_k) - pi = F.softmax(pi_top_k / temperature, -1) - input_ids = torch.multinomial(pi, num_samples=1) + if temperature == 0.0: + input_ids = pi_top_k.argmax(dim=-1, keepdim=True) + else: + pi = F.softmax(pi_top_k / temperature, -1) + input_ids = torch.multinomial(pi, num_samples=1) + input_ids = (1 - finished) * input_ids + finished * eos_token_id finished = (input_ids == eos_token_id).long() @@ -311,6 +342,16 @@ def generate( attention_mask = torch.hstack((attention_mask, (input_ids != eos_token_id).long())) position_ids = (position_ids[:, -1] + 1).view(-1, 1) + # Some peft models add a prefix to the prompt at each forward pass. + # We need to bypass it so that it doesn't add multiple times the prefix. + if self.peft_type and token == 0 and "LORA" not in self.peft_type: + bypass_peft = True + + prefix_attention_mask = torch.ones(input_ids.shape[0], self.peft_config.num_virtual_tokens).to( + input_ids.device + ) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") != "3" and torch.all(finished): break @@ -319,23 +360,29 @@ def generate( def sync_target_q_heads(self): self.ilql_heads.sync_target_q_heads() - def state_dict(self, *args, **kwargs): + def state_dict(self, *args, heads_only=False, **kwargs): """ Returns the state dictionary of the model. We add the state dictionary of the ilql heads to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. """ - base_model_state_dict = self.base_model.state_dict(*args, **kwargs) ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs) + if heads_only: + model_state_dict = OrderedDict() + else: + model_state_dict = self.base_model.state_dict(*args, **kwargs) + for k, v in ilql_heads_state_dict.items(): - base_model_state_dict[f"ilql_heads.{k}"] = v - return base_model_state_dict + model_state_dict[f"ilql_heads.{k}"] = v + return model_state_dict def post_init(self, state_dict): """ We add the state dictionary of the ilql heads to the state dictionary of the wrapped model - by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the + by prepending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the keys of the value head state dictionary. """ + super().post_init() + for k in list(state_dict.keys()): if "ilql_heads." in k: state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k) @@ -344,12 +391,23 @@ def post_init(self, state_dict): gc.collect() +@dataclass +class Seq2SeqILQLOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + value: Optional[torch.FloatTensor] = None + qs: Optional[Tuple[torch.FloatTensor]] = None + target_qs: Optional[Tuple[torch.FloatTensor]] = None + encoder_outputs: Optional[Tuple[Any]] = None + + class AutoModelForSeq2SeqLMWithILQLHeads(PreTrainedModelWrapper): """This is a wrapper around huggingface AutoModelForSeq2Seq with two additional scalar heads""" _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM _supported_modules = ["ilql_heads"] - _supported_args = ["two_qs", "alpha"] + _supported_args = ["two_qs", "alpha", "peft_config"] def __init__( self, @@ -357,8 +415,9 @@ def __init__( *, two_qs: bool = True, alpha: float = 0.99, + peft_config=None, ): - super().__init__(base_model) + super().__init__(base_model, peft_config=peft_config) hidden_size = hf_get_hidden_size(self.base_model.config) vocab_size = self.base_model.config.vocab_size dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype @@ -369,23 +428,29 @@ def __init__( def sync_target_q_heads(self): self.ilql_heads.sync_target_q_heads() - def state_dict(self, *args, **kwargs): + def state_dict(self, *args, heads_only=False, **kwargs): """ Returns the state dictionary of the model. We add the state dictionary of the ilql heads to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. """ - base_model_state_dict = self.base_model.state_dict(*args, **kwargs) ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs) + if heads_only: + model_state_dict = OrderedDict() + else: + model_state_dict = self.base_model.state_dict(*args, **kwargs) + for k, v in ilql_heads_state_dict.items(): - base_model_state_dict[f"ilql_heads.{k}"] = v - return base_model_state_dict + model_state_dict[f"ilql_heads.{k}"] = v + return model_state_dict def post_init(self, state_dict): """ We add the state dictionary of the ilql heads to the state dictionary of the wrapped model - by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the + by prepending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the keys of the value head state dictionary. """ + super().post_init() + for k in list(state_dict.keys()): if "ilql_heads." in k: state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k) @@ -397,36 +462,56 @@ def forward( self, input_ids, attention_mask=None, + decoder_attention_mask=None, decoder_input_ids=None, past_key_values=None, + decoder_inputs_embeds=None, encoder_outputs=None, actions_ixs=None, states_ixs=None, output_attentions=True, output_hidden_states=True, + return_dict=False, + bypass_peft_prompt_adapter=False, ): forward_kwargs = self.get_compatible_forward_kwargs( input_ids=input_ids, attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, encoder_outputs=encoder_outputs, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) - out = self.base_model(**forward_kwargs) + if self.peft_type == "PREFIX_TUNING" and not bypass_peft_prompt_adapter: + # Peft redefines past_key_values, remove it to avoid an exception. + forward_kwargs.pop("past_key_values", None) + + if bypass_peft_prompt_adapter: + out = self.base_model.base_model(**forward_kwargs) + else: + out = self.base_model(**forward_kwargs) hs = out.decoder_hidden_states[-1] logits = self.base_model.lm_head(hs) qs, target_qs, vs = self.ilql_heads(hs, states_ixs=states_ixs, actions_ixs=actions_ixs) encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions) + + if return_dict: + return Seq2SeqILQLOutput( + logits, out.past_key_values, out.decoder_hidden_states, vs, qs, target_qs, encoder_outputs + ) + return logits, qs, target_qs, vs, out.past_key_values, encoder_outputs def generate( self, input_ids, attention_mask=None, + decoder_attention_mask=None, decoder_input_ids=None, past_key_values=None, encoder_outputs=None, @@ -440,7 +525,7 @@ def generate( eos_token_id=None, ): """ - Generates samples akin to hf's `.generate` but with custom logp prepossessing: + Generates samples akin to hf's `.generate` but with custom logp preprocessing: changing token probabilities as to how advantageous they would be according to value functions estimations. """ @@ -449,7 +534,10 @@ def generate( raise ValueError("eos_token_id and pad_token_id must be provided") if attention_mask is None: - attention_mask = input_ids.not_equal(pad_token_id) + if decoder_attention_mask is not None: + attention_mask = decoder_attention_mask + else: + attention_mask = input_ids.not_equal(pad_token_id) samples = input_ids.clone() max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) @@ -457,13 +545,15 @@ def generate( decoder_input_ids = input_ids.new_zeros(input_ids.shape[0], 1) finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) - for _ in range(max_new_tokens): + bypass_peft = False + for token in range(max_new_tokens): out = self.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids[:, -1].unsqueeze(-1), past_key_values=past_key_values, encoder_outputs=encoder_outputs, + bypass_peft_prompt_adapter=bypass_peft, ) logits, _, target_qs, vs, past_key_values, encoder_outputs = out if self.two_qs: @@ -476,12 +566,27 @@ def generate( adv = qs - vs pi_beta = F.log_softmax(logits, -1) pi_top_k = topk_mask(pi_beta + beta * adv, top_k) - pi = F.softmax(pi_top_k / temperature, -1) - next_tokens = torch.multinomial(pi, num_samples=1) + + if temperature == 0.0: + next_tokens = pi_top_k.argmax(dim=-1, keepdim=True) + else: + pi = F.softmax(pi_top_k / temperature, -1) + next_tokens = torch.multinomial(pi, num_samples=1) next_tokens = (1 - finished) * next_tokens + finished * eos_token_id finished = (next_tokens == eos_token_id).long() | (next_tokens == pad_token_id).long() decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) samples = decoder_input_ids + + # Some peft models add a prefix to the prompt at each forward pass. + # We need to bypass it so that it doesn't add multiple times the prefix. + if self.peft_type and token == 0 and "LORA" not in self.peft_type: + bypass_peft = True + + prefix_attention_mask = torch.ones(input_ids.shape[0], self.peft_config.num_virtual_tokens).to( + input_ids.device + ) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") != "3" and torch.all(finished): break diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 37e20c1ed..49f56f61e 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -1,5 +1,6 @@ import gc import inspect +from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -257,13 +258,14 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): _auto_model_parent_class = transformers.AutoModelForCausalLM _supported_modules = ["v_head"] - _supported_args = [] + _supported_args = ["peft_config"] def __init__( self, base_model: transformers.PreTrainedModel, + peft_config=None, ): - super().__init__(base_model) + super().__init__(base_model, peft_config=peft_config) self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) def forward( @@ -278,6 +280,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + ignore_peft_adapter: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithValue]: forward_kwargs = self.get_compatible_forward_kwargs( input_ids=input_ids, @@ -294,7 +297,22 @@ def forward( forward_kwargs["output_hidden_states"] = True forward_kwargs["return_dict"] = True - outputs = self.base_model(**forward_kwargs) + if self.peft_type == "PREFIX_TUNING": + # In this case peft redefines past_key_values, remove it to avoid an exception. + forward_kwargs.pop("past_key_values", None) + + if self.peft_type and ignore_peft_adapter: + if "LORA" in self.peft_type: + # For LORA, temporarily disable the adapter + lora_model = self.base_model.base_model + lora_model.disable_adapter_layers() + outputs = self.base_model(**forward_kwargs) + lora_model.enable_adapter_layers() + else: + # For prompt or prefix adapters, just use the base model of PeftModel + outputs = self.base_model.base_model(**forward_kwargs) + else: + outputs = self.base_model(**forward_kwargs) value = self.v_head(outputs.hidden_states[-1]).squeeze(-1) if not return_dict: @@ -306,16 +324,21 @@ def forward( def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]: return self.base_model.generate(*args, **kwargs) - def state_dict(self, *args, **kwargs): + def state_dict(self, *args, heads_only=False, **kwargs): """ Returns the state dictionary of the model. We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the key with `v_head.`. """ - base_model_state_dict = self.base_model.state_dict(*args, **kwargs) v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + if heads_only: + model_state_dict = OrderedDict() + else: + model_state_dict = self.base_model.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): - base_model_state_dict[f"v_head.{k}"] = v - return base_model_state_dict + model_state_dict[f"v_head.{k}"] = v + + return model_state_dict def post_init(self, state_dict): """ @@ -323,6 +346,8 @@ def post_init(self, state_dict): by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state dictionary. """ + super().post_init() + for k in list(state_dict.keys()): if "v_head." in k: state_dict[k.replace("v_head.", "")] = state_dict.pop(k) @@ -333,17 +358,19 @@ def post_init(self, state_dict): class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead): _supported_modules = ["v_head", "frozen_head"] - _supported_args = ["num_layers_unfrozen"] + _supported_args = ["num_layers_unfrozen", "peft_config"] def __init__( self, base_model: transformers.PreTrainedModel, *, num_layers_unfrozen: int = -1, + peft_config=None, ): - super().__init__(base_model) + super().__init__(base_model, peft_config=peft_config) self.num_layers_unfrozen = num_layers_unfrozen - if self.num_layers_unfrozen > 0: + + if self.num_layers_unfrozen > 0 and not self.peft_type: config = self.base_model.config branch_class = hf_get_branch_class(config) self.frozen_head = branch_class( @@ -380,14 +407,17 @@ def forward_hydra( forward_kwargs["return_dict"] = True forward_kwargs["output_hidden_states"] = True - outputs = self.forward(**forward_kwargs) - # Select the hidden state before the first branching layer - input_hidden_state = outputs.hidden_states[-(self.num_layers_unfrozen + 1)] + if self.peft_type: + hydra_outputs = self.forward(**forward_kwargs, ignore_peft_adapter=True) + else: + outputs = self.forward(**forward_kwargs) + # Select the hidden state before the first branching layer + input_hidden_state = outputs.hidden_states[-(self.num_layers_unfrozen + 1)] - output_shape = outputs.hidden_states[-1].size() - forward_kwargs.pop("input_ids", None) # Ignore `input_ids` for branch head - forward_kwargs.pop("inputs_embeds", None) # Ignore `inputs_embeds` for branch head - hydra_outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs) + output_shape = outputs.hidden_states[-1].size() + forward_kwargs.pop("input_ids", None) # Ignore `input_ids` for branch head + forward_kwargs.pop("inputs_embeds", None) # Ignore `inputs_embeds` for branch head + hydra_outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs) if not return_dict: return hydra_outputs.logits @@ -985,13 +1015,14 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM _supported_modules = ["v_head"] - _supported_args = [] + _supported_args = ["peft_config"] def __init__( self, base_model: transformers.PreTrainedModel, + peft_config=None, ): - super().__init__(base_model) + super().__init__(base_model, peft_config=peft_config) self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) def forward( @@ -1011,6 +1042,7 @@ def forward( output_attentions: Optional[bool] = True, output_hidden_states: Optional[bool] = True, return_dict: Optional[bool] = None, + ignore_peft_adapter: Optional[bool] = None, ) -> Seq2SeqLMOutputWithValue: forward_kwargs = self.get_compatible_forward_kwargs( input_ids=input_ids, @@ -1032,7 +1064,23 @@ def forward( forward_kwargs["output_hidden_states"] = True forward_kwargs["return_dict"] = True - outputs = self.base_model(**forward_kwargs) + if self.peft_type == "PREFIX_TUNING": + # In this case peft redefines past_key_values, remove it to avoid an exception. + forward_kwargs.pop("past_key_values", None) + + if self.peft_type and ignore_peft_adapter: + if "LORA" in self.peft_type: + # For LORA, temporarily disable the adapter + lora_model = self.base_model.base_model + lora_model.disable_adapter_layers() + outputs = self.base_model(**forward_kwargs) + lora_model.enable_adapter_layers() + else: + # For prompt or prefix adapters, just use the base model of PeftModel + outputs = self.base_model.base_model(**forward_kwargs) + else: + outputs = self.base_model(**forward_kwargs) + last_hidden_state = outputs.decoder_hidden_states[-1] value = self.v_head(last_hidden_state).squeeze(-1) @@ -1041,16 +1089,21 @@ def forward( def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]: return self.base_model.generate(*args, **kwargs) - def state_dict(self, *args, **kwargs): + def state_dict(self, *args, heads_only=False, **kwargs): """ Returns the state dictionary of the model. We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the key with `v_head.`. """ - base_model_state_dict = self.base_model.state_dict(*args, **kwargs) v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + if heads_only: + model_state_dict = OrderedDict() + else: + model_state_dict = self.base_model.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): - base_model_state_dict[f"v_head.{k}"] = v - return base_model_state_dict + model_state_dict[f"v_head.{k}"] = v + + return model_state_dict def post_init(self, state_dict): """ @@ -1058,6 +1111,8 @@ def post_init(self, state_dict): by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state dictionary. """ + super().post_init() + for k in list(state_dict.keys()): if "v_head." in k: state_dict[k.replace("v_head.", "")] = state_dict.pop(k) @@ -1068,17 +1123,19 @@ def post_init(self, state_dict): class AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithValueHead): _supported_modules = ["v_head", "frozen_head"] - _supported_args = ["num_layers_unfrozen"] + _supported_args = ["num_layers_unfrozen", "peft_config"] def __init__( self, base_model: transformers.PreTrainedModel, *, num_layers_unfrozen: int = -1, + peft_config=None, ): - super().__init__(base_model) + super().__init__(base_model, peft_config=peft_config) self.num_layers_unfrozen = num_layers_unfrozen - if self.num_layers_unfrozen > 0: + + if self.num_layers_unfrozen > 0 and not self.peft_type: branch_class = T5Branch # TODO: Add support for other model branches self.frozen_head = branch_class( self.base_model, @@ -1124,19 +1181,22 @@ def forward_hydra( forward_kwargs["output_hidden_states"] = True forward_kwargs["return_dict"] = True - outputs = self.forward(**forward_kwargs) - # Select the hidden state before the first branching layer - input_hidden_state = outputs.decoder_hidden_states[-(self.num_layers_unfrozen + 1)] - hydra_outputs = self.frozen_head( - hidden_states=input_hidden_state, - attention_mask=decoder_attention_mask, - encoder_hidden_states=outputs.encoder_last_hidden_state, - encoder_attention_mask=attention_mask, - use_cache=False, - output_attentions=False, - output_hidden_states=True, - return_dict=return_dict, - ) + if self.peft_type: + hydra_outputs = self.forward(**forward_kwargs, ignore_peft_adapter=True) + else: + outputs = self.forward(**forward_kwargs) + # Select the hidden state before the first branching layer + input_hidden_state = outputs.decoder_hidden_states[-(self.num_layers_unfrozen + 1)] + hydra_outputs = self.frozen_head( + hidden_states=input_hidden_state, + attention_mask=decoder_attention_mask, + encoder_hidden_states=outputs.encoder_last_hidden_state, + encoder_attention_mask=attention_mask, + use_cache=False, + output_attentions=False, + output_hidden_states=True, + return_dict=return_dict, + ) if not return_dict: return hydra_outputs.logits diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 314eaa80c..4fb74c48d 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -32,8 +32,6 @@ freeze_bottom_causal_layers, freeze_bottom_seq2seq_layers, gather_dict, - get_delta_model_class, - parse_delta_kwargs, ) logger = logging.get_logger(__name__) @@ -149,22 +147,21 @@ def setup_model(self): # Retrieves model equipped for ppo, ilql, etc model = self.get_arch(self.config) - if self.config.model.model_arch_type == "seq2seq": - freeze_bottom_seq2seq_layers(model.base_model, self.config.model.num_layers_unfrozen) + + if self.config.model.peft_config is None: + if self.config.model.model_arch_type == "seq2seq": + freeze_bottom_seq2seq_layers(model.base_model, self.config.model.num_layers_unfrozen) + else: + freeze_bottom_causal_layers(model.base_model, self.config.model.num_layers_unfrozen) else: - freeze_bottom_causal_layers(model.base_model, self.config.model.num_layers_unfrozen) - # Set the delta tuning strategies - if self.config.model.delta_kwargs is not None: - delta_type, delta_kwargs = parse_delta_kwargs( - model.base_model.config, - self.config.model.delta_kwargs, - self.config.model.num_layers_unfrozen, - ) - delta_model_class = get_delta_model_class(delta_type) - delta_model = delta_model_class(model.base_model, **delta_kwargs) - delta_model.freeze_module(exclude=["deltas"], set_state_dict=True) - if self.accelerator.is_main_process: - delta_model.log() + if self.accelerator.is_main_process and hasattr(model.base_model, "print_trainable_parameters"): + model.base_model.print_trainable_parameters() + if self.config.model.num_layers_unfrozen >= 0: + logger.warning( + "The argument num_layers_unfrozen is ignored when using peft, to prevent unexpected behaviour." + "For Lora, use the `LoraConfig` argument `modules_to_save` instead." + ) + return model def setup_optimizer(self): @@ -306,11 +303,33 @@ def save_pretrained(self, directory: Optional[str] = None, **kwargs): def save(self, directory: Optional[str] = None, **kwargs): """Creates a checkpoint of the optimizer, scheduler and model""" - self.accelerator.save_state(directory or self.config.train.checkpoint_dir, **kwargs) + dst_dir = directory or self.config.train.checkpoint_dir + self.accelerator.save_state(dst_dir, **kwargs) + + if self.config.model.peft_config is not None and self.accelerator.is_main_process: + # Remove "pytorch_model.bin" because it contains more than necessary, + # let save_pretrained recreate it with just the value heads. + model_file = os.path.join(dst_dir, "pytorch_model.bin") + if os.path.exists(model_file): + os.remove(model_file) + self.accelerator.unwrap_model(self.model).save_pretrained(dst_dir) def load(self, directory: Optional[str] = None, **kwargs): """Load checkpoint of optimizer, scheduler and a model""" - self.accelerator.load_state(directory or self.config.train.checkpoint_dir, **kwargs) + if self.config.model.peft_config is not None: + + def load_state_hook(models: List[torch.nn.Module], input_dir: str): + with self.accelerator.main_process_first(): + for model in models: + model.from_pretrained(input_dir) + + self.accelerator.register_load_state_pre_hook(load_state_hook) + + strict = False + else: + strict = True + + self.accelerator.load_state(directory or self.config.train.checkpoint_dir, strict=strict, **kwargs) def add_eval_pipeline(self, eval_pipeline): """Adds pipeline from with validation prompts""" diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index e41d1f923..2a2dc661e 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -131,6 +131,7 @@ def get_arch(self, config): config.model.model_path, two_qs=config.method.two_qs, alpha=config.method.alpha, + peft_config=self.config.model.peft_config, ) def post_backward_callback(self): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a9639746e..a3af9aa3f 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -67,13 +67,13 @@ def __init__(self, config: TRLConfig, **kwargs): self.store.clear_history() # Clear the rollout store - # Setup a reference model when hydra heads are not used - if not hasattr(self.model, "frozen_head"): + # Set up a reference model when hydra heads are not used + if not hasattr(self.model, "frozen_head") and not self.model.peft_type: self.ref_model = self.get_arch(self.config) self.ref_model.to(self.accelerator.device) self.ref_model.eval() - # Setup the KL controller + # Set up the KL controller # This helps prevent large divergences in the controller (policy) if config.method.target is not None: self.kl_ctl = AdaptiveKLController(config.method.init_kl_coef, config.method.target, config.method.horizon) @@ -115,6 +115,7 @@ def get_arch(self, config: TRLConfig): return from_fn( config.model.model_path, num_layers_unfrozen=config.model.num_layers_unfrozen, + peft_config=self.config.model.peft_config, ) def loss(self, batch: PPORLBatch): @@ -368,7 +369,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) logits = outputs.logits values = outputs.value - if hasattr(self.model, "frozen_head"): + if hasattr(self.model, "frozen_head") or self.model.peft_type: ref_logits = self.model.forward_hydra( input_ids=prompt_tensors, attention_mask=attention_mask, @@ -394,7 +395,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_tokens, attention_mask=attention_mask, position_ids=position_ids ) # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.model, "frozen_head"): + if hasattr(self.model, "frozen_head") or self.model.peft_type: ref_logits = self.model.forward_hydra( all_tokens, attention_mask=attention_mask, diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index cb471ca61..d5cbe3ea5 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, PretrainedConfig from trlx.data.configs import TRLConfig from trlx.data.method_configs import MethodConfig, register_method @@ -38,7 +38,27 @@ def __init__(self, config: TRLConfig, **kwargs): ) def get_arch(self, config): - return AutoModelForCausalLM.from_pretrained(config.model.model_path) + from_fn = AutoModelForCausalLM.from_pretrained + if issubclass(type(config.model.model_path), PretrainedConfig): + from_fn = AutoModelForCausalLM.from_config + + model = from_fn(config.model.model_path) + + if config.model.peft_config is not None: + # Initialize the peft adapter + import peft + + peft_config = config.model.peft_config + if not isinstance(peft_config, peft.PeftConfig): + if isinstance(peft_config, dict): + peft_config = peft.get_peft_config(peft_config) + else: + raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.") + model = peft.get_peft_model(model, peft_config) + if self.accelerator.is_main_process: + model.print_trainable_parameters() + + return model def loss(self, batch): if "labels" in batch: diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 820b66783..784ec4437 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -1,3 +1,4 @@ +import importlib import math import os import random @@ -15,6 +16,10 @@ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR +def is_peft_available(): + return importlib.util.find_spec("peft") is not None + + def print_rank_0(*message): """ Print only once from the main rank diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 44b2dcd91..47688f553 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, List, MutableMapping, Tuple, Union +from typing import Dict, MutableMapping, Tuple, Union import accelerate import numpy as np @@ -9,19 +9,6 @@ import torch.nn.functional as F import transformers -try: - from opendelta import ( - AdapterModel, - BitFitModel, - LoraModel, - PrefixModel, - SoftPromptModel, - ) - - HAS_OPENDELTA = True -except ModuleNotFoundError: - HAS_OPENDELTA = False - def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential: """Returns a generic sequential MLP head.""" @@ -311,242 +298,3 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]: self.count = tot_count return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt() - - -# OpenDelta utilities - - -MODIFIED_MODULES_DICT = { - "gptj": { - "attention": ["attn.q_proj", "attn.k_proj", "attn.v_proj"], - "mlp": ["mlp.fc_in", "mlp.fc_out"], - "all": [ - "attn.q_proj", - "attn.k_proj", - "attn.v_proj", - "attn.out_proj", - "mlp.fc_in", - "mlp.fc_out", - ], - }, - "gpt_neox": { - "attention": ["attention.query_key_value"], - "mlp": ["mlp.dense_h_to_4h", "mlp.dense_4h_to_h"], - "all": [ - "attention.query_key_value", - "attention.dense", - "mlp.dense_h_to_4h", - "mlp.dense_4h_to_h", - ], - }, - "opt": { - "attention": [ - "self_attn.k_proj", - "self_attn.v_proj", - "self_attn.q_proj", - "self_attn.out_proj", - ], - "mlp": ["fc1", "fc2"], - "all": [ - "self_attn.k_proj", - "self_attn.v_proj", - "self_attn.q_proj", - "self_attn.out_proj", - "fc1", - "fc2", - ], - }, - "bloom": { - "attention": ["self_attention.query_key_value", "self_attention.dense"], - "mlp": ["mlp.dense_h_to_4h", "mlp.dense_4h_to_h"], - "all": [ - "self_attention.query_key_value", - "self_attention.dense", - "mlp.dense_h_to_4h", - "mlp.dense_4h_to_h", - ], - }, - "t5": { - "attention": [ - "layer.0.SelfAttention.q", - "layer.0.SelfAttention.k", - "layer.0.SelfAttention.v", - "layer.0.SelfAttention.o", - "layer.1.EncDecAttention.q", - "layer.1.EncDecAttention.k", - "layer.1.EncDecAttention.v", - "layer.1.EncDecAttention.o", - ], - "mlp": [ - "layer.2.DenseReluDense.wo", - "layer.2.DenseReluDense.wi_0", - "layer.2.DenseReluDense.wi_1", - ], - "all": [ - "layer.0.SelfAttention.q", - "layer.0.SelfAttention.k", - "layer.0.SelfAttention.v", - "layer.0.SelfAttention.o", - "layer.1.EncDecAttention.q", - "layer.1.EncDecAttention.k", - "layer.1.EncDecAttention.v", - "layer.1.EncDecAttention.o", - "layer.2.DenseReluDense.wo", - "layer.2.DenseReluDense.wi_0", - "layer.2.DenseReluDense.wi_1", - ], - }, -} - - -def generate_layer_regex(config: transformers.PretrainedConfig, num_layers_unfrozen: int = -1) -> str: - """Generates a regex range for the specified number of learnable layers.""" - if num_layers_unfrozen == -1: - return "(\d)+." - num_hidden_layers = hf_get_num_hidden_layers(config) - start_layer = num_hidden_layers - num_layers_unfrozen - if start_layer < 0: - raise Exception("Number of layers unfrozen cannot be greater than number of layers in the model") - pattern = f"(?:{regex_for_range(start_layer, num_hidden_layers - 1)})." - return f"{pattern}" - - -def get_delta_modified_modules( - config: transformers.PretrainedConfig, - modified_modules: List[str], - num_layers_unfrozen: int = -1, -) -> List[str]: - """Returns a list of module names to be modified for a given delta method with - the specified number of learnable layers.""" - unfrozen_layers_pattern = generate_layer_regex(config, num_layers_unfrozen) - - # [r] for regex as per https://github.com/thunlp/OpenDelta/blob/main/opendelta/utils/name_based_addressing.py#L20 - regex_prefix = "[r]" - # TODO (jon-tow): `decoder.block.` is hardcoded to support T5 layer naming. - decoder_prefix = "decoder.block." if config.is_encoder_decoder else "" - module_list = [regex_prefix + decoder_prefix + unfrozen_layers_pattern + module for module in modified_modules] - return module_list - - -def get_delta_model_class(model_type: str): - if not HAS_OPENDELTA: - raise ValueError("OpenDelta package required to train with delta models. https://github.com/thunlp/OpenDelta.") - delta_models = { - "bitfit": BitFitModel, - "adapter": AdapterModel, - "prefix": PrefixModel, - "lora": LoraModel, - "softprompt": SoftPromptModel, - } - return delta_models[model_type] - - -def parse_delta_kwargs( - config: transformers.PretrainedConfig, - delta_kwargs: Dict[str, Any], - num_layers_unfrozen: int = -1, -) -> Tuple[str, Dict[str, Any]]: - """Parses through delta kwargs to get delta type and proper modified modules.""" - # This function is needed to parse through the `delta_kwargs` in order to: - # 1) Get the `delta_type` method name to access the correct `delta_model_class` - # 2a) Accept user specified `modified_modules` and if not provided use the `trlx` default mapping - # 2b) Convert the list of `modified_modules` to a range of layers that fit within the range - # of learnable layers as specified by `num_layers_unfrozen` - - # Pop `delta_type` to allow passing the kwargs to the model constructor since - # `delta_type` is not a valid argument of the constructor - delta_type = delta_kwargs.pop("delta_type") - assert delta_type in ["lora"], "Only `LoRA` based delta models are supported" - - # Use `trlx` default modified modules if none are specified - modified_modules = delta_kwargs.get("modified_modules", "all") - if modified_modules in ["all", "attention", "mlp"]: - if config.model_type not in MODIFIED_MODULES_DICT: - raise ValueError( - f"Model type `{config.model_type}` is not currently supported for " - "delta training with default modified modules." - ) - modified_modules = MODIFIED_MODULES_DICT[config.model_type][modified_modules] - # Update the `modified_modules` with the correct layer ranges - delta_kwargs["modified_modules"] = get_delta_modified_modules( - config, modified_modules, num_layers_unfrozen=num_layers_unfrozen - ) - - return delta_type, delta_kwargs - - -def regex_for_range(min_: int, max_: int) -> str: # noqa - """Returns a regex that matches all numbers in the given range. - - Example: regex_for_range(12, 34) -> "1[2-9]|2\d|3[0-4]" - - Copyright (c) 2013, Dmitry Voronin. All rights reserved. - Reference: https://github.com/voronind/range-regex - """ - - def split_to_patterns(min_, max_): - subpatterns = [] - start = min_ - for stop in split_to_ranges(min_, max_): - subpatterns.append(range_to_pattern(start, stop)) - start = stop + 1 - return subpatterns - - def split_to_ranges(min_, max_): - stops = {max_} - nines_count = 1 - stop = fill_by_nines(min_, nines_count) - while min_ <= stop < max_: - stops.add(stop) - nines_count += 1 - stop = fill_by_nines(min_, nines_count) - zeros_count = 1 - stop = fill_by_zeros(max_ + 1, zeros_count) - 1 - while min_ < stop <= max_: - stops.add(stop) - zeros_count += 1 - stop = fill_by_zeros(max_ + 1, zeros_count) - 1 - stops = list(stops) - stops.sort() - return stops - - def fill_by_nines(integer, nines_count): - return int(str(integer)[:-nines_count] + "9" * nines_count) - - def fill_by_zeros(integer, zeros_count): - return integer - integer % 10**zeros_count - - def range_to_pattern(start, stop): - pattern = "" - any_digit_count = 0 - for start_digit, stop_digit in zip(str(start), str(stop)): - if start_digit == stop_digit: - pattern += start_digit - elif start_digit != "0" or stop_digit != "9": - pattern += "[{}-{}]".format(start_digit, stop_digit) - else: - any_digit_count += 1 - if any_digit_count: - pattern += r"\d" - if any_digit_count > 1: - pattern += "{{{}}}".format(any_digit_count) - return pattern - - positive_subpatterns = [] - negative_subpatterns = [] - - if min_ < 0: - min__ = 1 - if max_ < 0: - min__ = abs(max_) - max__ = abs(min_) - negative_subpatterns = split_to_patterns(min__, max__) - min_ = 0 - if max_ >= 0: - positive_subpatterns = split_to_patterns(min_, max_) - - negative_only_subpatterns = ["-" + val for val in negative_subpatterns if val not in positive_subpatterns] - positive_only_subpatterns = [val for val in positive_subpatterns if val not in negative_subpatterns] - intersected_subpatterns = ["-?" + val for val in negative_subpatterns if val in positive_subpatterns] - subpatterns = negative_only_subpatterns + intersected_subpatterns + positive_only_subpatterns - return "|".join(subpatterns)