Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor batching to os.fork / multiprocessing #57

Draft
wants to merge 9 commits into
base: batch-handler-asyncio-refactor-queue
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions experimental/caching/multiprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Iterator
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor

import torch.multiprocessing as mp
from transformers import AutoTokenizer, AutoModel
import torch
Expand Down Expand Up @@ -88,15 +88,15 @@ def loop_forever(self):
pass

class TokenizePipeline(BoringPipeline):
def post_init(self, device: str):
def post_init(self):
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5")
self.device = device


def working_function(self, item):
assert isinstance(item, list) and all(isinstance(i, str) for i in item)
try:
with torch.inference_mode():
return self.tokenizer(item, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
return self.tokenizer(item, padding="max_length", truncation=True, return_tensors="pt")
except Exception as ex:
print(ex)
return None
Expand All @@ -109,7 +109,9 @@ def post_init(self, model_device: str):

def working_function(self, item):
with torch.inference_mode():
return self.model(**item).last_hidden_state.shape
item = item.to(self.model.device)
output = self.model(**item).last_hidden_state
return output.detach().cpu().shape

def main():
mp.set_start_method('spawn')
Expand Down
48 changes: 25 additions & 23 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from infinity_emb.inference import (
BatchHandler,
Device,
select_model,
)
from infinity_emb.log_handler import logger
from infinity_emb.primitives import EmbeddingReturnType, ModelCapabilites
Expand Down Expand Up @@ -51,19 +50,18 @@ def __init__(
self.running = False
self._vector_disk_cache_path = vector_disk_cache_path
self._model_name_or_path = model_name_or_path
self._model_name_or_pathengine = engine
self._model_warmup = model_warmup
self._lengths_via_tokenize = lengths_via_tokenize

if isinstance(engine, str):
engine = InferenceEngine[engine]
self._engine_type = InferenceEngine[engine]
else:
self._engine_type = engine
if isinstance(device, str):
device = Device[device]

self._model, self._min_inference_t = select_model(
model_name_or_path=model_name_or_path,
batch_size=batch_size,
engine=engine,
model_warmup=model_warmup,
device=device,
)
self.device = Device[device]
else:
self.device = device

async def astart(self):
"""startup engine"""
Expand All @@ -75,20 +73,23 @@ async def astart(self):
)
self.running = True
self._batch_handler = BatchHandler(
model_name_or_path=self._model_name_or_path,
engine=self._engine_type,
max_batch_size=self.batch_size,
model=self._model,
batch_delay=self._min_inference_t / 2,
model_warmup=self._model_warmup,
vector_disk_cache_path=self._vector_disk_cache_path,
verbose=logger.level <= 10,
lengths_via_tokenize=self._lengths_via_tokenize,
device=self.device,
)
await self._batch_handler.spawn()
await self._batch_handler.astart()

async def astop(self):
"""stop engine"""
self._check_running()
self._assert_running()
self.running = False
await self._batch_handler.shutdown()
await self._batch_handler.astop()
self._batch_handler = None

async def __aenter__(self):
await self.astart()
Expand All @@ -97,16 +98,17 @@ async def __aexit__(self, *args):
await self.astop()

def overload_status(self):
self._check_running()
self._assert_running()
return self._batch_handler.overload_status()

def is_overloaded(self) -> bool:
self._check_running()
self._assert_running()
return self._batch_handler.is_overloaded()

@property
def capabilities(self) -> Set[ModelCapabilites]:
return self._model.capabilities
self._assert_running()
return self._batch_handler.capabilities

async def embed(
self, sentences: List[str]
Expand All @@ -125,7 +127,7 @@ async def embed(
Usage:
"""

self._check_running()
self._assert_running()
embeddings, usage = await self._batch_handler.embed(sentences)
return embeddings, usage

Expand All @@ -139,7 +141,7 @@ async def rerank(
docs (List[str]): docs to be reranked
raw_scores (bool): return raw scores instead of sigmoid
"""
self._check_running()
self._assert_running()
scores, usage = await self._batch_handler.rerank(
query=query, docs=docs, raw_scores=raw_scores
)
Expand All @@ -156,12 +158,12 @@ async def classify(
docs (List[str]): docs to be reranked
raw_scores (bool): return raw scores instead of sigmoid
"""
self._check_running()
self._assert_running()
scores, usage = await self._batch_handler.classify(sentences=sentences)

return scores, usage

def _check_running(self):
def _assert_running(self):
if not self.running:
raise ValueError(
"didn't start `AsyncEmbeddingEngine` "
Expand Down
2 changes: 0 additions & 2 deletions libs/infinity_emb/infinity_emb/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from infinity_emb.inference.batch_handler import BatchHandler
from infinity_emb.inference.select_model import select_model
from infinity_emb.primitives import (
Device,
DeviceTypeHint,
Expand All @@ -15,5 +14,4 @@
"Device",
"DeviceTypeHint",
"BatchHandler",
"select_model",
]
Loading