diff --git a/angle_emb/angle_trainer.py b/angle_emb/angle_trainer.py index 1608ac0..8936010 100644 --- a/angle_emb/angle_trainer.py +++ b/angle_emb/angle_trainer.py @@ -36,7 +36,7 @@ parser.add_argument('--valid_split_name', type=str, default='train', help='Specify huggingface datasets split name for valid set, default `train`') parser.add_argument('--valid_name_or_path_for_callback', type=str, default=None, - help='Specify huggingface datasets name or local file path for valid set for callback use, default None.') + help='Specify huggingface datasets name or local file path for callback valid set, default None.') parser.add_argument('--valid_subset_name_for_callback', type=str, default=None, help='Specify huggingface datasets subset name for valid set for callback use, default None') parser.add_argument('--valid_split_name_for_callback', type=str, default='train', @@ -92,6 +92,11 @@ parser.add_argument('--max_steps', type=int, default=-1, help='Specify max steps, default -1 (Automatically calculated from epochs)') parser.add_argument('--save_steps', type=int, default=100, help='Specify save_steps, default 1000') +parser.add_argument('--save_strategy', type=str, default='steps', choices=['steps', 'epoch'], + help='Specify save_strategy, default steps') +parser.add_argument('--eval_steps', type=int, default=1000, help='Specify eval_steps, default 1000') +parser.add_argument('--evaluation_strategy', type=str, default='steps', choices=['steps', 'epoch'], + help='Specify evaluation_strategy, default steps') parser.add_argument('--batch_size', type=int, default=32, help='Specify batch size, default 32') parser.add_argument('--maxlen', type=int, default=512, help='Specify max length, default 512') parser.add_argument('--streaming', action='store_true', default=False, @@ -234,7 +239,7 @@ def main(): valid_ds = valid_ds[args.valid_split_name or 'train'].map( AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), num_proc=args.workers) - + valid_ds_for_callback = None if valid_ds_for_callback is None and args.valid_name_or_path_for_callback is not None: logger.info('Validation for callback detected, processing validation...') @@ -287,6 +292,9 @@ def main(): epochs=args.epochs, learning_rate=args.learning_rate, save_steps=args.save_steps, + save_strategy=args.save_strategy, + eval_steps=args.eval_steps, + evaluation_strategy=args.evaluation_strategy, warmup_steps=args.warmup_steps, logging_steps=args.logging_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,