Skip to content

Commit

Permalink
support configuring more parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Sep 29, 2024
1 parent aeea516 commit 271384f
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions angle_emb/angle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
parser.add_argument('--valid_split_name', type=str, default='train',
help='Specify huggingface datasets split name for valid set, default `train`')
parser.add_argument('--valid_name_or_path_for_callback', type=str, default=None,
help='Specify huggingface datasets name or local file path for valid set for callback use, default None.')
help='Specify huggingface datasets name or local file path for callback valid set, default None.')
parser.add_argument('--valid_subset_name_for_callback', type=str, default=None,
help='Specify huggingface datasets subset name for valid set for callback use, default None')
parser.add_argument('--valid_split_name_for_callback', type=str, default='train',
Expand Down Expand Up @@ -92,6 +92,11 @@
parser.add_argument('--max_steps', type=int, default=-1,
help='Specify max steps, default -1 (Automatically calculated from epochs)')
parser.add_argument('--save_steps', type=int, default=100, help='Specify save_steps, default 1000')
parser.add_argument('--save_strategy', type=str, default='steps', choices=['steps', 'epoch'],
help='Specify save_strategy, default steps')
parser.add_argument('--eval_steps', type=int, default=1000, help='Specify eval_steps, default 1000')
parser.add_argument('--evaluation_strategy', type=str, default='steps', choices=['steps', 'epoch'],
help='Specify evaluation_strategy, default steps')
parser.add_argument('--batch_size', type=int, default=32, help='Specify batch size, default 32')
parser.add_argument('--maxlen', type=int, default=512, help='Specify max length, default 512')
parser.add_argument('--streaming', action='store_true', default=False,
Expand Down Expand Up @@ -234,7 +239,7 @@ def main():
valid_ds = valid_ds[args.valid_split_name or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
num_proc=args.workers)

valid_ds_for_callback = None
if valid_ds_for_callback is None and args.valid_name_or_path_for_callback is not None:
logger.info('Validation for callback detected, processing validation...')
Expand Down Expand Up @@ -287,6 +292,9 @@ def main():
epochs=args.epochs,
learning_rate=args.learning_rate,
save_steps=args.save_steps,
save_strategy=args.save_strategy,
eval_steps=args.eval_steps,
evaluation_strategy=args.evaluation_strategy,
warmup_steps=args.warmup_steps,
logging_steps=args.logging_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
Expand Down

0 comments on commit 271384f

Please sign in to comment.