diff --git a/server/retriever/embeddings.py b/server/retriever/embeddings.py new file mode 100644 index 0000000..3fb5471 --- /dev/null +++ b/server/retriever/embeddings.py @@ -0,0 +1,80 @@ +import mlx.core as mx +import mlx.nn as nn + +import torch +import torch.nn.functional as F +from torch import Tensor + +from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer +from abc import ABC, abstractmethod +from typing import Any, List + + +class Embeddings(ABC): + """Interface for embedding models.""" + + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs.""" + + @abstractmethod + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" + + +class E5Embeddings(Embeddings): + + model: Any = None + tokenizer: PreTrainedTokenizer = None + + def __init__(self, model_name: str = 'intfloat/multilingual-e5-small'): + self.model = AutoModel.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + last_hidden = last_hidden_states.masked_fill( + ~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + def embed_documents(self, texts: List[str], batch_size: int = 1) -> List[List[float]]: + embeddings = [] + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i+batch_size] + batch_embeddings = self.embed_query(batch_texts, batch=True) + embeddings.extend(batch_embeddings) + return embeddings + + @torch.no_grad() + def embed_query(self, texts: Any, batch: bool = False) -> List[Any]: + batch_dict = self.tokenizer(texts, max_length=512, padding=True, + truncation=True, return_tensors='pt', return_attention_mask=True) + outputs = self.model(**batch_dict) + embeddings = self._average_pool( + outputs.last_hidden_state, batch_dict['attention_mask']) + embeddings = F.normalize(embeddings, p=2, dim=1) + + if batch: + return embeddings.tolist() # -> List[List[float]] + + return embeddings[0].tolist() # -> List[float] + + +class ChatEmbeddings(Embeddings): + + model: nn.Module = None + tokenizer: PreTrainedTokenizer = None + + def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer): + self.model = model + self.tokenizer = tokenizer + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + h = self.model.embed_tokens(mx.array( + self.tokenizer.encode(text, add_special_tokens=False))) + # normalized to have unit length + h = mx.mean(h, axis=0) + h = h / mx.linalg.norm(h) + return h.tolist() diff --git a/server/retriever/vectorstore.py b/server/retriever/vectorstore.py index c372027..1cad356 100644 --- a/server/retriever/vectorstore.py +++ b/server/retriever/vectorstore.py @@ -12,7 +12,6 @@ Callable, Iterable, Optional, - Literal, Tuple, Type, ) @@ -21,6 +20,8 @@ import chromadb.config from chromadb.api.types import ID, OneOrMany, Where, WhereDocument +from .embeddings import Embeddings + Chroma = TypeVar('Chroma', bound='Chroma') @@ -118,26 +119,6 @@ def maximal_marginal_relevance( return idxs -class Embeddings(): - - type: Literal["Embeddings"] = "Embeddings" - - def __init__(self, model, tokenizer): - self.model = model - self.tokenizer = tokenizer - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - return [self.embed_query(text) for text in texts] - - def embed_query(self, text: str) -> List[float]: - h = self.model.embed_tokens(mx.array( - self.tokenizer.encode(text, add_special_tokens=False))) - # normalized to have unit length - h = mx.mean(h, axis=0) - h = h / mx.linalg.norm(h) - return h.tolist() - - class Chroma(): """ similarity_search diff --git a/server/server.py b/server/server.py index 9b3d49e..67b3889 100644 --- a/server/server.py +++ b/server/server.py @@ -14,7 +14,8 @@ from .retriever.loader import directory_loader from .retriever.splitter import RecursiveCharacterTextSplitter -from .retriever.vectorstore import Chroma, Embeddings +from .retriever.vectorstore import Chroma +from .retriever.embeddings import ChatEmbeddings, E5Embeddings _model: Optional[nn.Module] = None _tokenizer: Optional[PreTrainedTokenizer] = None @@ -27,16 +28,20 @@ def load_model(model_path: str, adapter_file: Optional[str] = None): _model, _tokenizer = load(model_path, adapter_file=adapter_file) -def load_database(directory: str): +def load_database(directory: str, use_embedding: bool = True): global _database # TODO: handle error from directory_loader on invalid raw_docs = directory_loader(directory) text_splitter = RecursiveCharacterTextSplitter( - chunk_size=4000, chunk_overlap=200, add_start_index=True + chunk_size=512, chunk_overlap=32, add_start_index=True ) + embedding = E5Embeddings() if use_embedding else ChatEmbeddings( + model=_model.model, tokenizer=_tokenizer) splits = text_splitter.split_documents(raw_docs) _database = Chroma.from_documents( - documents=splits, embedding=Embeddings(_model.model, _tokenizer)) + documents=splits, + embedding=embedding + ) def create_response(chat_id, prompt, tokens, text): @@ -66,24 +71,23 @@ def create_response(chat_id, prompt, tokens, text): return response -def format_messages(messages, condition): +def format_messages(messages, context): failedString = "ERROR" - if condition: + if context: messages[-1]['content'] = f""" -Only using the documents in the index, answer the following, Respond with just the answer, no "The answer is" or "Answer: " or anything like that. - -Question: +Only using the documents in the index, answer the following, respond with just the answer without "The answer is:" or "Answer:" or anything like that. + {messages[-1]['content']} + -Index: - -{condition} + +{context} + Remember, if you do not know the answer, just say "{failedString}", Try to give as much detail as possible, but only from what is provided within the index. If steps are given, you MUST ALWAYS use bullet points to list each of them them and you MUST use markdown when applicable. -You MUST markdown when applicable. Only use information you can find in the index, do not make up knowledge. Remember, use bullet points or numbered steps to better organize your answer if applicable. NEVER try to make up the answer, always return "{failedString}" if you do not know the answer or it's not provided in the index. @@ -122,9 +126,11 @@ def handle_post_request(self, post_data): chat_id = f'chatcmpl-{uuid.uuid4()}' load_database(body.get('directory', None)) - # emperically better than similarity_search + # emperically better than `similarity_search` docs = _database.max_marginal_relevance_search( - body['messages'][-1]['content']) + body['messages'][-1]['content'], + k=4 # number of documents to return + ) context = '\n'.join([doc.page_content for doc in docs]) print(body, flush=True) print(('\n'+'--'*10+'\n').join([