Skip to content

Commit

Permalink
fix model specific args (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
bluecoconut authored May 22, 2023
1 parent 97951f4 commit 39d96c7
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions lambdaprompt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,28 +219,32 @@ class Parameters(Backend.Parameters):
repetition_penalty: float = 1.1
stop: Optional[Union[str, List[str]]]

def __init__(self, model_name, torch_dtype=None, trust_remote_code=True, use_auth_token=None, **param_override):
def __init__(self, model_name, torch_dtype=None, trust_remote_code=True, use_auth_token=None, use_device_map=True, load_config=True, **param_override):
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
torch_dtype = torch_dtype or torch.bfloat16
super().__init__(**param_override)
config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True
)
init_kwargs = {
"torch_dtype": torch_dtype,
"trust_remote_code": trust_remote_code,
"use_auth_token": use_auth_token,
}
if load_config:
init_kwargs['config'] = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True
)
if use_device_map:
init_kwargs['device_map'] = "auto"
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
config=config,
device_map="auto"
**init_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
device_map="auto"
**({"device_map":"auto"} if use_device_map else {})
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
Expand Down Expand Up @@ -290,12 +294,12 @@ def stop_on_any(input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs)

class MPT7BInstructCompletion(HuggingFaceBackend):
def __init__(self, **kwargs):
super().__init__("mosaicml/mpt-7b-instruct", **kwargs)
super().__init__("mosaicml/mpt-7b-instruct", use_device_map=False, **kwargs)


class StarCoderCompletion(HuggingFaceBackend):
def __init__(self, hf_access_token=None, **kwargs):
hf_access_token = hf_access_token or os.environ.get("HF_ACCESS_TOKEN")
if not hf_access_token:
raise Exception("No HuggingFace access token found (envvar HF_ACCESS_TOKEN))")
super().__init__("bigcode/starcoder", use_auth_token=hf_access_token, **kwargs)
super().__init__("bigcode/starcoder", use_auth_token=hf_access_token, load_config=False, **kwargs)

0 comments on commit 39d96c7

Please sign in to comment.