Skip to content

Commit

Permalink
bugfix: validation
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Mar 2, 2024
1 parent 2f0303f commit 666aab4
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions angle_emb/train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
help='Specify huggingface datasets subset name for train set')
parser.add_argument('--train_split_name', type=str, default='train',
help='Specify huggingface datasets split name for train set, Default `train`')
parser.add_argument('--valid_split_name', type=str, default=None,
help='Specify huggingface datasets split name for valid set, Default None')
parser.add_argument('--valid_name_or_path', type=str, default=None,
help='Specify huggingface datasets name or local file path for valid set.')
parser.add_argument('--valid_subset_name', type=str, default=None,
help='Specify huggingface datasets subset name for valid set')
parser.add_argument('--prompt_template', type=str, default=None,
help='Specify prompt_template like "Instruct: xxx\nInput: {text}", default None')
parser.add_argument('--save_dir', type=str, default=None,
Expand Down Expand Up @@ -150,10 +152,15 @@ def main():
train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map(
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template), num_proc=args.workers)

valid_ds = None
if args.valid_split_name is not None:
if valid_ds is None and args.valid_name_or_path is not None:
logger.info('Validation detected, processing validation...')
valid_ds = ds[args.valid_split_name].shuffle(args.dataset_seed).map(
if os.path.exists(args.valid_name_or_path):
valid_ds = load_dataset('json', data_files=[args.valid_name_or_path])
else:
valid_ds = load_dataset(args.valid_name_or_path, args.valid_subset_name)
valid_ds = valid_ds[args.valid_subset_name or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template), num_proc=args.workers)

Expand Down

0 comments on commit 666aab4

Please sign in to comment.