Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a new parameter fix_data #103

Merged
merged 1 commit into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class AngleDataTokenizer:
:param dataset_format: Optional[str]. Specify dataset_format from DatasetFormats. Default None.
It will automatically detect the dataset format.
:param end_with_eos: bool. Specify whether ends with the eos token. Default False.
:param fix_data: bool. Specify whether fix the data. Only works when prompt_template is not None. Default True.

Example::

Expand All @@ -415,14 +416,16 @@ def __init__(self,
template_placeholders: Optional[List[str]] = None,
extra_columns: Optional[List[str]] = None,
dataset_format: Optional[str] = None,
end_with_eos: bool = False):
end_with_eos: bool = False,
fix_data: bool = True):
self.tokenizer = tokenizer
self.max_length = max_length
self.prompt_template = prompt_template
self.prompt_template_tok = None
self.extra_columns = extra_columns
self.dataset_format = dataset_format
self.end_with_eos = end_with_eos
self.fix_data = fix_data
if template_placeholders is None:
template_placeholders = ['condition', 'text']
if prompt_template is not None:
Expand Down Expand Up @@ -492,7 +495,7 @@ def __call__(self, data: Dict) -> Dict:
for text_column in text_columns:
toks.append(self.tokenizer(data[text_column], max_length=self.max_length, truncation=True))

if self.prompt_template_tok is not None:
if self.prompt_template_tok is not None and self.fix_data:
for tok in toks:
if tok['input_ids'][-1] != self.prompt_template_tok['input_ids'][-1]:
logger.info(f"data data: token ids={tok['input_ids']}, prompt_token_ids={self.prompt_template_tok['input_ids']}") # NOQA
Expand Down
18 changes: 14 additions & 4 deletions angle_emb/angle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
'This prompt will be applied for all text columns.'
'If you want to specify different prompts for different text columns,'
'please handle it in the preprocessing step.')
parser.add_argument('--fix_data', type=int, default=1, choices=[0, 1],
help='Whether fix data (only works when prompt_template is not None), choices [0, 1], defaut 1')
parser.add_argument('--filter_duplicate', type=int, default=1, choices=[0, 1],
help='Specify filter_duplicate, choices [0, 1], defaut 1')
parser.add_argument('--save_dir', type=str, default=None,
Expand Down Expand Up @@ -221,11 +223,15 @@ def main():
logger.info('Processing train...')
if args.streaming:
train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template,
fix_data=args.fix_data),
num_proc=args.workers)
else:
train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template,
fix_data=args.fix_data),
num_proc=args.workers)

valid_ds = None
Expand All @@ -239,7 +245,9 @@ def main():
else:
valid_ds = load_dataset(args.valid_name_or_path, num_proc=args.workers)
valid_ds = valid_ds[args.valid_split_name or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template,
fix_data=args.fix_data),
num_proc=args.workers)

valid_ds_for_callback = None
Expand All @@ -258,7 +266,9 @@ def main():
valid_ds_for_callback = load_dataset(
args.valid_name_or_path_for_callback, num_proc=args.workers)
valid_ds_for_callback = valid_ds_for_callback[args.valid_split_name_for_callback or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template,
fix_data=args.fix_data),
num_proc=args.workers)

argument_kwargs = {}
Expand Down
Loading