diff --git a/angle_emb/angle.py b/angle_emb/angle.py index e63a29c..6baae43 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -1431,7 +1431,7 @@ def fit(self, warmup_steps: int = 1000, logging_steps: int = 10, eval_steps: int = 1000, - evaluation_strategy: str = 'steps', + eval_strategy: str = 'steps', save_steps: int = 100, save_strategy: str = 'steps', save_total_limit: int = 1, @@ -1462,7 +1462,7 @@ def fit(self, :param warmup_steps: int. Default 1000. :param logging_steps: int. Default 10. :param eval_steps: int. Default 1000. - :param evaluation_strategy: str. Default 'steps'. + :param eval_strategy: str. Default 'steps'. :param save_steps: int. Default 100. :param save_strategy: str. Default steps. :param save_total_limit: int. Default 10. @@ -1549,7 +1549,7 @@ def fit(self, logging_steps=logging_steps, save_steps=save_steps, save_strategy=save_strategy, - evaluation_strategy=evaluation_strategy if valid_ds is not None else 'no', + eval_strategy=eval_strategy if valid_ds is not None else 'no', eval_steps=eval_steps, output_dir=output_dir, save_total_limit=save_total_limit, diff --git a/angle_emb/angle_trainer.py b/angle_emb/angle_trainer.py index 9c54f47..c52f7a6 100644 --- a/angle_emb/angle_trainer.py +++ b/angle_emb/angle_trainer.py @@ -99,8 +99,8 @@ parser.add_argument('--save_strategy', type=str, default='steps', choices=['steps', 'epoch', 'no'], help='Specify save_strategy, default steps') parser.add_argument('--eval_steps', type=int, default=1000, help='Specify eval_steps, default 1000') -parser.add_argument('--evaluation_strategy', type=str, default='steps', choices=['steps', 'epoch', 'no'], - help='Specify evaluation_strategy, default steps') +parser.add_argument('--eval_strategy', type=str, default='steps', choices=['steps', 'epoch', 'no'], + help='Specify eval_strategy, default steps') parser.add_argument('--batch_size', type=int, default=32, help='Specify batch size, default 32') parser.add_argument('--maxlen', type=int, default=512, help='Specify max length, default 512') parser.add_argument('--streaming', action='store_true', default=False, @@ -307,7 +307,7 @@ def main(): save_strategy=args.save_strategy, save_total_limit=args.save_total_limit, eval_steps=args.eval_steps, - evaluation_strategy=args.evaluation_strategy, + eval_strategy=args.eval_strategy, warmup_steps=args.warmup_steps, logging_steps=args.logging_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,