diff --git a/walledeval/judge/llm/guard.py b/walledeval/judge/llm/guard.py index fbdc003..b13af36 100644 --- a/walledeval/judge/llm/guard.py +++ b/walledeval/judge/llm/guard.py @@ -1,6 +1,7 @@ # walledeval/judge/llm/guard.py from enum import Enum +from typing import Callable from walledeval.llm import LLM, HF_LLM from walledeval.types import LLMType, Messages @@ -37,30 +38,6 @@ def process_llm_output(self, response: str) -> LLMGuardOutput: return LLMGuardOutput.SAFE return LLMGuardOutput.UNKNOWN - # @classmethod - # def load(cls, model_id: str, - # prompt_name: str, - # model_kwargs = None, - # device_map = "auto", - # use_chat_template: bool = True, - # **kwargs): - # llm = HF_LLM( - # model_id, - # type=(LLMType.INSTRUCT if use_chat_template else LLMType.BASE), - # model_kwargs=model_kwargs, - # device_map=device_map, - # **kwargs - # ) - - # template = PromptTemplate.from_preset(f"judges/{prompt_name}") - - # return cls( - # model_id + " as a Judge", - # llm=llm, - # template=template, - # llm_instruct=use_chat_template - # ) - def score(self, output: LLMGuardOutput) -> bool: return output == "safe" @@ -72,9 +49,15 @@ def __init__(self, name: str = None): self.template: PromptTemplate = None self.llm: LLM = None self.use_chat_template: bool = None + self.postprocess_func = None def load_prompt_preset(self, name: str): - self.template = PromptTemplate.from_preset(f"judges/{name}") + if PromptTemplate.exists_preset(f"judges/{name}"): + self.template = PromptTemplate.from_preset(f"judges/{name}") + elif PromptTemplate.exists_preset(name): + self.template = PromptTemplate.from_preset(name) + else: + raise NameError(f"Preset '{name}' not found in Prompt Database") return self def load_prompt_yaml(self, filename: str): @@ -143,6 +126,10 @@ def load_huggingface_llm(self, model_id: str, ) self.use_chat_template = use_chat_template return self + + def set_postprocess_func(self, func: Callable[[str,], str]): + self.postprocess_func = func + return self def create(self) -> LLMGuardJudge: if self.template is None: @@ -153,9 +140,18 @@ def create(self) -> LLMGuardJudge: if self.name is None: self.name = self.llm.name + " as a Judge" - return LLMGuardJudge( + judge = LLMGuardJudge( name = self.name, llm = self.llm, template = self.template, use_chat_template = self.use_chat_template - ) \ No newline at end of file + ) + + if self.postprocess_func is not None: + def postprocess(response: str) -> LLMGuardOutput: + output = self.postprocess_func(response) + return LLMGuardOutput(output) + + judge.process_llm_output = postprocess + + return judge \ No newline at end of file