Skip to content

Commit

Permalink
Merge pull request #50 from holyCowMp3/gemini_pro_and_ranging
Browse files Browse the repository at this point in the history
Added possibility to use range for translation checkpoints alongside with new translation model based on Gemini Pro Free API
  • Loading branch information
ErikTromp authored Feb 8, 2024
2 parents 5953512 + 0227cc3 commit 17e55d9
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 10 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ tensorboardx==2.6.2.2
pandas==1.5.3
stanza==1.7.0
tqdm
sacrebleu
sacrebleu
google-generativeai
30 changes: 24 additions & 6 deletions translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from translators.opus import OPUSTranslator
from translators.seamless_m4t_v2 import Seamless_M4T_V2
from translators.towerinstruct import TowerInstructTranslator
from translators.gemini_pro import GeminiProTranslator


# Find the max checkpoint number to continue from
Expand Down Expand Up @@ -68,6 +69,10 @@ def main():
help="Forces usage of CPU. By default GPU is taken if available.")
parser.add_argument('--source_lang', type=str, default=None,
help="Source language to select from OASST based on lang property of dataset")
parser.add_argument('--start_index', type=int, default=None,
help="Set start index for processing in dataset by range")
parser.add_argument('--end_index', type=int, default=None,
help="Set end index for processing in dataset by range")

parser_opus = subparsers.add_parser('opus', help='Translate the dataset using HelsinkiNLP OPUS models.')

Expand All @@ -94,6 +99,10 @@ def main():

parser_towerinstruct = subparsers.add_parser('towerinstruct', help='Translate the dataset using Unbabel\'s Tower Instruct. Make sure your target language is in the 10 languages supported by the model.')

parser_gemini_pro = subparsers.add_parser('gemini_pro', help='Gemini Pro translation model')

parser_gemini_pro.add_argument('--auth_token', type=str, default=None,
help='Gemini Pro retrieved here https://makersuite.google.com/app/apikey')
# Default arguments shared across models
args = parser.parse_args()
model = args.model
Expand All @@ -108,6 +117,8 @@ def main():
batch_size = args.batch_size
force_cpu = args.cpu
selected_source_language = args.source_lang
start_index = args.start_index
end_index = args.end_index

device = torch.device("cuda:0" if torch.cuda.is_available() and not (force_cpu) else "cpu")

Expand Down Expand Up @@ -140,6 +151,8 @@ def main():
translator = Seamless_M4T_V2(device, quant4, quant4_config, quant8, args.max_length, args.model_size)
elif model == 'towerinstruct':
translator = TowerInstructTranslator(device, quant4, quant4_config, quant8, args.max_length)
elif model == 'gemini_pro':
translator = GeminiProTranslator(args.auth_token, args.max_length)
else:
translator = OPUSTranslator(device, quant4, quant4_config, quant8, args.max_length)

Expand All @@ -150,27 +163,32 @@ def main():
if selected_source_language is not None:
records = records_by_lang[selected_source_language]
translate_records(base_dataset_lang_field, base_dataset_text_field, batch_size, checkpoint_location,
checkpoint_n, device, fold, pbar, records, source_lang, target_lang, translator)
checkpoint_n, device, fold, pbar, records, selected_source_language, target_lang, translator,
last_checkpoint=start_index, end_of_range=end_index)
else:
for source_lang, records in records_by_lang.items():
translate_records(base_dataset_lang_field, base_dataset_text_field, batch_size, checkpoint_location,
checkpoint_n, device, fold, pbar, records, source_lang, target_lang, translator)
checkpoint_n, device, fold, pbar, records, source_lang, target_lang, translator,
last_checkpoint=start_index, end_of_range=end_index)
# One source language down, release the memory
gc.collect()
if str(device).startswith('cuda'):
torch.cuda.empty_cache()


def translate_records(base_dataset_lang_field, base_dataset_text_field, batch_size, checkpoint_location, checkpoint_n,
device, fold, pbar, records, source_lang, target_lang, translator):
device, fold, pbar, records, source_lang, target_lang, translator, last_checkpoint = None,
end_of_range = None):
lang_checkpoint_location = os.path.join(checkpoint_location, fold, f'from_{source_lang}')
os.makedirs(lang_checkpoint_location, exist_ok=True)
last_checkpoint_n = find_largest_checkpoint(lang_checkpoint_location)
last_checkpoint_n = last_checkpoint if last_checkpoint is not None else find_largest_checkpoint(lang_checkpoint_location)
translated_texts = []
records_length = len(records) if end_of_range is None else end_of_range
print(
f'[---- LLaMa2Lang ----] Got {len(records)} records for source language {source_lang}, skipping {last_checkpoint_n}')
f'[---- LLaMa2Lang ----] Got {len(records)} records for source language {source_lang}, skipping {last_checkpoint_n}, will process till {records_length}')
pbar.total = records_length
pbar.update(last_checkpoint_n)
for cnt in range(last_checkpoint_n, len(records), batch_size):
for cnt in range(last_checkpoint_n, records_length, batch_size):
# Translate a full batch
batch = records[cnt:cnt + batch_size]
texts_to_translate = [record[base_dataset_text_field] for record in batch]
Expand Down
121 changes: 121 additions & 0 deletions translators/gemini_pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from google.api_core.exceptions import InternalServerError
from translators.base import BaseTranslator

