From 6c81c61bc4b52296fb0e6ac01c69a71c28756790 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 23 Oct 2023 17:38:41 -0400 Subject: [PATCH] refactor setup trainer so we can add more hooks (#773) * refactor setup trainer so we can add more hooks * Remove stray comma --- src/axolotl/core/__init__.py | 0 src/axolotl/core/trainer_builder.py | 689 ++++++++++++++++++++++++++++ src/axolotl/utils/callbacks.py | 2 +- src/axolotl/utils/trainer.py | 540 +--------------------- 4 files changed, 699 insertions(+), 532 deletions(-) create mode 100644 src/axolotl/core/__init__.py create mode 100644 src/axolotl/core/trainer_builder.py diff --git a/src/axolotl/core/__init__.py b/src/axolotl/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py new file mode 100644 index 0000000000..00a1a0c670 --- /dev/null +++ b/src/axolotl/core/trainer_builder.py @@ -0,0 +1,689 @@ +""" +Builder for the training args and trainer +""" + +import abc +import importlib +import logging +import math +import os +import sys +from abc import abstractmethod +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import Optional, Union + +import torch +import transformers +from datasets import Dataset +from torch.optim.lr_scheduler import OneCycleLR +from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler +from transformers import EarlyStoppingCallback, Trainer, TrainingArguments +from transformers.trainer_pt_utils import SequentialDistributedSampler + +from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler +from axolotl.utils.callbacks import ( + EvalFirstStepCallback, + GPUStatsCallback, + SaveAxolotlConfigtoWandBCallback, + SaveBetterTransformerModelCallback, + bench_eval_callback_factory, + log_prediction_callback_factory, +) +from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.dataloader import MultipackDistributedDataloader +from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup + +try: + import torch._dynamo # pylint: disable=ungrouped-imports +except ImportError: + pass + +LOG = logging.getLogger("axolotl.core.trainer_builder") + + +@dataclass +class AxolotlTrainingArguments(TrainingArguments): + """ + Extend the base TrainingArguments for axolotl helpers + """ + + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + eval_sample_packing: Optional[bool] = field( + default=None, + metadata={"help": "Use sample packing for efficient evals."}, + ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) + sample_packing_seq_len_multiplier: int = field( + default=1, + metadata={"help": "the multiplier for the max len for packed sequences"}, + ) + relora_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for ReLoRA"}, + ) + relora_warmup_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + bench_split: Optional[str] = field( + default="eval", metadata={"help": "The benchmark split to run on"} + ) + bench_dataset: Optional[str] = field( + default="pharaouk/dharma-1/dharma_1_mini.json", + metadata={ + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" + }, + ) + do_bench_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Benchmark evaluation."} + ) + max_bench_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." + }, + ) + bench_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for bench."} + ) + + +class AxolotlTrainer(Trainer): + """ + Extend the base Trainer for axolotl helpers + """ + + args = None # type: AxolotlTrainingArguments + + def __init__(self, *args, bench_data_collator=None, **kwargs): + self.bench_data_collator = bench_data_collator + super().__init__(*args, **kwargs) + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer + """ + + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ): + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + else: + return super().create_scheduler(num_training_steps, optimizer) + return self.lr_scheduler + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size > 1 and self.args.sample_packing: + return DistributedSampler( + self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=self.args.seed, + ) + return super()._get_train_sampler() + + def _get_eval_sampler( + self, eval_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if ( + self.args.world_size > 1 + and self.args.sample_packing + and self.args.eval_sample_packing is not False + ): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + batch_size=self.args.per_device_eval_batch_size, + ) + return super()._get_eval_sampler(eval_dataset) + + def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: + if self.args.sample_packing: + train_sampler = self._get_train_sampler() + return self.accelerator.prepare( + MultipackDistributedDataloader( + self.train_dataset, + batch_size=self._train_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=train_sampler, + packing_efficiency_estimate=self.args.sample_packing_efficiency, + sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, + device_count=int(os.environ.get("WORLD_SIZE", 1)), + ) + ) + return super().get_train_dataloader() + + def get_eval_dataloader( + self, eval_dataset: Optional[Dataset] = None + ) -> Union[DataLoader, MultipackDistributedDataloader]: + if self.args.sample_packing and self.args.eval_sample_packing is not False: + eval_dataset = ( + eval_dataset if eval_dataset is not None else self.eval_dataset + ) + + eval_sampler = self._get_eval_sampler(eval_dataset) + return self.accelerator.prepare( + MultipackDistributedDataloader( + eval_dataset, + batch_size=self.args.eval_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=eval_sampler, + packing_efficiency_estimate=self.args.sample_packing_efficiency, + sample_packing_seq_len_multiplier=self.args.eval_batch_size, + device_count=int(os.environ.get("WORLD_SIZE", 1)), + ) + ) + return super().get_eval_dataloader(eval_dataset) + + def _get_bench_sampler( + self, bench_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size <= 1: + return SequentialSampler(bench_dataset) + return None + + def get_bench_dataloader( + self, + bench_dataset: Dataset, + ) -> Union[DataLoader, MultipackDistributedDataloader]: + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": self.bench_data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + + if not isinstance(bench_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + return DataLoader(bench_dataset, **dataloader_params) + # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + + def compute_loss(self, model, inputs, return_outputs=False): + # use one's weighted cross entropy loss calc + # if self.args.sample_packing: + # labels = inputs.pop("labels") + # outputs = model(**inputs) + # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) + # return (loss, outputs) if return_outputs else loss + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + +class OneCycleLRSchedulerTrainer(AxolotlTrainer): + """ + Trainer subclass that uses the OneCycleLR scheduler + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + + def create_scheduler( + self, + num_training_steps: int, + optimizer: Optional[torch.optim.Optimizer] = None, + ): + optimizer = self.optimizer if optimizer is None else optimizer + num_warmup_steps = self.args.get_warmup_steps(num_training_steps) + pct_start = num_warmup_steps / num_training_steps + + self.lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + pct_start=pct_start, + div_factor=6, + ) + + return self.lr_scheduler + + +class ReLoRATrainer(AxolotlTrainer): + """ + Trainer subclass that uses the OneCycleLR scheduler + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + + def create_scheduler( + self, + num_training_steps: int, + optimizer: Optional[torch.optim.Optimizer] = None, + ): + optimizer = self.optimizer if optimizer is None else optimizer + lr_scheduler = super().create_scheduler(num_training_steps, optimizer) + + if self.args.relora_steps: + warmup_steps = ( + self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 + ) + self.lr_scheduler = ReLoRAScheduler( + optimizer, + lr_scheduler, + self.args.relora_steps, + warmup_steps, + ) + else: + self.lr_scheduler = lr_scheduler + + return self.lr_scheduler + + +class TrainerBuilderBase(abc.ABC): + """ + Base class for trainer builder + """ + + _train_dataset = None + _eval_dataset = None + + def __init__(self, cfg, model, tokenizer): + self.cfg = cfg + self.model = model + self.tokenizer = tokenizer + + @property + def train_dataset(self): + return self._train_dataset + + @train_dataset.setter + def train_dataset(self, dataset): + self._train_dataset = dataset + + @property + def eval_dataset(self): + return self._eval_dataset + + @eval_dataset.setter + def eval_dataset(self, dataset): + self._eval_dataset = dataset + + @abstractmethod + def build(self, total_num_steps): + pass + + @abstractmethod + def get_callbacks(self): + pass + + @abstractmethod + def get_post_trainer_create_callbacks(self, trainer): + """ + Callbacks added after the trainer is created, usually b/c these need access to the trainer + """ + + +class HFCausalTrainerBuilder(TrainerBuilderBase): + """ + Build the HuggingFace training args/trainer for Causal models + """ + + def hook_pre_create_training_args(self, training_arguments_kwargs): + # TODO + return training_arguments_kwargs + + def hook_post_create_training_args(self, training_arguments): + # TODO + return training_arguments + + def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls): + # TODO + return trainer_kwargs, trainer_cls + + def hook_post_create_trainer(self, trainer): + # TODO + return trainer + + def get_callbacks(self): + callbacks = [] + callbacks.append(GPUStatsCallback(self.cfg)) + callbacks.append(EvalFirstStepCallback) + + if self.cfg.relora_steps: + callbacks.append(ReLoRACallback(self.cfg)) + + if ( + hasattr(self.model, "use_bettertransformer") + and self.model.use_bettertransformer is True + ): + callbacks.append(SaveBetterTransformerModelCallback) + + if self.cfg.use_wandb: + callbacks.append( + SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) + ) + + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + if self.cfg.use_wandb and self.cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer + ) + callbacks.append(LogPredictionCallback(self.cfg)) + + if self.cfg.do_bench_eval: + callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) + + if self.cfg.early_stopping_patience: + early_stop_cb = EarlyStoppingCallback( + self.cfg.early_stopping_patience, + ) + callbacks.append(early_stop_cb) + + return callbacks + + def _get_trainer_cls(self): + if self.cfg.lr_scheduler == "one_cycle" and ( + self.cfg.fsdp or self.cfg.adapter == "qlora" + ): + return OneCycleLRSchedulerTrainer + if self.cfg.relora_steps: + return ReLoRATrainer + return AxolotlTrainer + + def build(self, total_num_steps): + warmup_steps = ( + self.cfg.warmup_steps + if self.cfg.warmup_steps is not None + else min(int(0.03 * total_num_steps), 100) + ) + logging_steps = ( + self.cfg.logging_steps + if self.cfg.logging_steps is not None + else max(min(int(0.005 * total_num_steps), 10), 1) + ) + + training_arguments_kwargs = {} + if self.cfg.bf16 == "full": + training_arguments_kwargs["bf16_full_eval"] = True + else: + training_arguments_kwargs["bf16"] = self.cfg.bf16 + training_arguments_kwargs["fp16"] = ( + self.cfg.fp16 and not self.cfg.bf16 + ) or False + training_arguments_kwargs["tf32"] = self.cfg.tf32 + training_arguments_kwargs["warmup_steps"] = warmup_steps + training_arguments_kwargs["logging_steps"] = logging_steps + + if self.cfg.seed: + training_arguments_kwargs["seed"] = self.cfg.seed + + if self.cfg.gradient_checkpointing: + training_arguments_kwargs[ + "gradient_checkpointing" + ] = self.cfg.gradient_checkpointing + if self.cfg.fsdp: + training_arguments_kwargs["fsdp"] = self.cfg.fsdp + if self.cfg.fsdp_config: + training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) + + # deepspeed + if self.cfg.deepspeed: + training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed + + if self.cfg.lr_quadratic_warmup is not None: + training_arguments_kwargs[ + "lr_quadratic_warmup" + ] = self.cfg.lr_quadratic_warmup + + if self.cfg.adam_beta1: + training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 + if self.cfg.adam_beta2: + training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2 + if self.cfg.adam_epsilon: + training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon + if self.cfg.max_grad_norm: + training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm + + if self.cfg.hub_model_id: + training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id + training_arguments_kwargs["push_to_hub"] = True + training_arguments_kwargs["hub_private_repo"] = True + + if self.cfg.hub_strategy: + training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy + + if self.cfg.save_safetensors: + training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors + + if self.cfg.sample_packing_eff_est: + training_arguments_kwargs[ + "sample_packing_efficiency" + ] = self.cfg.sample_packing_eff_est + + if self.cfg.eval_steps: + training_arguments_kwargs["evaluation_strategy"] = "steps" + training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps + elif self.cfg.evaluation_strategy: + training_arguments_kwargs[ + "evaluation_strategy" + ] = self.cfg.evaluation_strategy + elif self.cfg.val_set_size == 0: + # no eval set, so don't eval + training_arguments_kwargs["evaluation_strategy"] = "no" + else: + # we have an eval set, but no steps defined, default to use epoch + training_arguments_kwargs["evaluation_strategy"] = "epoch" + + if self.cfg.save_steps: + training_arguments_kwargs["save_strategy"] = "steps" + training_arguments_kwargs["save_steps"] = self.cfg.save_steps + elif self.cfg.save_strategy: + training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy + else: + # default to saving each epoch if not defined + training_arguments_kwargs["save_strategy"] = "epoch" + + if self.cfg.do_bench_eval: + training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval + if self.cfg.bench_dataset: + training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset + if self.cfg.metric_for_best_model: + training_arguments_kwargs[ + "metric_for_best_model" + ] = self.cfg.metric_for_best_model + if self.cfg.greater_is_better: + training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better + + if self.cfg.torch_compile: + if torch.__version__ < "2.1.0": # pylint: disable=protected-access + LOG.warning("torch>=2.1.0 required for torch_compile to work properly") + elif torch._dynamo: # pylint: disable=protected-access + torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access + True + ) + training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile + if self.cfg.torch_compile_backend: + training_arguments_kwargs[ + "torch_compile_backend" + ] = self.cfg.torch_compile_backend + + # DDP Config + if self.cfg.ddp_timeout: + training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout + # see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + if self.cfg.ddp_bucket_cap_mb: + training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb + if self.cfg.ddp_broadcast_buffers is not None: + training_arguments_kwargs[ + "ddp_broadcast_buffers" + ] = self.cfg.ddp_broadcast_buffers + + # these are all the "standard" kwargs that are def used + training_arguments_kwargs["max_steps"] = ( + total_num_steps if self.cfg.max_steps else -1 + ) + training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len + training_arguments_kwargs[ + "per_device_train_batch_size" + ] = self.cfg.micro_batch_size + training_arguments_kwargs[ + "per_device_eval_batch_size" + ] = self.cfg.eval_batch_size + training_arguments_kwargs[ + "gradient_accumulation_steps" + ] = self.cfg.gradient_accumulation_steps + training_arguments_kwargs[ + "eval_accumulation_steps" + ] = self.cfg.gradient_accumulation_steps + training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs + training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate + training_arguments_kwargs["output_dir"] = self.cfg.output_dir + training_arguments_kwargs["save_total_limit"] = ( + self.cfg.save_total_limit if self.cfg.save_total_limit else 4 + ) + training_arguments_kwargs["load_best_model_at_end"] = ( + ( + self.cfg.load_best_model_at_end is not False + or self.cfg.early_stopping_patience + ) + and self.cfg.val_set_size > 0 + and self.cfg.save_steps + and self.cfg.eval_steps + and self.cfg.save_steps % self.cfg.eval_steps == 0 + ) or False + training_arguments_kwargs["ddp_find_unused_parameters"] = ( + False if self.cfg.ddp else None + ) + training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length + training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None + training_arguments_kwargs["run_name"] = ( + self.cfg.wandb_run_id if self.cfg.use_wandb else None + ) + training_arguments_kwargs["optim"] = ( + self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" + ) + training_arguments_kwargs["lr_scheduler_type"] = ( + self.cfg.lr_scheduler + if self.cfg.lr_scheduler + and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") + else "cosine" + ) + training_arguments_kwargs["weight_decay"] = ( + self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 + ) + training_arguments_kwargs["sample_packing"] = ( + self.cfg.sample_packing if self.cfg.sample_packing else False + ) + training_arguments_kwargs["eval_sample_packing"] = ( + self.cfg.sample_packing if self.cfg.sample_packing else False + ) + training_arguments_kwargs[ + "sample_packing_seq_len_multiplier" + ] = self.cfg.micro_batch_size + training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps + training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps + training_arguments_kwargs = self.hook_pre_create_training_args( + training_arguments_kwargs + ) + training_args = ( + AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg + **training_arguments_kwargs, + ) + ) + training_args = self.hook_post_create_training_args(training_args) + trainer_kwargs = {} + + if self.cfg.optimizer == "adamw_anyprecision": + if Path(self.cfg.torchdistx_path).exists(): + sys.path.append(self.cfg.torchdistx_path) + importlib.import_module("torchdistx") + + data_collator_kwargs = { + "padding": True, # True/"longest" is the default + } + if self.cfg.pad_to_sequence_len: + data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil( + self.cfg.sequence_len / 64 + ) + else: + # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check + # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html + data_collator_kwargs["pad_to_multiple_of"] = 64 + + if self.cfg.is_llama_derived_model and self.cfg.landmark_attention: + from axolotl.monkeypatch.llama_landmark_attn import ( + add_mem_tokens, + get_mem_id, + set_model_mem_id, + ) + + set_model_mem_id(self.model, self.tokenizer) + + LOG.info("Adding landmark attention tokens to dataset") + + for dataset in [self.train_dataset, self.eval_dataset]: + dataset = dataset.map( + partial( + add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer) + ), + batched=False, + num_proc=32, + ) + + trainer_cls = self._get_trainer_cls() + trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( + trainer_kwargs, trainer_cls + ) + trainer = trainer_cls( + model=self.model, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + args=training_args, + data_collator=DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **data_collator_kwargs, + ), + bench_data_collator=transformers.DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **data_collator_kwargs, + ), + callbacks=self.get_callbacks(), + **trainer_kwargs, + ) + trainer = self.hook_post_create_trainer(trainer) + for callback in self.get_post_trainer_create_callbacks(trainer): + trainer.add_callback(callback) + + return trainer diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 458e537c63..4191bcf164 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -37,7 +37,7 @@ ) if TYPE_CHECKING: - from axolotl.utils.trainer import AxolotlTrainingArguments + from axolotl.core.trainer_builder import AxolotlTrainingArguments LOG = logging.getLogger("axolotl.callbacks") IGNORE_INDEX = -100 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 820202b80b..0d275cbf55 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,40 +1,19 @@ """Module containing the Trainer class and related functions""" -import importlib import logging import math import os -import sys from contextlib import contextmanager -from dataclasses import dataclass, field from functools import partial -from pathlib import Path -from typing import List, Optional, Union +from typing import List import numpy as np import torch import torch.cuda import torch.distributed as dist -import transformers -from datasets import Dataset, set_caching_enabled -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import ( - DataLoader, - DistributedSampler, - RandomSampler, - SequentialSampler, -) -from transformers import EarlyStoppingCallback, Trainer, TrainingArguments -from transformers.trainer_pt_utils import SequentialDistributedSampler - -from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler -from axolotl.utils.callbacks import ( - EvalFirstStepCallback, - GPUStatsCallback, - SaveAxolotlConfigtoWandBCallback, - SaveBetterTransformerModelCallback, - bench_eval_callback_factory, - log_prediction_callback_factory, -) +from datasets import set_caching_enabled +from torch.utils.data import DistributedSampler, RandomSampler + +from axolotl.core.trainer_builder import HFCausalTrainerBuilder from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.distributed import ( @@ -43,7 +22,6 @@ reduce_and_broadcast, zero_first, ) -from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup LOG = logging.getLogger("axolotl") @@ -110,269 +88,6 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True): return weighted_cross_entropy(logits, labels, weights) -@dataclass -class AxolotlTrainingArguments(TrainingArguments): - """ - Extend the base TrainingArguments for axolotl helpers - """ - - lr_quadratic_warmup: bool = field( - default=False, - metadata={"help": "Use quadratic warmup for cosine scheduling."}, - ) - sample_packing: bool = field( - default=False, - metadata={"help": "Use sample packing for efficient training."}, - ) - eval_sample_packing: Optional[bool] = field( - default=None, - metadata={"help": "Use sample packing for efficient evals."}, - ) - sample_packing_efficiency: float = field( - default=1.0, - metadata={"help": "Sample packing efficiency for calculating batch length."}, - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "The maximum sequence length the model can handle"}, - ) - sample_packing_seq_len_multiplier: int = field( - default=1, - metadata={"help": "the multiplier for the max len for packed sequences"}, - ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - bench_split: Optional[str] = field( - default="eval", metadata={"help": "The benchmark split to run on"} - ) - bench_dataset: Optional[str] = field( - default="pharaouk/dharma-1/dharma_1_mini.json", - metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" - }, - ) - do_bench_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Benchmark evaluation."} - ) - max_bench_samples: Optional[int] = field( - default=None, - metadata={ - "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." - }, - ) - bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for bench."} - ) - - -class AxolotlTrainer(Trainer): - """ - Extend the base Trainer for axolotl helpers - """ - - args = None # type: AxolotlTrainingArguments - - def __init__(self, *args, bench_data_collator=None, **kwargs): - self.bench_data_collator = bench_data_collator - super().__init__(*args, **kwargs) - - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or - passed as an argument. - - Args: - num_training_steps (int): The number of training steps to do. - optimizer (torch.optim.Optimizer): The training optimizer - """ - - # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition - # fmt: on - if ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ): - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - else: - return super().create_scheduler(num_training_steps, optimizer) - return self.lr_scheduler - - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size > 1 and self.args.sample_packing: - return DistributedSampler( - self.train_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - seed=self.args.seed, - ) - return super()._get_train_sampler() - - def _get_eval_sampler( - self, eval_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if ( - self.args.world_size > 1 - and self.args.sample_packing - and self.args.eval_sample_packing is not False - ): - return SequentialDistributedSampler( - eval_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - batch_size=self.args.per_device_eval_batch_size, - ) - return super()._get_eval_sampler(eval_dataset) - - def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: - if self.args.sample_packing: - train_sampler = self._get_train_sampler() - return self.accelerator.prepare( - MultipackDistributedDataloader( - self.train_dataset, - batch_size=self._train_batch_size, - seq_max_length=self.args.max_seq_length, - collate_fn=self.data_collator, - sampler=train_sampler, - packing_efficiency_estimate=self.args.sample_packing_efficiency, - sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, - device_count=int(os.environ.get("WORLD_SIZE", 1)), - ) - ) - return super().get_train_dataloader() - - def get_eval_dataloader( - self, eval_dataset: Optional[Dataset] = None - ) -> Union[DataLoader, MultipackDistributedDataloader]: - if self.args.sample_packing and self.args.eval_sample_packing is not False: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset - ) - - eval_sampler = self._get_eval_sampler(eval_dataset) - return self.accelerator.prepare( - MultipackDistributedDataloader( - eval_dataset, - batch_size=self.args.eval_batch_size, - seq_max_length=self.args.max_seq_length, - collate_fn=self.data_collator, - sampler=eval_sampler, - packing_efficiency_estimate=self.args.sample_packing_efficiency, - sample_packing_seq_len_multiplier=self.args.eval_batch_size, - device_count=int(os.environ.get("WORLD_SIZE", 1)), - ) - ) - return super().get_eval_dataloader(eval_dataset) - - def _get_bench_sampler( - self, bench_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size <= 1: - return SequentialSampler(bench_dataset) - return None - - def get_bench_dataloader( - self, - bench_dataset: Dataset, - ) -> Union[DataLoader, MultipackDistributedDataloader]: - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": self.bench_data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - - if not isinstance(bench_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - return DataLoader(bench_dataset, **dataloader_params) - # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) - - def compute_loss(self, model, inputs, return_outputs=False): - # use one's weighted cross entropy loss calc - # if self.args.sample_packing: - # labels = inputs.pop("labels") - # outputs = model(**inputs) - # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) - # return (loss, outputs) if return_outputs else loss - return super().compute_loss(model, inputs, return_outputs=return_outputs) - - -class OneCycleLRSchedulerTrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - pct_start = num_warmup_steps / num_training_steps - - self.lr_scheduler = OneCycleLR( - optimizer, - max_lr=self.args.learning_rate, - total_steps=num_training_steps, - pct_start=pct_start, - div_factor=6, - ) - - return self.lr_scheduler - - -class ReLoRATrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler = super().create_scheduler(num_training_steps, optimizer) - - if self.args.relora_steps: - warmup_steps = ( - self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 - ) - self.lr_scheduler = ReLoRAScheduler( - optimizer, - lr_scheduler, - self.args.relora_steps, - warmup_steps, - ) - else: - self.lr_scheduler = lr_scheduler - - return self.lr_scheduler - - def add_position_ids(sample): sample_len = len(sample["input_ids"]) sample["position_ids"] = torch.arange(len(sample["input_ids"])) @@ -550,245 +265,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ elif cfg.deepspeed: os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - warmup_steps = ( - cfg.warmup_steps - if cfg.warmup_steps is not None - else min(int(0.03 * total_num_steps), 100) - ) - logging_steps = ( - cfg.logging_steps - if cfg.logging_steps is not None - else max(min(int(0.005 * total_num_steps), 10), 1) - ) - - training_arguments_kwargs = {} - if cfg.bf16 == "full": - training_arguments_kwargs["bf16_full_eval"] = True - else: - training_arguments_kwargs["bf16"] = cfg.bf16 - training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False - training_arguments_kwargs["tf32"] = cfg.tf32 - training_arguments_kwargs["warmup_steps"] = warmup_steps - training_arguments_kwargs["logging_steps"] = logging_steps - - if cfg.seed: - training_arguments_kwargs["seed"] = cfg.seed - - if cfg.gradient_checkpointing: - training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing - if cfg.fsdp: - training_arguments_kwargs["fsdp"] = cfg.fsdp - if cfg.fsdp_config: - training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) - - # deepspeed - if cfg.deepspeed: - training_arguments_kwargs["deepspeed"] = cfg.deepspeed - - if cfg.lr_quadratic_warmup is not None: - training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup - - if cfg.adam_beta1: - training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1 - if cfg.adam_beta2: - training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2 - if cfg.adam_epsilon: - training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon - if cfg.max_grad_norm: - training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm - - if cfg.hub_model_id: - training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id - training_arguments_kwargs["push_to_hub"] = True - training_arguments_kwargs["hub_private_repo"] = True - - if cfg.hub_strategy: - training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy - - if cfg.save_safetensors: - training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors - - if cfg.sample_packing_eff_est: - training_arguments_kwargs[ - "sample_packing_efficiency" - ] = cfg.sample_packing_eff_est - - if cfg.eval_steps: - training_arguments_kwargs["evaluation_strategy"] = "steps" - training_arguments_kwargs["eval_steps"] = cfg.eval_steps - elif cfg.evaluation_strategy: - training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy - elif cfg.val_set_size == 0: - # no eval set, so don't eval - training_arguments_kwargs["evaluation_strategy"] = "no" - else: - # we have an eval set, but no steps defined, default to use epoch - training_arguments_kwargs["evaluation_strategy"] = "epoch" - - if cfg.save_steps: - training_arguments_kwargs["save_strategy"] = "steps" - training_arguments_kwargs["save_steps"] = cfg.save_steps - elif cfg.save_strategy: - training_arguments_kwargs["save_strategy"] = cfg.save_strategy - else: - # default to saving each epoch if not defined - training_arguments_kwargs["save_strategy"] = "epoch" - - if cfg.do_bench_eval: - training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval - if cfg.bench_dataset: - training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset - if cfg.metric_for_best_model: - training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model - if cfg.greater_is_better: - training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better - - if cfg.torch_compile: - if torch.__version__ < "2.1.0": # pylint: disable=protected-access - LOG.warning("torch>=2.1.0 required for torch_compile to work properly") - else: - import torch._dynamo # pylint: disable=redefined-outer-name - - torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access - True - ) - training_arguments_kwargs["torch_compile"] = cfg.torch_compile - if cfg.torch_compile_backend: - training_arguments_kwargs[ - "torch_compile_backend" - ] = cfg.torch_compile_backend - - # DDP Config - if cfg.ddp_timeout: - training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout - # see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html - if cfg.ddp_bucket_cap_mb: - training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb - if cfg.ddp_broadcast_buffers is not None: - training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers - - training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - max_steps=total_num_steps if cfg.max_steps else -1, - max_seq_length=cfg.sequence_len, - per_device_train_batch_size=cfg.micro_batch_size, - per_device_eval_batch_size=cfg.eval_batch_size, - gradient_accumulation_steps=cfg.gradient_accumulation_steps, - eval_accumulation_steps=cfg.gradient_accumulation_steps, - num_train_epochs=cfg.num_epochs, - learning_rate=cfg.learning_rate, - output_dir=cfg.output_dir, - save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4, - load_best_model_at_end=( - (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience) - and cfg.val_set_size > 0 - and cfg.save_steps - and cfg.eval_steps - and cfg.save_steps % cfg.eval_steps == 0 - ) - or False, - ddp_find_unused_parameters=False if cfg.ddp else None, - group_by_length=cfg.group_by_length, - report_to="wandb" if cfg.use_wandb else None, - run_name=cfg.wandb_run_id if cfg.use_wandb else None, - optim=cfg.optimizer if cfg.optimizer else "adamw_hf", - lr_scheduler_type=cfg.lr_scheduler - if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") - else "cosine", - weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, - sample_packing=cfg.sample_packing if cfg.sample_packing else False, - eval_sample_packing=cfg.eval_sample_packing, - sample_packing_seq_len_multiplier=cfg.micro_batch_size, - relora_steps=cfg.relora_steps, - relora_warmup_steps=cfg.relora_warmup_steps, - **training_arguments_kwargs, - ) - - trainer_kwargs = {} - - if cfg.optimizer == "adamw_anyprecision": - if Path(cfg.torchdistx_path).exists(): - sys.path.append(cfg.torchdistx_path) - importlib.import_module("torchdistx") - - callbacks = [] - callbacks.append(GPUStatsCallback(cfg)) - callbacks.append(EvalFirstStepCallback) - - if cfg.relora_steps: - callbacks.append(ReLoRACallback(cfg)) - - if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: - callbacks.append(SaveBetterTransformerModelCallback) - - data_collator_kwargs = { - "padding": True, # True/"longest" is the default - } - if cfg.pad_to_sequence_len: - data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil( - cfg.sequence_len / 64 - ) - else: - # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check - # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html - data_collator_kwargs["pad_to_multiple_of"] = 64 - - if cfg.is_llama_derived_model and cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( - add_mem_tokens, - get_mem_id, - set_model_mem_id, - ) - - set_model_mem_id(model, tokenizer) - - LOG.info("Adding landmark attention tokens to dataset") - - for dataset in [train_dataset, eval_dataset]: - dataset = dataset.map( - partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)), - batched=False, - num_proc=32, - ) - - trainer_cls = AxolotlTrainer - if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"): - trainer_cls = OneCycleLRSchedulerTrainer - elif cfg.relora_steps: - trainer_cls = ReLoRATrainer - trainer = trainer_cls( - model=model, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - args=training_args, - data_collator=DataCollatorForSeq2Seq( - tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), - bench_data_collator=transformers.DataCollatorForSeq2Seq( - tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), - callbacks=callbacks, - **trainer_kwargs, - ) - - if cfg.use_wandb and cfg.eval_table_size > 0: - LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) - trainer.add_callback(LogPredictionCallback(cfg)) - - if cfg.use_wandb: - trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path)) - - if cfg.do_bench_eval: - trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) - - # TODO on_save callback to sync checkpoints to GCP/AWS in background - if cfg.early_stopping_patience: - early_stop_cb = EarlyStoppingCallback( - cfg.early_stopping_patience, - ) - trainer.add_callback(early_stop_cb) + trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) + trainer_builder.train_dataset = train_dataset + trainer_builder.eval_dataset = eval_dataset - return trainer + return trainer_builder.build(total_num_steps)