Skip to content

Commit

Permalink
generate baseline with any given dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 14, 2024
1 parent 4a536c7 commit bd7da22
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions poc/dsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,13 +786,13 @@ async def run(
print("Main: All servers are closed")


def generate(model_name: str, prompt: str, max_new_tokens: int) -> str:
def generate(model_name: str, dtype: torch.dtype, prompt: str, max_new_tokens: int) -> str:
setup_hf_cache()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tok_ids = tokenizer.encode(prompt, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
torch_dtype=dtype,
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)
model.eval()
Expand Down Expand Up @@ -822,12 +822,12 @@ def generate(model_name: str, prompt: str, max_new_tokens: int) -> str:
if __name__ == "__main__":
print("Script started")

verifier_name: str = "lmsys/vicuna-13b-v1.3"
verifier_name: str = "lmsys/vicuna-7b-v1.3"
drafter_name: str = "double7/vicuna-68m"
verifier_dtype: torch.dtype = torch.float32
drafter_dtype: torch.dtype = torch.float16
vocab_size: int = 32000
lookahead: int = 1
lookahead: int = 3
max_new_tokens: int = 100
prompt: str = """Below is an instruction that describes a
task, paired with an input that provides
Expand All @@ -852,6 +852,6 @@ def generate(model_name: str, prompt: str, max_new_tokens: int) -> str:
max_new_tokens=max_new_tokens,
)
)
# print(generate(model_name=verifier_name, prompt=prompt, max_new_tokens=max_new_tokens))
# print(generate(model_name=verifier_name, dtype=verifier_dtype, prompt=prompt, max_new_tokens=max_new_tokens))

print("Script completed")

0 comments on commit bd7da22

Please sign in to comment.