diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0f62aae9af..d055fd4764 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -227,7 +227,8 @@ def _get_eval_sampler( def get_train_dataloader(self) -> DataLoader: if self.args.sample_packing and not self.args.pretraining: train_dataset = self.train_dataset - train_dataset = train_dataset.remove_columns(["length"]) + if "length" in train_dataset.features.keys(): + train_dataset = train_dataset.remove_columns(["length"]) data_collator = self.data_collator dataloader_params = { "batch_size": self._train_batch_size, diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 09bc31db3b..59cbef15dc 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -204,6 +204,9 @@ def validate_config(cfg): if cfg.max_packed_sequence_len: raise DeprecationWarning("`max_packed_sequence_len` is no longer supported") + if cfg.sample_packing and cfg.rl: + raise ValueError("`sample_packing: true` does not work with RLHF training") + if cfg.sample_packing and not cfg.pad_to_sequence_len: LOG.warning( "`pad_to_sequence_len: true` is recommended when using sample_packing"