Skip to content

Commit

Permalink
Merge pull request #74 from SeanLee97/feature/tokenizer-name-or-path
Browse files Browse the repository at this point in the history
support specify tokenizer_name_or_path
  • Loading branch information
SeanLee97 authored May 22, 2024
2 parents 0c4f274 + 241f352 commit e6a0f0b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 4 additions & 1 deletion angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ class AnglE:
AnglE. Everything is here👋
:param model_name_or_path: str, model name or path.
:param tokenizer_name_or_path: Optional[str]. Default None. When it set to None, it will use the same as `model_name_or_path`.
:param max_length: int. Default 512
:param model_kwargs: Optional[Dict]. kwargs for model.
:param lora_config_kwargs: Optional[Dict]. kwargs for peft lora_config.
Expand All @@ -1101,6 +1102,7 @@ class AnglE:

def __init__(self,
model_name_or_path: str,
tokenizer_name_or_path: Optional[str] = None,
max_length: int = 512,
model_kwargs: Optional[Dict] = None,
lora_config_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -1173,7 +1175,8 @@ def __init__(self,
if train_mode:
logger.info(f'lora_config={lora_config}')

self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path or model_name_or_path, trust_remote_code=True)
if tokenizer_padding_side is not None and self.tokenizer.padding_side != tokenizer_padding_side:
self.tokenizer.padding_side = tokenizer_padding_side
if self.is_llm and self.tokenizer.pad_token_id is None:
Expand Down
3 changes: 3 additions & 0 deletions angle_emb/angle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, required=True,
help='Specify model name or path to set transformer backbone, required')
parser.add_argument('--tokenizer_name_or_path', type=str, default=None,
help='Specify tokenizer name or path. Default None, will use model_name_or_path')
parser.add_argument('--pretrained_model_path', type=str, default=None,
help='Specify pretrained model path to load pretrained model, default None')
parser.add_argument('--pretrained_lora_path', type=str, default=None,
Expand Down Expand Up @@ -159,6 +161,7 @@

def main():
model = AnglE(args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
max_length=args.maxlen,
pretrained_model_path=args.pretrained_model_path,
pretrained_lora_path=args.pretrained_lora_path,
Expand Down

0 comments on commit e6a0f0b

Please sign in to comment.