Skip to content

Commit

Permalink
Added custom model example for google translate.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 12, 2024
1 parent a7e1fe5 commit 24b8bd3
Showing 1 changed file with 151 additions and 0 deletions.
151 changes: 151 additions & 0 deletions examples/custom_models/google_translate_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
from typing import Optional

from tqdm import tqdm
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
LoglikelihoodSingleTokenResponse,
)
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)


logger = logging.getLogger(__name__)


class GoogleTranslateClient(LightevalModel):
def __init__(self, config, env_config) -> None:
self.model = config.model
self.model_definition_file_path = config.model_definition_file_path

self.model_info = ModelInfo(
model_name=config.model,
model_sha="",
model_dtype=None,
model_size="",
)

self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility

import httpcore

# Needed to fix some googletrans bug
# https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618
setattr(httpcore, "SyncHTTPTransport", "AsyncHTTPProxy")
from googletrans import Translator

self.translator = Translator()

def greedy_until(
self,
requests: list[GreedyUntilRequest],
override_bs: Optional[int] = None,
) -> list[GenerativeResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Args:
requests (list[Request]): list of requests containing the context and ending conditions.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
Returns:
list[GenerativeResponse]: list of generated responses.
"""
for request in requests:
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
results = []

for _ in tqdm(
dataset.splits_start_end_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=False, # self.disable_tqdm,
):
for r in tqdm(dataset, desc="Batch", position=1, disable=False):
context = r.context.replace("French phrase: ", "")
# TODO: Get src and dest from request
translation = self.translator.translate(context, src="fr", dest="de")

result = translation.text
cur_response = GenerativeResponse(
result=result,
logits=None,
generated_tokens=[],
input_tokens=[],
)
results.append(cur_response)

return dataset.get_original_order(results)

@property
def tokenizer(self):
return self._tokenizer

def tok_encode(self, text: str):
return self.tokenizer.encode(text)

@property
def add_special_tokens(self) -> bool:
return False

@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
return 4096

def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError

def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
raise NotImplementedError

def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError

0 comments on commit 24b8bd3

Please sign in to comment.