From c8e735be880fd1be5e71026c7434fe151c7c172e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 16 Jan 2024 07:29:19 -0500 Subject: [PATCH] set dataloader preload args --- src/axolotl/core/trainer_builder.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b0f7fcdcbf..030a8eb8f8 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -949,6 +949,18 @@ def build_training_arguments(self, total_num_steps): if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) + if self.cfg.hub_model_id: + training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id + training_args_kwargs["push_to_hub"] = True + training_args_kwargs["hub_private_repo"] = True + training_args_kwargs["hub_always_push"] = True + + if self.cfg.hub_strategy: + training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy + + if self.cfg.save_safetensors is not None: + training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors + if self.eval_dataset: training_args_kwargs["evaluation_strategy"] = "steps" training_args_kwargs["eval_steps"] = self.cfg.eval_steps @@ -964,6 +976,19 @@ def build_training_arguments(self, total_num_steps): self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) + if self.cfg.dataloader_pin_memory is not None: + training_args_kwargs[ + "dataloader_pin_memory" + ] = self.cfg.dataloader_pin_memory + if self.cfg.dataloader_num_workers is not None: + training_args_kwargs[ + "dataloader_num_workers" + ] = self.cfg.dataloader_num_workers + if self.cfg.dataloader_prefetch_factor is not None: + training_args_kwargs[ + "dataloader_prefetch_factor" + ] = self.cfg.dataloader_prefetch_factor + training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=total_num_steps,