Skip to content

Commit

Permalink
Add login
Browse files Browse the repository at this point in the history
  • Loading branch information
rjambrecic committed Jun 20, 2024
1 parent 9abf4cc commit e351e15
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 7 deletions.
111 changes: 109 additions & 2 deletions google_sheets/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import logging
import urllib.parse
from os import environ
from pathlib import Path
from typing import Annotated, Any, Dict, List, Union
from typing import Annotated, Any, Dict, List, Tuple, Union

import httpx
from asyncify import asyncify
from fastapi import FastAPI, HTTPException, Query
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from prisma.errors import RecordNotFoundError
Expand Down Expand Up @@ -44,6 +47,110 @@
}


async def get_user_id_chat_uuid_from_chat_id(
chat_id: Union[int, str],
) -> Tuple[int, str]:
wasp_db_url = await get_db_url(db_name="waspdb")
async with get_db_connection(db_url=wasp_db_url) as db:
chat = await db.query_first(
f'SELECT * from "Chat" where id={chat_id}' # nosec: [B608]
)
if not chat:
raise HTTPException(status_code=404, detail=f"chat {chat} not found")
user_id = chat["userId"]
chat_uuid = chat["uuid"]
return user_id, chat_uuid


async def is_authenticated_for_ads(user_id: int) -> bool:
await get_user(user_id=user_id)
async with get_db_connection() as db:
data = await db.gauth.find_unique(where={"user_id": user_id})

if not data:
return False
return True


# Route 1: Redirect to Google OAuth
@app.get("/login")
async def get_login_url(
request: Request,
user_id: int = Query(title="User ID"),
conv_id: int = Query(title="Conversation ID"),
force_new_login: bool = Query(title="Force new login", default=False),
) -> Dict[str, str]:
if not force_new_login:
is_authenticated = await is_authenticated_for_ads(user_id=user_id)
if is_authenticated:
return {"login_url": "User is already authenticated"}

google_oauth_url = (
f"{oauth2_settings['auth_uri']}?client_id={oauth2_settings['clientId']}"
f"&redirect_uri={oauth2_settings['redirectUri']}&response_type=code"
f"&scope={urllib.parse.quote_plus('email https://www.googleapis.com/auth/spreadsheets https://www.googleapis.com/auth/drive.metadata.readonly')}"
f"&access_type=offline&prompt=consent&state={conv_id}"
)
markdown_url = f"To navigate Google Ads waters, I require access to your account. Please [click here]({google_oauth_url}) to grant permission."
return {"login_url": markdown_url}


# Route 2: Save user credentials/token to a JSON file
@app.get("/login/callback")
async def login_callback(
code: str = Query(title="Authorization Code"), state: str = Query(title="State")
) -> RedirectResponse:
chat_id = state
user_id, chat_uuid = await get_user_id_chat_uuid_from_chat_id(chat_id)
# user_id, chat_id = await get_user_id_chat_id_from_conversation(conv_id)
user = await get_user(user_id=user_id)

token_request_data = {
"code": code,
"client_id": oauth2_settings["clientId"],
"client_secret": oauth2_settings["clientSecret"],
"redirect_uri": oauth2_settings["redirectUri"],
"grant_type": "authorization_code",
}

async with httpx.AsyncClient() as client:
response = await client.post(
oauth2_settings["tokenUrl"], data=token_request_data
)

if response.status_code == 200:
token_data = response.json()

async with httpx.AsyncClient() as client:
userinfo_response = await client.get(
"https://www.googleapis.com/oauth2/v2/userinfo",
headers={"Authorization": f"Bearer {token_data['access_token']}"},
)

if userinfo_response.status_code == 200:
user_info = userinfo_response.json()
async with get_db_connection() as db:
await db.gauth.upsert(
where={"user_id": user["id"]},
data={
"create": {
"user_id": user["id"],
"creds": json.dumps(token_data),
"info": json.dumps(user_info),
},
"update": {
"creds": json.dumps(token_data),
"info": json.dumps(user_info),
},
},
)

redirect_domain = environ.get("REDIRECT_DOMAIN", "https://captn.ai")
logged_in_message = "I have successfully logged in"
redirect_uri = f"{redirect_domain}/chat/{chat_uuid}?msg={logged_in_message}"
return RedirectResponse(redirect_uri)


async def get_user(user_id: Union[int, str]) -> Any:
wasp_db_url = await get_db_url(db_name="waspdb")
async with get_db_connection(db_url=wasp_db_url) as db:
Expand Down
12 changes: 7 additions & 5 deletions google_sheets/db_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ async def get_db_connection(

async def get_db_url(db_name: str) -> str:
curr_db_url = environ.get("DATABASE_URL")
wasp_db_name = environ.get("WASP_DB_NAME", db_name)
wasp_db_url = curr_db_url.replace(curr_db_url.split("/")[-1], wasp_db_name) # type: ignore[union-attr]
if "connect_timeout" not in wasp_db_url:
wasp_db_url += "?connect_timeout=60"
return wasp_db_url
if db_name == "waspdb":
db_name = environ.get("WASP_DB_NAME", db_name)
db_url = curr_db_url.replace(curr_db_url.split("/")[-1], db_name) # type: ignore[union-attr]
if "connect_timeout" not in db_url:
db_url += "?connect_timeout=60"

return db_url

0 comments on commit e351e15

Please sign in to comment.