Skip to content

Commit

Permalink
feat(guardasajudge): allow customizing post-processing functions and …
Browse files Browse the repository at this point in the history
…add checks for prompt template existence
  • Loading branch information
ThePyProgrammer committed Aug 2, 2024
1 parent c82c028 commit 33754b2
Showing 1 changed file with 23 additions and 27 deletions.
50 changes: 23 additions & 27 deletions walledeval/judge/llm/guard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# walledeval/judge/llm/guard.py

from enum import Enum
from typing import Callable

from walledeval.llm import LLM, HF_LLM
from walledeval.types import LLMType, Messages
Expand Down Expand Up @@ -37,30 +38,6 @@ def process_llm_output(self, response: str) -> LLMGuardOutput:
return LLMGuardOutput.SAFE
return LLMGuardOutput.UNKNOWN

# @classmethod
# def load(cls, model_id: str,
# prompt_name: str,
# model_kwargs = None,
# device_map = "auto",
# use_chat_template: bool = True,
# **kwargs):
# llm = HF_LLM(
# model_id,
# type=(LLMType.INSTRUCT if use_chat_template else LLMType.BASE),
# model_kwargs=model_kwargs,
# device_map=device_map,
# **kwargs
# )

# template = PromptTemplate.from_preset(f"judges/{prompt_name}")

# return cls(
# model_id + " as a Judge",
# llm=llm,
# template=template,
# llm_instruct=use_chat_template
# )

def score(self, output: LLMGuardOutput) -> bool:
return output == "safe"

Expand All @@ -72,9 +49,15 @@ def __init__(self, name: str = None):
self.template: PromptTemplate = None
self.llm: LLM = None
self.use_chat_template: bool = None
self.postprocess_func = None

def load_prompt_preset(self, name: str):
self.template = PromptTemplate.from_preset(f"judges/{name}")
if PromptTemplate.exists_preset(f"judges/{name}"):
self.template = PromptTemplate.from_preset(f"judges/{name}")
elif PromptTemplate.exists_preset(name):
self.template = PromptTemplate.from_preset(name)
else:
raise NameError(f"Preset '{name}' not found in Prompt Database")
return self

def load_prompt_yaml(self, filename: str):
Expand Down Expand Up @@ -143,6 +126,10 @@ def load_huggingface_llm(self, model_id: str,
)
self.use_chat_template = use_chat_template
return self

def set_postprocess_func(self, func: Callable[[str,], str]):
self.postprocess_func = func
return self

def create(self) -> LLMGuardJudge:
if self.template is None:
Expand All @@ -153,9 +140,18 @@ def create(self) -> LLMGuardJudge:
if self.name is None:
self.name = self.llm.name + " as a Judge"

return LLMGuardJudge(
judge = LLMGuardJudge(
name = self.name,
llm = self.llm,
template = self.template,
use_chat_template = self.use_chat_template
)
)

if self.postprocess_func is not None:
def postprocess(response: str) -> LLMGuardOutput:
output = self.postprocess_func(response)
return LLMGuardOutput(output)

judge.process_llm_output = postprocess

return judge

0 comments on commit 33754b2

Please sign in to comment.