Skip to content

Commit

Permalink
fix(api): fix indexing files with api key auth (#852)
Browse files Browse the repository at this point in the history
* add migration for api key auth on storage.objects
* move get_user_id in crud_base to a static method for other classes to use
* fix issues with database calls from index.py
* add conformance testing

---------

Co-authored-by: Jonathan Perry <[email protected]>
  • Loading branch information
gphorvath and YrrepNoj authored Aug 1, 2024
1 parent 33e4efb commit c4d9c3f
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 124 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
create policy "Individuals can CRUD storage.objects via API key."
on storage.objects for all
to anon
using
(
exists (
select 1
from api_keys
where api_keys.api_key_hash = crypt(current_setting('request.headers')::json->>'x-custom-api-key', api_keys.api_key_hash)
)
);
137 changes: 22 additions & 115 deletions src/leapfrogai_api/backend/rag/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import logging
import tempfile
import time


from fastapi import HTTPException, UploadFile, status
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
Expand All @@ -28,6 +26,8 @@
FilterVectorStoreFile,
)

from leapfrogai_api.data.crud_vector_content import CRUDVectorContent, Vector

# Allows for overwriting type of embeddings that will be instantiated
embeddings_type: type[Embeddings] | type[LeapfrogAIEmbeddings] | None = (
LeapfrogAIEmbeddings
Expand Down Expand Up @@ -56,14 +56,12 @@ async def index_file(self, vector_store_id: str, file_id: str) -> VectorStoreFil
if await crud_vector_store_file.get(
filters=FilterVectorStoreFile(vector_store_id=vector_store_id, id=file_id)
):
print("File already indexed: %s", file_id)
logging.error("File already indexed: %s", file_id)
raise FileAlreadyIndexedError("File already indexed")

if not (
await crud_vector_store.get(filters=FilterVectorStore(id=vector_store_id))
):
print("Vector store doesn't exist: %s", vector_store_id)
logging.error("Vector store doesn't exist: %s", vector_store_id)
raise ValueError("Vector store not found")

Expand Down Expand Up @@ -175,7 +173,7 @@ async def create_new_vector_store(
),
last_active_at=last_active_at, # Set to current time
metadata=request.metadata,
name=request.name,
name=request.name or "",
object="vector_store",
status=VectorStoreStatus.IN_PROGRESS.value,
expires_after=expires_after,
Expand All @@ -200,7 +198,6 @@ async def create_new_vector_store(
object_=new_vector_store,
)
except Exception as exc:
logging.error(exc)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Unable to parse vector store request",
Expand Down Expand Up @@ -291,29 +288,6 @@ async def file_ids_are_valid(self, file_ids: str | list[str]) -> bool:

return True

async def adelete_file(self, vector_store_id: str, file_id: str) -> bool:
"""Delete a file from the vector store.
Args:
vector_store_id (str): The ID of the vector store.
file_id (str): The ID of the file to be deleted.
Returns:
dict: The response from the database after deleting the file.
"""
data, _count = (
await self.db.from_(self.table_name)
.delete()
.eq("vector_store_id", vector_store_id)
.eq("file_id", file_id)
.execute()
)

_, response = data

return bool(response)

async def aadd_documents(
self,
documents: list[Document],
Expand All @@ -326,7 +300,8 @@ async def aadd_documents(
documents (list[Document]): A list of Langchain Document objects to be added.
vector_store_id (str): The ID of the vector store where the documents will be added.
file_id (str): The ID of the file associated with the documents.
batch_size (int): The size of the batches that will be pushed to the db. This value defaults to 100
batch_size (int): The size of the batches that will
be pushed to the db. This value defaults to 100
as a balance between the memory impact of large files and performance improvements from batching.
Returns:
List[str]: A list of IDs assigned to the added documents.
Expand All @@ -338,22 +313,25 @@ async def aadd_documents(
texts=[document.page_content for document in documents]
)

vectors = []
vectors: list[Vector] = []
for document, embedding in zip(documents, embeddings):
vector = {
"content": document.page_content,
"metadata": document.metadata,
"embedding": embedding,
}
vector = Vector(
id="",
vector_store_id=vector_store_id,
file_id=file_id,
content=document.page_content,
metadata=document.metadata,
embedding=embedding,
)
vectors.append(vector)

crud_vector_content = CRUDVectorContent(db=self.db)

for i in range(0, len(vectors), batch_size):
batch = vectors[i : i + batch_size]
response = await self._aadd_vectors(
vector_store_id=vector_store_id, file_id=file_id, vectors=batch
)
ids.extend([item["id"] for item in response])

response = await crud_vector_content.add_vectors(batch)
ids.extend([item.id for item in response])
return ids

async def asimilarity_search(self, query: str, vector_store_id: str, k: int = 4):
Expand All @@ -370,20 +348,10 @@ async def asimilarity_search(self, query: str, vector_store_id: str, k: int = 4)
"""
vector = await self.embeddings.aembed_query(query)

user_id: str = (await self.db.auth.get_user()).user.id

params = {
"query_embedding": vector,
"match_limit": k,
"vs_id": vector_store_id,
"user_id": user_id,
}

query_builder = self.db.rpc(self.query_name, params=params)

response = await query_builder.execute()

return response
crud_vector_content = CRUDVectorContent(db=self.db)
return await crud_vector_content.similarity_search(
query=vector, vector_store_id=vector_store_id, k=k
)

async def _increment_vector_store_file_status(
self, vector_store: VectorStore, file_response: VectorStoreFile
Expand All @@ -398,64 +366,3 @@ async def _increment_vector_store_file_status(
elif file_response.status == VectorStoreFileStatus.CANCELLED.value:
vector_store.file_counts.cancelled += 1
vector_store.file_counts.total += 1

async def _adelete_vector(
self,
vector_store_id: str,
file_id: str,
) -> dict:
"""Delete a vector from the vector store.
Args:
vector_store_id (str): The ID of the vector store.
file_id (str): The ID of the file associated with the vector.
Returns:
dict: The response from the database after deleting the vector.
"""
response = (
await self.db.from_(self.table_name)
.delete()
.eq("vector_store_id", vector_store_id)
.eq("file_id", file_id)
.execute()
)
return response

async def _aadd_vectors(
self, vector_store_id: str, file_id: str, vectors: list[dict[str, any]]
) -> dict:
"""Add multiple vectors to the vector store in a batch.
Args:
vector_store_id (str): The ID of the vector store.
file_id (str): The ID of the file associated with the vectors.
vectors (list[dict]): A list of dictionaries containing vector data.
Returns:
dict: The response from the database after inserting the vectors.
"""
user_id: str = (await self.db.auth.get_user()).user.id

rows = []
for vector in vectors:
row = {
"user_id": user_id,
"vector_store_id": vector_store_id,
"file_id": file_id,
"content": vector["content"],
"metadata": vector["metadata"],
"embedding": vector["embedding"],
}
rows.append(row)

data, _count = await self.db.from_(self.table_name).insert(rows).execute()

_, response = data

for item in response:
if "user_id" in item:
del item["user_id"]

return response
6 changes: 5 additions & 1 deletion src/leapfrogai_api/backend/rag/leapfrogai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import leapfrogai_sdk as lfai
from leapfrogai_api.utils import get_model_config
from leapfrogai_api.backend.grpc_client import create_embeddings
import logging


# Partially implements the Langchain Core Embeddings interface
Expand Down Expand Up @@ -44,7 +45,9 @@ async def aembed_query(self, text: str) -> list[float]:

return list_of_embeddings[0]

async def _get_model(self, model_name: str = os.getenv("DEFAULT_EMBEDDINGS_MODEL")):
async def _get_model(
self, model_name: str = os.getenv("DEFAULT_EMBEDDINGS_MODEL", "text-embeddings")
):
"""Gets the embeddings model.
Args:
Expand All @@ -58,6 +61,7 @@ async def _get_model(self, model_name: str = os.getenv("DEFAULT_EMBEDDINGS_MODEL
"""

if not (model := get_model_config().get_model_backend(model=model_name)):
logging.error(f"Embeddings model {model_name} not found.")
raise ValueError("Embeddings model not found.")

return model
18 changes: 12 additions & 6 deletions src/leapfrogai_api/data/crud_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,16 @@ async def delete(self, filters: dict | None = None) -> bool:
async def _get_user_id(self) -> str:
"""Get the user_id from the API key."""

if self.db.options.headers.get("x-custom-api-key"):
result = await self.db.table("api_keys").select("user_id").execute()
user_id: str = result.data[0]["user_id"]
else:
user_id = (await self.db.auth.get_user()).user.id
return await get_user_id(self.db)

return user_id

async def get_user_id(db: AsyncClient) -> str:
"""Get the user_id from the API key."""

if db.options.headers.get("x-custom-api-key"):
result = await db.table("api_keys").select("user_id").execute()
user_id: str = result.data[0]["user_id"]
else:
user_id = (await db.auth.get_user()).user.id

return user_id
102 changes: 102 additions & 0 deletions src/leapfrogai_api/data/crud_vector_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""CRUD Operations for VectorStore."""

from pydantic import BaseModel
from supabase import AClient as AsyncClient
from leapfrogai_api.data.crud_base import get_user_id
import ast


class Vector(BaseModel):
id: str = ""
vector_store_id: str
file_id: str
content: str
metadata: dict
embedding: list[float]


class CRUDVectorContent:
"""CRUD Operations for VectorStore"""

def __init__(self, db: AsyncClient):
self.db = db
self.table_name = "vector_content"

async def add_vectors(self, object_: list[Vector]) -> list[Vector]:
"""Create new row."""

user_id = await get_user_id(self.db)

rows = []

for vector in object_:
dict_ = vector.model_dump()
dict_["user_id"] = user_id
if "id" in dict_:
del dict_["id"]

rows.append(dict_)

data, _count = await self.db.table(self.table_name).insert(dict_).execute()

_, response = data

final_response = []
try:
for item in response:
if "user_id" in item:
del item["user_id"]
if isinstance(item["embedding"], str):
item["embedding"] = self.string_to_float_list(item["embedding"])
final_response.append(
Vector(
id=item["id"],
vector_store_id=item["vector_store_id"],
file_id=item["file_id"],
content=item["content"],
metadata=item["metadata"],
embedding=item["embedding"],
)
)

return final_response
except Exception as e:
raise e

async def delete_vectors(self, vector_store_id: str, file_id: str) -> bool:
"""Delete a vector store file by its ID."""
data, _count = (
await self.db.table(self.table_name)
.delete()
.eq("vector_store_id", vector_store_id)
.eq("file_id", file_id)
.execute()
)

_, response = data

return bool(response)

async def similarity_search(self, query: list[float], vector_store_id: str, k: int):
user_id = await get_user_id(self.db)

params = {
"query_embedding": query,
"match_limit": k,
"vs_id": vector_store_id,
"user_id": user_id,
}

return await self.db.rpc("match_vectors", params).execute()

@staticmethod
def string_to_float_list(s: str) -> list[float]:
try:
# Remove any whitespace and convert to a Python list
cleaned_string = s.strip()
python_list = ast.literal_eval(cleaned_string)

# Convert all elements to float
return [float(x) for x in python_list]
except (ValueError, SyntaxError) as e:
raise e
5 changes: 3 additions & 2 deletions src/leapfrogai_api/routers/openai/vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ListVectorStoresResponse,
ModifyVectorStoreRequest,
)
from leapfrogai_api.data.crud_vector_content import CRUDVectorContent
from leapfrogai_api.data.crud_vector_store import CRUDVectorStore, FilterVectorStore
from leapfrogai_api.data.crud_vector_store_file import (
CRUDVectorStoreFile,
Expand Down Expand Up @@ -180,8 +181,8 @@ async def delete_vector_store_file(
) -> VectorStoreFileDeleted:
"""Delete a file in a vector store."""

vector_store = IndexingService(db=session)
vectors_deleted = await vector_store.adelete_file(
vector_content = CRUDVectorContent(db=session)
vectors_deleted = await vector_content.delete_vectors(
vector_store_id=vector_store_id, file_id=file_id
)

Expand Down
Loading

0 comments on commit c4d9c3f

Please sign in to comment.