Skip to content

Commit

Permalink
drop empty token from beginning if tokenizer has no bos_token (in the…
Browse files Browse the repository at this point in the history
… case of qwen) (#1490)
  • Loading branch information
winglian authored Apr 7, 2024
1 parent bda48f0 commit 934fc85
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
Expand Down Expand Up @@ -802,6 +803,15 @@ def push_to_hub(self, *args, **kwargs) -> str:

return super().push_to_hub(*args, **kwargs)

def tokenize_row(
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
) -> Dict:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res


class TrainerBuilderBase(abc.ABC):
"""
Expand Down

0 comments on commit 934fc85

Please sign in to comment.