From 30b537c30b2227e0aa8b1212e9aecc56a8eba5ea Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Fri, 26 Jul 2024 21:40:07 +0800 Subject: [PATCH] add multigpu inference example --- examples/multigpu_infer.py | 41 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 examples/multigpu_infer.py diff --git a/examples/multigpu_infer.py b/examples/multigpu_infer.py new file mode 100644 index 0000000..d4648b6 --- /dev/null +++ b/examples/multigpu_infer.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + + +from angle_emb import AnglE +from datasets import load_dataset +from multiprocess import set_start_method + + +# configuration +n_gpus = 4 +workers = 8 +batch_size = 16 + +# init angle +angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda() + +# load dataset +ds = load_dataset('mteb/stsbenchmark-sts', split='test') +ds = ds.select_columns(('sentence1', 'sentence2')) + + +def encode(examples, rank): + device = f"cuda:{rank}" + angle.to(device) + + docs = [f'{s1} {s2}' for s1, s2 in zip(examples['sentence1'], examples['sentence2'])] + examples['emb'] = angle.encode(docs).tolist() + return examples + + +if __name__ == '__main__': + + # it is required to put the inference code in the main function. + + set_start_method('spawn') + + # map and encode + ds = ds.map(encode, with_rank=True, num_proc=n_gpus, batched=True, batch_size=batch_size) + + print(ds) + # ds.push_to_hub(xxx)