From 666aab4e4fe028bd9e54ce9e74e2c6bc364b9db6 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Sat, 2 Mar 2024 14:22:03 +0800 Subject: [PATCH] bugfix: validation --- angle_emb/train_cli.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/angle_emb/train_cli.py b/angle_emb/train_cli.py index 17839e6..f3c4f4e 100644 --- a/angle_emb/train_cli.py +++ b/angle_emb/train_cli.py @@ -25,8 +25,10 @@ help='Specify huggingface datasets subset name for train set') parser.add_argument('--train_split_name', type=str, default='train', help='Specify huggingface datasets split name for train set, Default `train`') -parser.add_argument('--valid_split_name', type=str, default=None, - help='Specify huggingface datasets split name for valid set, Default None') +parser.add_argument('--valid_name_or_path', type=str, default=None, + help='Specify huggingface datasets name or local file path for valid set.') +parser.add_argument('--valid_subset_name', type=str, default=None, + help='Specify huggingface datasets subset name for valid set') parser.add_argument('--prompt_template', type=str, default=None, help='Specify prompt_template like "Instruct: xxx\nInput: {text}", default None') parser.add_argument('--save_dir', type=str, default=None, @@ -150,10 +152,15 @@ def main(): train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map( AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), num_proc=args.workers) + valid_ds = None - if args.valid_split_name is not None: + if valid_ds is None and args.valid_name_or_path is not None: logger.info('Validation detected, processing validation...') - valid_ds = ds[args.valid_split_name].shuffle(args.dataset_seed).map( + if os.path.exists(args.valid_name_or_path): + valid_ds = load_dataset('json', data_files=[args.valid_name_or_path]) + else: + valid_ds = load_dataset(args.valid_name_or_path, args.valid_subset_name) + valid_ds = valid_ds[args.valid_subset_name or 'train'].map( AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), num_proc=args.workers)