diff --git a/README.md b/README.md index c7c8f5c..4e6e254 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ angle = AnglE.from_pretrained('NousResearch/Llama-2-7b-hf', pooling_strategy='last', is_llm=True, apply_billm=True, - billm_model_class='LlamaForCausalMask', + billm_model_class='LlamaForCausalLM', torch_dtype='float16').cuda() doc_vecs = angle.encode([ diff --git a/docs/notes/quickstart.rst b/docs/notes/quickstart.rst index 348d9fe..776eda2 100644 --- a/docs/notes/quickstart.rst +++ b/docs/notes/quickstart.rst @@ -103,7 +103,7 @@ Specify `apply_billm` and `billm_model_class` to load and infer billm models pooling_strategy='last', is_llm=True, apply_billm=True, - billm_model_class='LlamaForCausalMask', + billm_model_class='LlamaForCausalLM', torch_dtype='float16').cuda() doc_vecs = angle.encode([