Skip to content

Commit

Permalink
Made local mt model example more general to support madlad400 as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 17, 2024
1 parent 1a37f71 commit 2f27645
Showing 1 changed file with 75 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@

import pycountry
from tqdm import tqdm
from transformers import AutoProcessor, SeamlessM4Tv2ForTextToText
from transformers import (
AutoModelForSeq2SeqLM,
AutoProcessor,
AutoTokenizer,
SeamlessM4Tv2ForTextToText,
)

from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo, TokenSequence
Expand All @@ -45,19 +50,61 @@
logger = logging.getLogger(__name__)


class Seamless4MTClient(LightevalModel):
class LocalMTClient(LightevalModel):
"""
A custom model implementation for local machine translation models, specifically supporting:
- SeamlessM4T v2 models from Meta
- MADLAD-400 models from Google
This class provides a unified interface for both model families while handling their different
tokenization and generation approaches transparently.
Args:
config (CustomModelConfig): Configuration containing:
- model (str): Model identifier/path (e.g. "facebook/seamless-m4t-v2-large" or "google/madlad400-7b-mt")
- model_definition_file_path (str): Path to this model definition file
env_config: Environment configuration (unused)
The model automatically detects whether to load SeamlessM4T or MADLAD based on the model identifier
and initializes the appropriate tokenizer and model.
Translation tasks should specify the source and target languages in the format:
"{task_name}|{...}:{src}-{tgt}"
where src and tgt are ISO language codes (2 or 3 letter codes supported).
Example:
```lighteval custom facebook/seamless-m4t-v2-large examples/custom_models/local_mt_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10
```
Note:
- SeamlessM4T models use the AutoProcessor for tokenization
- MADLAD models use the standard AutoTokenizer
- Language codes are automatically converted to 3-letter ISO codes for SeamlessM4T
"""

def __init__(self, config, env_config) -> None:
self.model = config.model
self.model_definition_file_path = config.model_definition_file_path
self.batch_size = 32

self.model_info = ModelInfo(
model_name=config.model,
model_sha="",
model_dtype=None,
model_size="",
)
self._tokenizer = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
self._model = SeamlessM4Tv2ForTextToText.from_pretrained("facebook/seamless-m4t-v2-large")

# Update model initialization to handle both models
if "seamless-m4t" in config.model:
self._tokenizer = AutoProcessor.from_pretrained(config.model)
self._model = SeamlessM4Tv2ForTextToText.from_pretrained(config.model)
self.model_type = "seamless-4mt"
elif "madlad400" in config.model:
self._tokenizer = AutoTokenizer.from_pretrained(config.model)
self._model = AutoModelForSeq2SeqLM.from_pretrained(config.model)
self.model_type = "madlad400"
else:
raise ValueError(f"Unsupported model: {config.model}")

def _convert_to_iso3(self, lang_code: str) -> str:
"""Convert 2-letter ISO code to 3-letter ISO code."""
Expand Down Expand Up @@ -86,24 +133,27 @@ def greedy_until(

def get_langs(task_name: str) -> tuple[str, str]:
src, tgt = task_name.split("|")[1].split(":")[-1].split("-")
return self._convert_to_iso3(src), self._convert_to_iso3(tgt)
if self.model_type == "seamless-4mt":
return self._convert_to_iso3(src), self._convert_to_iso3(tgt)
return src, tgt

# Prepare all inputs first
# Prepare all inputs first for creating the GenerativeTaskDataset
prepared_requests = []
for request in requests:
src_lang, tgt_lang = get_langs(request.task_name)
request.context = request.context.replace(f"{src_lang.upper()}: ", "").replace(
f"\n{tgt_lang.upper()}: ", ""
)
request.tokenized_context = self._tokenizer(
text=request.context, src_lang=src_lang, return_tensors="pt", padding=True
)
if self.model_type == "madlad400":
request.context = f"<2{tgt_lang}> {request.context}"

request.tokenized_context = self.tok_encode(request.context)
prepared_requests.append(request)

# Create dataset after preparation
dataset = GenerativeTaskDataset(requests=prepared_requests, num_dataset_splits=self.DATASET_SPLITS)
results = []
batch_size = override_bs or 32
batch_size = override_bs or self.batch_size

for split_start, split_end in tqdm(
dataset.splits_start_end_iterator(),
Expand All @@ -123,18 +173,25 @@ def get_langs(task_name: str) -> tuple[str, str]:
batch_texts = [r.context for r in batch]
src_lang = get_langs(batch[0].task_name)[0] # All source languages should be the same in a batch

# Unpack the tokenizer output into input_ids and attention_mask
input_ids, attention_mask = self._tokenizer(
text=batch_texts, src_lang=src_lang, return_tensors="pt", padding=True
).values()
# This is the tokenization step that really counts, as it actually gets used
tokenizer_kwargs = {"text": batch_texts, "return_tensors": "pt", "padding": True}
if self.model_type == "seamless-4mt":
tokenizer_kwargs["src_lang"] = src_lang

input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).values()

tgt_langs = [get_langs(r.task_name)[1] for r in batch]
assert set(tgt_langs) == {tgt_langs[0]}, "All target languages must be the same"

# Use unpacked values directly
output_ids = self._model.generate(
input_ids=input_ids, attention_mask=attention_mask, tgt_lang=tgt_langs[0]
)
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if self.model_type == "seamless-4mt":
generate_kwargs["tgt_lang"] = tgt_langs[0]

output_ids = self._model.generate(**generate_kwargs)
translations = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True)

# Create responses for the batch
Expand All @@ -155,9 +212,7 @@ def tokenizer(self):
return self._tokenizer

def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
return self._tokenizer(
text=str_to_encode, return_tensors="pt", padding=True, add_special_tokens=add_special_tokens or False
)
return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False)

@property
def add_special_tokens(self) -> bool:
Expand Down

0 comments on commit 2f27645

Please sign in to comment.