diff --git a/src/leapfrogai_api/data/crud_base.py b/src/leapfrogai_api/data/crud_base.py index 7f6b5db6d..8532df611 100644 --- a/src/leapfrogai_api/data/crud_base.py +++ b/src/leapfrogai_api/data/crud_base.py @@ -118,10 +118,16 @@ async def delete(self, filters: dict | None = None) -> bool: async def _get_user_id(self) -> str: """Get the user_id from the API key.""" - if self.db.options.headers.get("x-custom-api-key"): - result = await self.db.table("api_keys").select("user_id").execute() - user_id: str = result.data[0]["user_id"] - else: - user_id = (await self.db.auth.get_user()).user.id + return await get_user_id(self.db) - return user_id + +async def get_user_id(db: AsyncClient) -> str: + """Get the user_id from the API key.""" + + if db.options.headers.get("x-custom-api-key"): + result = await db.table("api_keys").select("user_id").execute() + user_id: str = result.data[0]["user_id"] + else: + user_id = (await db.auth.get_user()).user.id + + return user_id