Skip to content

Commit

Permalink
added chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhardwaj-Rishabh committed May 15, 2024
1 parent 5d24275 commit 8d1a11a
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions walledeval/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,41 @@ class HF_LLM(LLM):
def __init__(self, id: str, system_prompt: str = "", **kwargs):
super().__init__(id, system_prompt)
self.pipeline = pipeline(
"text-generation", id,
"text-generation",
model=id,
trust_remote_code=True,
**kwargs
)

def generate(self, text: str) -> str:
text = self.pipeline([

messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": text},
], max_new_tokens = 128)[0]['generated_text'][-1]["content"]
return text

{"role": "user", "content": "Who are you?"},
]

prompt = self.pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)

terminators = [
self.pipeline.tokenizer.eos_token_id,
self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = self.pipeline(
prompt,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)[0]["generated_text"][len(prompt):].strip()

return outputs

class Claude(LLM):
def __init__(self, api_key: str, system_prompt: str = ""):
super().__init__("Claude 3 Opus", system_prompt)
Expand Down

0 comments on commit 8d1a11a

Please sign in to comment.