diff --git a/angle_emb/train_cli.py b/angle_emb/train_cli.py index 7fbd8b1..26e29ae 100644 --- a/angle_emb/train_cli.py +++ b/angle_emb/train_cli.py @@ -19,8 +19,6 @@ help='Specify pretrained model path to load pretrained model, default None') parser.add_argument('--pretrained_lora_path', type=str, default=None, help='Specify pretrained lora path to load lora, default None') -parser.add_argument('--bellm_class_name', type=str, default=None, - help='Specify bellm class name, default None') parser.add_argument('--train_name_or_path', type=str, required=True, help='Specify huggingface datasets name or local file path for train set, required') parser.add_argument('--train_subset_name', type=str, default=None, @@ -134,11 +132,8 @@ def main(): 'r': args.lora_r, 'lora_alpha': args.lora_alpha, 'lora_dropout': args.lora_dropout, - 'target_modules': ['fc2', 'Wqkv', 'fc1'] if 'BePhi2Model' == args.bellm_class_name else None, }, load_kbit=args.load_kbit, - bellm_class_name=args.bellm_class_name, - kbit_kwargs={'use_gradient_checkpointing': False} if 'BePhi2Model' == args.bellm_class_name else None, torch_dtype=args.torch_dtype) if args.start_bilayer_index is not None: