Skip to content

Commit

Permalink
Switch to parallel FFD bin packing algorithm.
Browse files Browse the repository at this point in the history
  • Loading branch information
dsesclei committed Apr 11, 2024
1 parent 5ed2939 commit 739dd5f
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 289 deletions.
9 changes: 5 additions & 4 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,11 @@ pad_to_sequence_len:
sample_packing:
# Set to 'false' if getting errors during eval with sample_packing on.
eval_sample_packing:
# You can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# Increasing the following values helps with packing, but usually only a bit (<%1.)
# The number of samples each process considers during packing.
sample_packing_group_size: 25000
# The number of samples which can be packed into one bin.
sample_packing_bin_size: 200

# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
Expand Down
68 changes: 35 additions & 33 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,22 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=25000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
Expand Down Expand Up @@ -340,11 +344,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
)
return super()._get_train_sampler()

Expand All @@ -362,11 +366,11 @@ def _get_eval_sampler(
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
batch_size=batch_size,
drop_last=True,
lengths=get_dataset_lengths(self.eval_dataset),
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
)
return super()._get_eval_sampler(eval_dataset)

Expand Down Expand Up @@ -1058,11 +1062,6 @@ def build(self, total_num_steps):
if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors

if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est

if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
Expand Down Expand Up @@ -1232,20 +1231,23 @@ def build(self, total_num_steps):
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.flash_attention is not True
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
else False
)

training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
"multipack_real_batches"
] = not self.cfg.flash_attention
training_arguments_kwargs[
"eval_sample_packing"
] = bool(self.cfg.eval_sample_packing)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[
"sample_packing_bin_size"
] = self.cfg.sample_packing_bin_size
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[
"sample_packing_group_size"
] = self.cfg.sample_packing_group_size

if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[
Expand Down
10 changes: 7 additions & 3 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@ class DeprecatedParameters(BaseModel):
max_packed_sequence_len: Optional[int] = None
rope_scaling: Optional[Any] = None
noisy_embedding_alpha: Optional[float] = None
sample_packing_eff_est: Optional[float] = None

@field_validator("max_packed_sequence_len")
@classmethod
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
if max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
return max_packed_sequence_len

@field_validator("rope_scaling")
@classmethod
def validate_rope_scaling(cls, rope_scaling):
if rope_scaling:
raise DeprecationWarning(
Expand All @@ -42,12 +41,17 @@ def validate_rope_scaling(cls, rope_scaling):
return rope_scaling

@field_validator("noisy_embedding_alpha")
@classmethod
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
if noisy_embedding_alpha:
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
return noisy_embedding_alpha

@field_validator("sample_packing_eff_est")
def validate_sample_packing_eff_est(cls, sample_packing_eff_est):
if sample_packing_eff_est:
LOG.warning("sample_packing_eff_est is deprecated and no longer necessary")
return sample_packing_eff_est


class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names"""
Expand Down
5 changes: 2 additions & 3 deletions src/axolotl/utils/data/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,10 @@ def encode_packed_pretraining(
)

sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
sampler=RandomSampler(train_dataset),
lengths=get_dataset_lengths(train_dataset),
batch_size=1,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=get_dataset_lengths(train_dataset),
)

chunked_data = defaultdict(list)
Expand Down
Loading

0 comments on commit 739dd5f

Please sign in to comment.