diff --git a/scripts/evaluator/evaluate_utils/llm_async_processor.py b/scripts/evaluator/evaluate_utils/llm_async_processor.py index 5c7e35e..9a85ebd 100644 --- a/scripts/evaluator/evaluate_utils/llm_async_processor.py +++ b/scripts/evaluator/evaluate_utils/llm_async_processor.py @@ -57,6 +57,8 @@ def _invoke(self, messages: Messages, **kwargs) -> Tuple[AIMessage, float]: print(f"Retrying request due to empty content. Retry attempt {i+1} of {n}.") elif self.api_type == "amazon_bedrock": response = self.llm.invoke(messages, **kwargs) + elif self.api_type == "None": + response = self.llm.invoke(messages, **kwargs) else: raise NotImplementedError( "Synchronous invoke is only implemented for Google API" @@ -67,7 +69,7 @@ def _invoke(self, messages: Messages, **kwargs) -> Tuple[AIMessage, float]: @backoff.on_exception(backoff.expo, Exception, max_tries=MAX_TRIES) async def _ainvoke(self, messages: Messages, **kwargs) -> Tuple[AIMessage, float]: await asyncio.sleep(self.inference_interval) - if self.api_type in ["google", "amazon_bedrock"]: + if self.api_type in ["google", "amazon_bedrock", "None"]: return await asyncio.to_thread(self._invoke, messages, **kwargs) else: if self.model_name == "tokyotech-llm/Swallow-7b-instruct-v0.1": diff --git a/scripts/evaluator/mtbench.py b/scripts/evaluator/mtbench.py index b26b073..7e7f96c 100644 --- a/scripts/evaluator/mtbench.py +++ b/scripts/evaluator/mtbench.py @@ -72,6 +72,8 @@ def evaluate(): answer_file=answer_file, num_worker=cfg.mtbench.parallel, ) + elif cfg.api == "None": + pass # 2. evaluate outputs questions = load_questions(question_file, None, None) diff --git a/scripts/llm_inference_adapter.py b/scripts/llm_inference_adapter.py index 9dee7a7..c5c2629 100644 --- a/scripts/llm_inference_adapter.py +++ b/scripts/llm_inference_adapter.py @@ -16,9 +16,12 @@ # from langchain_cohere import Cohere +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # PreTrainedTokenizerBase, +from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline +import torch @dataclass -class BedrockResponse: +class TextHolder: content: str @@ -61,11 +64,78 @@ def _invoke( def invoke(self, messages, max_tokens: int): response = self._invoke(messages=messages, max_tokens=max_tokens) if response["content"]: - content = content = response["content"][0]["text"] + content = response["content"][0]["text"] else: content = "" - return BedrockResponse(content=content) + return TextHolder(content=content) + + +class ChatTransformers: + def __init__(self, cfg) -> None: + # self.cfg = cfg + self.model_id = cfg.model.pretrained_model_name_or_path + self.ignore_keys = ["max_tokens"] + self.generator_config = { + k: v for k, v in cfg.generator.items() if not k in self.ignore_keys + } + self.model_params = dict( + trust_remote_code=cfg.model.trust_remote_code, + device_map=cfg.model.device_map, + load_in_8bit=cfg.model.load_in_8bit, + load_in_4bit=cfg.model.load_in_4bit, + torch_dtype=torch.float16, + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_id, + **self.model_params, + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id, + trust_remote_code=self.model_params["trust_remote_code"], + ) + + def _invoke( + self, + messages: list[dict[str, str]], + max_tokens: int, + ): + self.model.eval() + + pipe = pipeline( + "text-generation", + model=self.model, + tokenizer=self.tokenizer, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + max_new_tokens=max_tokens, + device_map=self.model_params["device_map"], + ) + + print("="*50) + print("messages") + print("="*50) + print(messages) + + prompt = " ".join([msg["content"] for msg in messages]) + print("="*50) + print("prompt") + print("="*50) + print(prompt) + + generated_text = pipe(prompt, **self.generator_config)[0]["generated_text"] + print("="*50) + print("generated_text") + print("="*50) + print(generated_text) + + return generated_text + + def invoke(self, messages, max_tokens: int): + content = self._invoke(messages=messages, max_tokens=max_tokens) + # content = response["content"][0]["text"] + + return TextHolder(content=content) def get_llm_inference_engine(): @@ -135,7 +205,7 @@ def get_llm_inference_engine(): api_key=os.environ["ANTHROPIC_API_KEY"], **cfg.generator, ) - + elif api_type == "upstage": # LangChainのOpenAIインテグレーションを使用 llm = ChatOpenAI( @@ -145,6 +215,10 @@ def get_llm_inference_engine(): **cfg.generator, ) + elif api_type == "None": + llm = ChatTransformers(cfg=cfg) + + # elif api_type == "azure-openai": # llm = AzureChatOpenAI( # api_key=os.environ["OPENAI_API_KEY"],