From e351e15a3f34914949be635ac93384ea2aa781ca Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 20 Jun 2024 15:47:59 +0200 Subject: [PATCH] Add login --- google_sheets/app.py | 111 +++++++++++++++++++++++++++++++++++- google_sheets/db_helpers.py | 12 ++-- 2 files changed, 116 insertions(+), 7 deletions(-) diff --git a/google_sheets/app.py b/google_sheets/app.py index 0ddd975..57838be 100644 --- a/google_sheets/app.py +++ b/google_sheets/app.py @@ -1,11 +1,14 @@ import json import logging +import urllib.parse from os import environ from pathlib import Path -from typing import Annotated, Any, Dict, List, Union +from typing import Annotated, Any, Dict, List, Tuple, Union +import httpx from asyncify import asyncify -from fastapi import FastAPI, HTTPException, Query +from fastapi import FastAPI, HTTPException, Query, Request +from fastapi.responses import RedirectResponse from google.oauth2.credentials import Credentials from googleapiclient.discovery import build from prisma.errors import RecordNotFoundError @@ -44,6 +47,110 @@ } +async def get_user_id_chat_uuid_from_chat_id( + chat_id: Union[int, str], +) -> Tuple[int, str]: + wasp_db_url = await get_db_url(db_name="waspdb") + async with get_db_connection(db_url=wasp_db_url) as db: + chat = await db.query_first( + f'SELECT * from "Chat" where id={chat_id}' # nosec: [B608] + ) + if not chat: + raise HTTPException(status_code=404, detail=f"chat {chat} not found") + user_id = chat["userId"] + chat_uuid = chat["uuid"] + return user_id, chat_uuid + + +async def is_authenticated_for_ads(user_id: int) -> bool: + await get_user(user_id=user_id) + async with get_db_connection() as db: + data = await db.gauth.find_unique(where={"user_id": user_id}) + + if not data: + return False + return True + + +# Route 1: Redirect to Google OAuth +@app.get("/login") +async def get_login_url( + request: Request, + user_id: int = Query(title="User ID"), + conv_id: int = Query(title="Conversation ID"), + force_new_login: bool = Query(title="Force new login", default=False), +) -> Dict[str, str]: + if not force_new_login: + is_authenticated = await is_authenticated_for_ads(user_id=user_id) + if is_authenticated: + return {"login_url": "User is already authenticated"} + + google_oauth_url = ( + f"{oauth2_settings['auth_uri']}?client_id={oauth2_settings['clientId']}" + f"&redirect_uri={oauth2_settings['redirectUri']}&response_type=code" + f"&scope={urllib.parse.quote_plus('email https://www.googleapis.com/auth/spreadsheets https://www.googleapis.com/auth/drive.metadata.readonly')}" + f"&access_type=offline&prompt=consent&state={conv_id}" + ) + markdown_url = f"To navigate Google Ads waters, I require access to your account. Please [click here]({google_oauth_url}) to grant permission." + return {"login_url": markdown_url} + + +# Route 2: Save user credentials/token to a JSON file +@app.get("/login/callback") +async def login_callback( + code: str = Query(title="Authorization Code"), state: str = Query(title="State") +) -> RedirectResponse: + chat_id = state + user_id, chat_uuid = await get_user_id_chat_uuid_from_chat_id(chat_id) + # user_id, chat_id = await get_user_id_chat_id_from_conversation(conv_id) + user = await get_user(user_id=user_id) + + token_request_data = { + "code": code, + "client_id": oauth2_settings["clientId"], + "client_secret": oauth2_settings["clientSecret"], + "redirect_uri": oauth2_settings["redirectUri"], + "grant_type": "authorization_code", + } + + async with httpx.AsyncClient() as client: + response = await client.post( + oauth2_settings["tokenUrl"], data=token_request_data + ) + + if response.status_code == 200: + token_data = response.json() + + async with httpx.AsyncClient() as client: + userinfo_response = await client.get( + "https://www.googleapis.com/oauth2/v2/userinfo", + headers={"Authorization": f"Bearer {token_data['access_token']}"}, + ) + + if userinfo_response.status_code == 200: + user_info = userinfo_response.json() + async with get_db_connection() as db: + await db.gauth.upsert( + where={"user_id": user["id"]}, + data={ + "create": { + "user_id": user["id"], + "creds": json.dumps(token_data), + "info": json.dumps(user_info), + }, + "update": { + "creds": json.dumps(token_data), + "info": json.dumps(user_info), + }, + }, + ) + + redirect_domain = environ.get("REDIRECT_DOMAIN", "https://captn.ai") + logged_in_message = "I have successfully logged in" + redirect_uri = f"{redirect_domain}/chat/{chat_uuid}?msg={logged_in_message}" + return RedirectResponse(redirect_uri) + + async def get_user(user_id: Union[int, str]) -> Any: wasp_db_url = await get_db_url(db_name="waspdb") async with get_db_connection(db_url=wasp_db_url) as db: diff --git a/google_sheets/db_helpers.py b/google_sheets/db_helpers.py index 434ffc8..b2a37ba 100644 --- a/google_sheets/db_helpers.py +++ b/google_sheets/db_helpers.py @@ -27,8 +27,10 @@ async def get_db_connection( async def get_db_url(db_name: str) -> str: curr_db_url = environ.get("DATABASE_URL") - wasp_db_name = environ.get("WASP_DB_NAME", db_name) - wasp_db_url = curr_db_url.replace(curr_db_url.split("/")[-1], wasp_db_name) # type: ignore[union-attr] - if "connect_timeout" not in wasp_db_url: - wasp_db_url += "?connect_timeout=60" - return wasp_db_url + if db_name == "waspdb": + db_name = environ.get("WASP_DB_NAME", db_name) + db_url = curr_db_url.replace(curr_db_url.split("/")[-1], db_name) # type: ignore[union-attr] + if "connect_timeout" not in db_url: + db_url += "?connect_timeout=60" + + return db_url