diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1618ddd..f0d30cf 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -241,6 +241,7 @@ jobs: GITHUB_PASSWORD: ${{ secrets.GITHUB_TOKEN }} # DEVELOPER_TOKEN: ${{ secrets.DEVELOPER_TOKEN }} DOMAIN: ${{ github.ref_name == 'main' && vars.PROD_DOMAIN || vars.STAGING_DOMAIN }} + REDIRECT_DOMAIN: ${{ github.ref_name == 'main' && vars.PROD_REDIRECT_DOMAIN || vars.STAGING_REDIRECT_DOMAIN }} SSH_KEY: ${{ github.ref_name == 'main' && secrets.PROD_SSH_KEY || secrets.STAGING_SSH_KEY }} steps: - uses: actions/checkout@v3 # Don't change it to cheackout@v4. V4 is not working with container image. diff --git a/google-sheets-docker-compose.yaml b/google-sheets-docker-compose.yaml index 5aa106e..eeac56d 100644 --- a/google-sheets-docker-compose.yaml +++ b/google-sheets-docker-compose.yaml @@ -9,6 +9,7 @@ services: - "8001:8000" environment: - DOMAIN=${DOMAIN} + - REDIRECT_DOMAIN=${REDIRECT_DOMAIN} - DATABASE_URL=${DATABASE_URL} - CLIENT_SECRET=${CLIENT_SECRET} networks: diff --git a/google_sheets/app.py b/google_sheets/app.py index 3bcd57c..7eb5f15 100644 --- a/google_sheets/app.py +++ b/google_sheets/app.py @@ -1,11 +1,14 @@ import json import logging +import urllib.parse from os import environ from pathlib import Path -from typing import Annotated, Any, Dict, List, Union +from typing import Annotated, Any, Dict, List, Tuple, Union +import httpx from asyncify import asyncify -from fastapi import FastAPI, HTTPException, Query +from fastapi import FastAPI, HTTPException, Query, Request +from fastapi.responses import RedirectResponse from google.oauth2.credentials import Credentials from googleapiclient.discovery import build from prisma.errors import RecordNotFoundError @@ -44,8 +47,118 @@ } +async def get_user_id_chat_uuid_from_chat_id( + chat_id: Union[int, str], +) -> Tuple[int, str]: + wasp_db_url = get_wasp_db_url() + async with get_db_connection(db_url=wasp_db_url) as db: + chat = await db.query_first( + f'SELECT * from "Chat" where id={chat_id}' # nosec: [B608] + ) + if not chat: + raise HTTPException(status_code=404, detail=f"chat {chat} not found") + user_id = chat["userId"] + chat_uuid = chat["uuid"] + return user_id, chat_uuid + + +async def is_authenticated_for_ads(user_id: int) -> bool: + await get_user(user_id=user_id) + async with get_db_connection() as db: + data = await db.gauth.find_unique(where={"user_id": user_id}) + + if not data: + return False + return True + + +# Route 1: Redirect to Google OAuth +@app.get("/login") +async def get_login_url( + request: Request, + user_id: int = Query(title="User ID"), + conv_id: int = Query(title="Conversation ID"), + force_new_login: bool = Query(title="Force new login", default=False), +) -> Dict[str, str]: + if not force_new_login: + is_authenticated = await is_authenticated_for_ads(user_id=user_id) + if is_authenticated: + return {"login_url": "User is already authenticated"} + + google_oauth_url = ( + f"{oauth2_settings['auth_uri']}?client_id={oauth2_settings['clientId']}" + f"&redirect_uri={oauth2_settings['redirectUri']}&response_type=code" + f"&scope={urllib.parse.quote_plus('email https://www.googleapis.com/auth/spreadsheets https://www.googleapis.com/auth/drive.metadata.readonly')}" + f"&access_type=offline&prompt=consent&state={conv_id}" + ) + markdown_url = f"To navigate Google Ads waters, I require access to your account. Please [click here]({google_oauth_url}) to grant permission." + return {"login_url": markdown_url} + + +@app.get("/login/success") +async def get_login_success() -> Dict[str, str]: + return {"login_success": "You have successfully logged in"} + + +# Route 2: Save user credentials/token to a JSON file +@app.get("/login/callback") +async def login_callback( + code: str = Query(title="Authorization Code"), state: str = Query(title="State") +) -> RedirectResponse: + chat_id = state + user_id, chat_uuid = await get_user_id_chat_uuid_from_chat_id(chat_id) + user = await get_user(user_id=user_id) + + token_request_data = { + "code": code, + "client_id": oauth2_settings["clientId"], + "client_secret": oauth2_settings["clientSecret"], + "redirect_uri": oauth2_settings["redirectUri"], + "grant_type": "authorization_code", + } + + async with httpx.AsyncClient() as client: + response = await client.post( + oauth2_settings["tokenUrl"], data=token_request_data + ) + + if response.status_code == 200: + token_data = response.json() + + async with httpx.AsyncClient() as client: + userinfo_response = await client.get( + "https://www.googleapis.com/oauth2/v2/userinfo", + headers={"Authorization": f"Bearer {token_data['access_token']}"}, + ) + + if userinfo_response.status_code == 200: + user_info = userinfo_response.json() + async with get_db_connection() as db: + await db.gauth.upsert( + where={"user_id": user["id"]}, + data={ + "create": { + "user_id": user["id"], + "creds": json.dumps(token_data), + "info": json.dumps(user_info), + }, + "update": { + "creds": json.dumps(token_data), + "info": json.dumps(user_info), + }, + }, + ) + + # redirect_domain = environ.get("REDIRECT_DOMAIN", "https://captn.ai") + # logged_in_message = "I have successfully logged in" + # redirect_uri = f"{redirect_domain}/chat/{chat_uuid}?msg={logged_in_message}" + # return RedirectResponse(redirect_uri) + # redirect to success page + return RedirectResponse(url=f"{base_url}/login/success") + + async def get_user(user_id: Union[int, str]) -> Any: - wasp_db_url = await get_wasp_db_url() + wasp_db_url = get_wasp_db_url() async with get_db_connection(db_url=wasp_db_url) as db: user = await db.query_first( f'SELECT * from "User" where id={user_id}' # nosec: [B608] diff --git a/google_sheets/db_helpers.py b/google_sheets/db_helpers.py index 928e4b7..d99f06a 100644 --- a/google_sheets/db_helpers.py +++ b/google_sheets/db_helpers.py @@ -25,7 +25,7 @@ async def get_db_connection( await db.disconnect() -async def get_wasp_db_url() -> str: +def get_wasp_db_url() -> str: curr_db_url = environ.get("DATABASE_URL") wasp_db_name = environ.get("WASP_DB_NAME", "waspdb") wasp_db_url = curr_db_url.replace(curr_db_url.split("/")[-1], wasp_db_name) # type: ignore[union-attr] diff --git a/migrations/20240620124134_initial/migration.sql b/migrations/20240620124134_initial/migration.sql new file mode 100644 index 0000000..46c7cfd --- /dev/null +++ b/migrations/20240620124134_initial/migration.sql @@ -0,0 +1,14 @@ +-- CreateTable +CREATE TABLE "GAuth" ( + "id" TEXT NOT NULL, + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP(3) NOT NULL, + "user_id" INTEGER NOT NULL, + "creds" JSONB NOT NULL, + "info" JSONB NOT NULL, + + CONSTRAINT "GAuth_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "GAuth_user_id_key" ON "GAuth"("user_id"); diff --git a/migrations/migration_lock.toml b/migrations/migration_lock.toml new file mode 100644 index 0000000..99e4f20 --- /dev/null +++ b/migrations/migration_lock.toml @@ -0,0 +1,3 @@ +# Please do not edit this file manually +# It should be added in your version-control system (i.e. Git) +provider = "postgresql" diff --git a/scripts/deploy.sh b/scripts/deploy.sh index 3680fdf..8ba07e0 100755 --- a/scripts/deploy.sh +++ b/scripts/deploy.sh @@ -12,6 +12,7 @@ check_variable "TAG" check_variable "GITHUB_USERNAME" check_variable "GITHUB_PASSWORD" check_variable "DOMAIN" +check_variable "REDIRECT_DOMAIN" check_variable "CLIENT_SECRET" check_variable "DATABASE_URL" @@ -51,5 +52,5 @@ $ssh_command "docker system prune -f || echo 'No images to delete'" echo "INFO: starting docker containers" $ssh_command "export GITHUB_REPOSITORY='$GITHUB_REPOSITORY' TAG='$TAG' container_name='$container_name' \ - DATABASE_URL='$DATABASE_URL' CLIENT_SECRET='$CLIENT_SECRET' DOMAIN='$DOMAIN' \ + DATABASE_URL='$DATABASE_URL' CLIENT_SECRET='$CLIENT_SECRET' DOMAIN='$DOMAIN' REDIRECT_DOMAIN='$REDIRECT_DOMAIN' \ && docker compose -f google-sheets-docker-compose.yaml up -d" diff --git a/scripts/run_server.sh b/scripts/run_server.sh index 2e66e1a..f6eb568 100755 --- a/scripts/run_server.sh +++ b/scripts/run_server.sh @@ -2,8 +2,7 @@ cat <<< "$CLIENT_SECRET" > client_secret.json -# ToDo: Uncomment the below line once we have project specific migrations -# prisma migrate deploy +prisma migrate deploy prisma generate uvicorn google_sheets.app:app --workers 2 --host 0.0.0.0 --proxy-headers diff --git a/tests/app/test_app.py b/tests/app/test_app.py index af8e8ba..2b831e6 100644 --- a/tests/app/test_app.py +++ b/tests/app/test_app.py @@ -63,6 +63,122 @@ def test_openapi(self) -> None: } ], "paths": { + "/login": { + "get": { + "summary": "Get Login Url", + "operationId": "get_login_url_login_get", + "parameters": [ + { + "name": "user_id", + "in": "query", + "required": True, + "schema": {"type": "integer", "title": "User ID"}, + }, + { + "name": "conv_id", + "in": "query", + "required": True, + "schema": { + "type": "integer", + "title": "Conversation ID", + }, + }, + { + "name": "force_new_login", + "in": "query", + "required": False, + "schema": { + "type": "boolean", + "title": "Force new login", + "default": False, + }, + }, + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": {"type": "string"}, + "title": "Response Get Login Url Login Get", + } + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + }, + "/login/success": { + "get": { + "summary": "Get Login Success", + "operationId": "get_login_success_login_success_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": {"type": "string"}, + "type": "object", + "title": "Response Get Login Success Login Success Get", + } + } + }, + } + }, + } + }, + "/login/callback": { + "get": { + "summary": "Login Callback", + "operationId": "login_callback_login_callback_get", + "parameters": [ + { + "name": "code", + "in": "query", + "required": True, + "schema": { + "type": "string", + "title": "Authorization Code", + }, + }, + { + "name": "state", + "in": "query", + "required": True, + "schema": {"type": "string", "title": "State"}, + }, + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + }, "/sheet": { "get": { "summary": "Get Sheet", diff --git a/tests/app/test_db_helpers.py b/tests/app/test_db_helpers.py new file mode 100644 index 0000000..2d56522 --- /dev/null +++ b/tests/app/test_db_helpers.py @@ -0,0 +1,17 @@ +import os +from unittest.mock import patch + +from google_sheets.db_helpers import get_wasp_db_url + + +def test_get_wasp_db_url() -> None: + root_db_url = "db://user:pass@localhost:5432" # pragma: allowlist secret + env_vars = { + "DATABASE_URL": f"{root_db_url}/dbname", + "WASP_DB_NAME": "waspdb", + } + with patch.dict(os.environ, env_vars, clear=True): + wasp_db_url = get_wasp_db_url() + excepted = f"{root_db_url}/waspdb?connect_timeout=60" + + assert wasp_db_url == excepted