diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 6053a7c..f674681 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -1077,6 +1077,7 @@ class AnglE: AnglE. Everything is heređź‘‹ :param model_name_or_path: str, model name or path. + :param tokenizer_name_or_path: Optional[str]. Default None. When it set to None, it will use the same as `model_name_or_path`. :param max_length: int. Default 512 :param model_kwargs: Optional[Dict]. kwargs for model. :param lora_config_kwargs: Optional[Dict]. kwargs for peft lora_config. @@ -1101,6 +1102,7 @@ class AnglE: def __init__(self, model_name_or_path: str, + tokenizer_name_or_path: Optional[str] = None, max_length: int = 512, model_kwargs: Optional[Dict] = None, lora_config_kwargs: Optional[Dict] = None, @@ -1173,7 +1175,8 @@ def __init__(self, if train_mode: logger.info(f'lora_config={lora_config}') - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path or model_name_or_path, trust_remote_code=True) if tokenizer_padding_side is not None and self.tokenizer.padding_side != tokenizer_padding_side: self.tokenizer.padding_side = tokenizer_padding_side if self.is_llm and self.tokenizer.pad_token_id is None: diff --git a/angle_emb/angle_trainer.py b/angle_emb/angle_trainer.py index 7cabb52..5b8a4d8 100644 --- a/angle_emb/angle_trainer.py +++ b/angle_emb/angle_trainer.py @@ -15,6 +15,8 @@ parser = argparse.ArgumentParser() parser.add_argument('--model_name_or_path', type=str, required=True, help='Specify model name or path to set transformer backbone, required') +parser.add_argument('--tokenizer_name_or_path', type=str, default=None, + help='Specify tokenizer name or path. Default None, will use model_name_or_path') parser.add_argument('--pretrained_model_path', type=str, default=None, help='Specify pretrained model path to load pretrained model, default None') parser.add_argument('--pretrained_lora_path', type=str, default=None, @@ -159,6 +161,7 @@ def main(): model = AnglE(args.model_name_or_path, + tokenizer_name_or_path=args.tokenizer_name_or_path, max_length=args.maxlen, pretrained_model_path=args.pretrained_model_path, pretrained_lora_path=args.pretrained_lora_path,