Skip to content

Commit

Permalink
various bugfixes (#856)
Browse files Browse the repository at this point in the history
* various bugfixes

use latest tinyllama release
check if val_set_size is empty first
update sdp and xformers llama patches for updated upstream transformers
fix system prompt when no input
calculate total and total supervised tokens even when not sample packing

* add fix for when eval size is estimated to be too small

* should be len 1 for dataset length

* add catchall kwargs
  • Loading branch information
winglian authored Nov 15, 2023
1 parent 501b4d1 commit 1470650
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 37 deletions.
2 changes: 1 addition & 1 deletion examples/llama-2/tiny-llama.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
Expand Down
8 changes: 4 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,16 +543,16 @@ def build(self, total_num_steps):
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor

if self.cfg.eval_steps:
if self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
elif self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def sdp_attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def xformers_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _build_result(self, instruction, input_text, output):
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_prompt
if self.system_no_input_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
if output:
Expand Down
21 changes: 12 additions & 9 deletions src/axolotl/utils/samplers/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,16 @@ def _len_est(self):
)

# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.batch_max_len
)
- 1
return min(
1,
(
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.batch_max_len
)
- 1
),
)
45 changes: 23 additions & 22 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,31 +142,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):


def calculate_total_num_steps(cfg, train_dataset):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
cfg.total_num_tokens = total_num_tokens

if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
cfg.total_supervised_tokens = total_supervised_tokens

if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
cfg.total_num_tokens = total_num_tokens

if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
cfg.total_supervised_tokens = total_supervised_tokens

if cfg.sample_packing_eff_est:
total_num_steps = (
Expand Down

0 comments on commit 1470650

Please sign in to comment.