import google.generativeai as genai
import asyncio
import codecs


class GeminiProTranslator(BaseTranslator):
# based on https://ai.google.dev/available_regions#available_languages
# make sure that you have access to Gemini Region
language_mapping = {
"en": "English",
"pt": "Portuguese",
"pt-BR": "Portuguese",
"es": "Spanish",
"fr": "French",
"de": "German",
"nl": "Dutch",
"it": "Italian",
"ko": "Korean",
"zh": "Chinese",
"uk": "Ukrainian",
"uk-UA": "Ukrainian",
"ja": "Japan",
"pl": "Polish",
"ar": "Arabic",
"bn": "Bengali",
"bg": "Bulgarian",
"hr": "Croatian",
"cs": "Czech",
"da": "Danish",
"et": "Estonian",
"fi": "Finnish",
"el": "Greek",
"iw": "Hebrew",
"hi": "Hindi",
"hu": "Hungarian",
"id": "Indonesian",
"lv": "Latvian",
"lt": "Lithuanian",
"no": "Norwegian",
"ro": "Romanian",
"ru": "Russian",
"sr": "Serbian",
"sk": "Slovak",
"sl": "Slovenian",
"sw": "Swahili",
"sv": "Swedish",
"th": "Thai",
"tr": "Turkish",
"vi": "Vietnamese"
}

def __init__(self, access_token, max_length):
if access_token is None:
raise Exception("Access token is required!")
super().__init__(None, None, None, None, max_length)
genai.configure(api_key=access_token)
self.printed_error_langs = {}
self.model = genai.GenerativeModel('gemini-pro')

async def translate_text(self, text, prompt):
try:
## Need to ignore safety to correctly translate input from different languages
result = self.model.generate_content_async(f"{prompt}\n{text}", safety_settings={'HARASSMENT': 'block_none',
'HARM_CATEGORY_SEXUALLY_EXPLICIT': 'block_none',
'harm_category_dangerous_content': 'block_none',
'harm_category_hate_speech': 'block_none',
'harm_category_harassment': 'block_none'
})
return await result
except InternalServerError:
return await self.translate_text(text, prompt)
def decode_result(self, response):
try:
return response.text
except:
try:
result = "".join(map(lambda part: part.text, response.parts))
decoded_result = codecs.escape_decode(result)[0].decode("utf8")
return decoded_result
except:
if len(response.candidates) == 0:
return
result = "".join(map(lambda part: part.text, response.candidates[0].content.parts))
decoded_result = codecs.escape_decode(result)[0].decode("utf8")
return decoded_result

async def translate_texts(self, texts, prompt):
tasks = []
for text in texts:
tasks.append(self.translate_text(text, prompt))
await asyncio.sleep(1)
results = await asyncio.gather(*tasks)
decoded_results = []
for i in range(0,len(results)):
try:
decoded_results.append(self.decode_result(results[i]))
except:
print("Error during translation, returning source language")
decoded_results.append(texts[i])

return decoded_results

def translate(self, texts, source_lang, target_lang):
if len(texts) > 60:
raise Exception("Batch size cannot be more than 60 for this translator due ratelimit in 60 RPM!")
if source_lang in self.language_mapping and target_lang in self.language_mapping:
trgt_lang = self.language_mapping[target_lang]
prompt = (f"Translate text below to {trgt_lang} language and preserve formatting and special characters. "
f"Respond with translated text ONLY. Here is text to translate:\n")
loop = asyncio.get_event_loop()
result = loop.run_until_complete(self.translate_texts(texts, prompt))
return result
else:
if not (source_lang in self.printed_error_langs):
print(
f"[---- LLaMa2Lang ----] Gemini Pro cannot translate from source language {source_lang} or to your target language {target_lang}, returning originals")
self.printed_error_langs[source_lang] = True
return None
7 changes: 4 additions & 3 deletions translators/towerinstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class TowerInstructTranslator(BaseTranslator):
'it': 'Italian',
'ko': 'Korean',
'zh': 'Chinese',
'ru': 'Russian'
'ru': 'Russian',
'uk': 'Ukrainian'
}
def __init__(self, device, quant4, quant4_config, quant8, max_length):
super().__init__(device, quant4, quant4_config, quant8, max_length)
Expand All @@ -29,7 +30,7 @@ def __init__(self, device, quant4, quant4_config, quant8, max_length):
model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

self.nlp_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=self.device)
self.nlp_pipeline = pipeline("text-generation", model=model, device_map=self.device, tokenizer=tokenizer)
self.printed_error_langs = {}

def translate(self, texts, source_lang, target_lang):
Expand All @@ -39,7 +40,7 @@ def translate(self, texts, source_lang, target_lang):

with torch.no_grad():
texts = [{'role':'user','content': f'Translate the following text from {src_lang} into {trgt_lang}.\n{src_lang}: {t}\n{trgt_lang}:'} for t in texts]
prompts = [self.nlp_pipeline.tokenizer.apply_chat_template([text], tokenize=False, add_generation_prompt=True).to(self.device) for text in texts]
prompts = [self.nlp_pipeline.tokenizer.apply_chat_template([text], tokenize=False, add_generation_prompt=True) for text in texts]
if self.max_length is None:
outputs = [self.nlp_pipeline(prompt, do_sample=False) for prompt in prompts]
else:
Expand Down

0 comments on commit 17e55d9

Please sign in to comment.