From cdf4cb4a3d372af26a37c58647aed2cf01cce8bd Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Mon, 2 Sep 2024 11:06:35 +0200 Subject: [PATCH] Switch back to old structure --- backend/chainlit/__init__.py | 199 +---- backend/chainlit/auth.py | 2 +- backend/chainlit/callbacks.py | 38 +- backend/chainlit/chat_context.py | 2 - backend/chainlit/config.py | 2 +- backend/chainlit/data/acl.py | 1 - backend/chainlit/data/sql_alchemy.py | 17 +- backend/chainlit/oauth/__init__.py | 0 .../chainlit/oauth/auth0_oauth_provider.py | 69 -- .../oauth/aws_cognito_oauth_provider.py | 72 -- .../oauth/azure_ad_hubrid_oauth_provider.py | 92 --- .../chainlit/oauth/azure_ad_oauth_provider.py | 89 --- .../chainlit/oauth/descope_oauth_provider.py | 61 -- backend/chainlit/oauth/github.py | 63 -- .../chainlit/oauth/gitlab_oauth_provider.py | 67 -- backend/chainlit/oauth/google.py | 58 -- backend/chainlit/oauth/oauth_provider.py | 22 - backend/chainlit/oauth/okta_oauth_provider.py | 78 -- backend/chainlit/oauth/providers.py | 42 -- backend/chainlit/oauth_providers.py | 683 +++++++++++++++++- backend/chainlit/server.py | 10 - 21 files changed, 696 insertions(+), 971 deletions(-) delete mode 100644 backend/chainlit/oauth/__init__.py delete mode 100644 backend/chainlit/oauth/auth0_oauth_provider.py delete mode 100644 backend/chainlit/oauth/aws_cognito_oauth_provider.py delete mode 100644 backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py delete mode 100644 backend/chainlit/oauth/azure_ad_oauth_provider.py delete mode 100644 backend/chainlit/oauth/descope_oauth_provider.py delete mode 100644 backend/chainlit/oauth/github.py delete mode 100644 backend/chainlit/oauth/gitlab_oauth_provider.py delete mode 100644 backend/chainlit/oauth/google.py delete mode 100644 backend/chainlit/oauth/oauth_provider.py delete mode 100644 backend/chainlit/oauth/okta_oauth_provider.py delete mode 100644 backend/chainlit/oauth/providers.py diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 71071d7e52..0506ef38f3 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -12,7 +12,7 @@ logger.info("Loaded .env file") import asyncio -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict import chainlit.input_widget as input_widget from chainlit.action import Action @@ -41,7 +41,6 @@ ErrorMessage, Message, ) -from chainlit.oauth.providers import get_configured_oauth_providers from chainlit.step import Step, step from chainlit.sync import make_async, run_sync from chainlit.types import AudioChunk, ChatProfile, Starter @@ -76,202 +75,6 @@ from chainlit.langchain.callbacks import ( AsyncLangchainCallbackHandler, LangchainCallbackHandler, - config.code.oauth_callback = wrap_user_function(func) - return func - - -@trace -def custom_authenticate_user(func: Callable[[str], Awaitable[User]]) -> Callable: - """ - A decorator to authenticate the user via custom token validation. - - Args: - func (Callable[[str, str, Dict[str, str], User], Optional[User]]): The authentication callback to execute. - - Returns: - Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated authentication callback. - """ - - if len(get_configured_oauth_providers()) == 0: - raise ValueError( - "You must set the environment variable for at least one oauth provider to use oauth authentication." - ) - - config.code.custom_authenticate_user = wrap_user_function(func) - return func - - -@trace -def custom_oauth_provider(func: Callable[[str], Awaitable[User]]) -> Callable: - """ - A decorator to integrate custom OAuth provider logic for user authentication. - - Args: - func (Callable[[str, str, Dict[str, str], User], Optional[User]]): A function that returns an instance of the OAuthProvider class, encapsulating the logic and details for the custom OAuth provider. - - Returns: - Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated callback function that handles authentication via the custom OAuth provider. - """ - - if len(get_configured_oauth_providers()) == 0: - raise ValueError( - "You must set the environment variable for at least one oauth provider to use oauth authentication." - ) - - config.code.custom_oauth_provider = wrap_user_function(func) - return func - - -@trace -def on_logout(func: Callable[[Request, Response], Any]) -> Callable: - """ - Function called when the user logs out. - Takes the FastAPI request and response as parameters. - """ - - config.code.on_logout = wrap_user_function(func) - return func - - -@trace -def on_message(func: Callable) -> Callable: - """ - Framework agnostic decorator to react to messages coming from the UI. - The decorated function is called every time a new message is received. - - Args: - func (Callable[[Message], Any]): The function to be called when a new message is received. Takes a cl.Message. - - Returns: - Callable[[str], Any]: The decorated on_message function. - """ - - async def with_parent_id(message: Message): - async with Step(name="on_message", type="run", parent_id=message.id) as s: - s.input = message.content - if len(inspect.signature(func).parameters) > 0: - await func(message) - else: - await func() - - config.code.on_message = wrap_user_function(with_parent_id) - return func - - -@trace -def on_chat_start(func: Callable) -> Callable: - """ - Hook to react to the user websocket connection event. - - Args: - func (Callable[], Any]): The connection hook to execute. - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_chat_start = wrap_user_function( - step(func, name="on_chat_start", type="run"), with_task=True - ) - return func - - -@trace -def on_chat_resume(func: Callable[[ThreadDict], Any]) -> Callable: - """ - Hook to react to resume websocket connection event. - - Args: - func (Callable[], Any]): The connection hook to execute. - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_chat_resume = wrap_user_function(func, with_task=True) - return func - - -@trace -def set_chat_profiles( - func: Callable[[Optional["User"]], List["ChatProfile"]] -) -> Callable: - """ - Programmatic declaration of the available chat profiles (can depend on the User from the session if authentication is setup). - - Args: - func (Callable[[Optional["User"]], List["ChatProfile"]]): The function declaring the chat profiles. - - Returns: - Callable[[Optional["User"]], List["ChatProfile"]]: The decorated function. - """ - - config.code.set_chat_profiles = wrap_user_function(func) - return func - - -@trace -def set_starters(func: Callable[[Optional["User"]], List["Starter"]]) -> Callable: - """ - Programmatic declaration of the available starter (can depend on the User from the session if authentication is setup). - - Args: - func (Callable[[Optional["User"]], List["Starter"]]): The function declaring the starters. - - Returns: - Callable[[Optional["User"]], List["Starter"]]: The decorated function. - """ - - config.code.set_starters = wrap_user_function(func) - return func - - -@trace -def on_chat_end(func: Callable) -> Callable: - """ - Hook to react to the user websocket disconnect event. - - Args: - func (Callable[], Any]): The disconnect hook to execute. - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_chat_end = wrap_user_function(func, with_task=True) - return func - - -@trace -def on_audio_chunk(func: Callable) -> Callable: - """ - Hook to react to the audio chunks being sent. - - Args: - chunk (AudioChunk): The audio chunk being sent. - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_audio_chunk = wrap_user_function(func, with_task=False) - return func - - -@trace -def on_audio_end(func: Callable) -> Callable: - """ - Hook to react to the audio stream ending. This is called after the last audio chunk is sent. - - Args: - elements ([List[Element]): The files that were uploaded before starting the audio stream (if any). - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_audio_end = wrap_user_function( - step(func, name="on_audio_end", type="run"), with_task=True ) from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler from chainlit.mistralai import instrument_mistralai diff --git a/backend/chainlit/auth.py b/backend/chainlit/auth.py index a4b1b326a7..9bd2073b55 100644 --- a/backend/chainlit/auth.py +++ b/backend/chainlit/auth.py @@ -5,7 +5,7 @@ import jwt from chainlit.config import config from chainlit.data import get_data_layer -from chainlit.oauth.providers import get_configured_oauth_providers +from chainlit.oauth_providers import get_configured_oauth_providers from chainlit.user import User from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer diff --git a/backend/chainlit/callbacks.py b/backend/chainlit/callbacks.py index b559049d7b..c9ac5fe43c 100644 --- a/backend/chainlit/callbacks.py +++ b/backend/chainlit/callbacks.py @@ -4,7 +4,11 @@ from chainlit.action import Action from chainlit.config import config from chainlit.message import Message -from chainlit.oauth_providers import get_configured_oauth_providers +from chainlit.oauth_providers import ( + OAuthProvider, + get_configured_oauth_providers, + providers, +) from chainlit.step import Step, step from chainlit.telemetry import trace from chainlit.types import ChatProfile, Starter, ThreadDict @@ -87,6 +91,38 @@ async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, return func +@trace +def custom_authenticate_user(func: Callable[[str], Awaitable[User]]) -> Callable: + """ + A decorator to authenticate the user via custom token validation. + + Args: + func (Callable[[str], Awaitable[User]]): The authentication callback to execute. + + Returns: + Callable[[str], Awaitable[User]]: The decorated authentication callback. + """ + + if len(get_configured_oauth_providers()) == 0: + raise ValueError( + "You must set the environment variable for at least one oauth provider to use oauth authentication." + ) + + config.code.custom_authenticate_user = wrap_user_function(func) + return func + + +def custom_oauth_provider(func: Callable[[], OAuthProvider]) -> None: + """ + A decorator to integrate custom OAuth provider logic for user authentication. + + Args: + func (Callable[[], OAuthProvider): A function that returns an instance of the OAuthProvider class, encapsulating the logic and details for the custom OAuth provider. + """ + + providers.append(func()) + + @trace def on_logout(func: Callable[[Request, Response], Any]) -> Callable: """ diff --git a/backend/chainlit/chat_context.py b/backend/chainlit/chat_context.py index 5f7215ba56..0362b4fd15 100644 --- a/backend/chainlit/chat_context.py +++ b/backend/chainlit/chat_context.py @@ -25,10 +25,8 @@ def add(self, message: "Message"): if context.session.id not in chat_contexts: chat_contexts[context.session.id] = [] - if message not in chat_contexts[context.session.id]: chat_contexts[context.session.id].append(message) - return message def remove(self, message: "Message") -> bool: diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 22fc0d6ec9..af54857d99 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -19,7 +19,7 @@ import tomli from chainlit.logger import logger -from chainlit.oauth.oauth_provider import OAuthProvider +from chainlit.oauth_providers import OAuthProvider from chainlit.translations import lint_translation_json from chainlit.version import __version__ from dataclasses_json import DataClassJsonMixin diff --git a/backend/chainlit/data/acl.py b/backend/chainlit/data/acl.py index 65c040170a..81580109be 100644 --- a/backend/chainlit/data/acl.py +++ b/backend/chainlit/data/acl.py @@ -8,7 +8,6 @@ async def is_thread_author(username: str, thread_id: str): raise HTTPException(status_code=400, detail="Data layer not initialized") thread_author = await data_layer.get_thread_author(thread_id) - if not thread_author: raise HTTPException(status_code=404, detail="Thread not found") diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py index 9a4f65b411..cbc0a01155 100644 --- a/backend/chainlit/data/sql_alchemy.py +++ b/backend/chainlit/data/sql_alchemy.py @@ -373,12 +373,18 @@ async def delete_feedback(self, feedback_id: str) -> bool: return True ###### Elements ###### - async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: + async def get_element( + self, thread_id: str, element_id: str + ) -> Optional["ElementDict"]: if self.show_logger: - logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}") + logger.info( + f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}" + ) query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id""" parameters = {"thread_id": thread_id, "element_id": element_id} - element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters) + element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql( + query=query, parameters=parameters + ) if isinstance(element, list) and element: element_dict: Dict[str, Any] = element[0] return ElementDict( @@ -396,7 +402,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen autoPlay=element_dict.get("autoPlay"), playerConfig=element_dict.get("playerConfig"), forId=element_dict.get("forId"), - mime=element_dict.get("mime") + mime=element_dict.get("mime"), ) else: return None @@ -607,7 +613,8 @@ async def get_all_user_threads( tags=step_feedback.get("step_tags"), input=( step_feedback.get("step_input", "") - if step_feedback.get("step_showinput") not in [None, "false"] + if step_feedback.get("step_showinput") + not in [None, "false"] else None ), output=step_feedback.get("step_output", ""), diff --git a/backend/chainlit/oauth/__init__.py b/backend/chainlit/oauth/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/backend/chainlit/oauth/auth0_oauth_provider.py b/backend/chainlit/oauth/auth0_oauth_provider.py deleted file mode 100644 index 91e70fb376..0000000000 --- a/backend/chainlit/oauth/auth0_oauth_provider.py +++ /dev/null @@ -1,69 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class Auth0OAuthProvider(OAuthProvider): - id = "auth0" - env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"] - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET") - # Ensure that the domain does not have a trailing slash - self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}" - self.original_domain = ( - f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}" - if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN") - else self.domain - ) - - self.authorize_url = f"{self.domain}/authorize" - - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "audience": f"{self.original_domain}/userinfo", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.original_domain}/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - auth0_user = response.json() - user = User( - identifier=auth0_user.get("email"), - metadata={ - "image": auth0_user.get("picture", ""), - "provider": "auth0", - }, - ) - return (auth0_user, user) diff --git a/backend/chainlit/oauth/aws_cognito_oauth_provider.py b/backend/chainlit/oauth/aws_cognito_oauth_provider.py deleted file mode 100644 index d5d286d134..0000000000 --- a/backend/chainlit/oauth/aws_cognito_oauth_provider.py +++ /dev/null @@ -1,72 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class AWSCognitoOAuthProvider(OAuthProvider): - id = "aws-cognito" - env = [ - "OAUTH_COGNITO_CLIENT_ID", - "OAUTH_COGNITO_CLIENT_SECRET", - "OAUTH_COGNITO_DOMAIN", - ] - authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login" - token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET") - self.authorize_params = { - "response_type": "code", - "client_id": self.client_id, - "scope": "openid profile email", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - user_info_url = ( - f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo" - ) - async with httpx.AsyncClient() as client: - response = await client.get( - user_info_url, - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - cognito_user = response.json() - - # Customize user metadata as needed - user = User( - identifier=cognito_user["email"], - metadata={ - "image": cognito_user.get("picture", ""), - "provider": "aws-cognito", - }, - ) - return (cognito_user, user) diff --git a/backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py b/backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py deleted file mode 100644 index da935d1543..0000000000 --- a/backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py +++ /dev/null @@ -1,92 +0,0 @@ -import base64 -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.secret import random_secret -from chainlit.user import User -from fastapi import HTTPException - - -class AzureADHybridOAuthProvider(OAuthProvider): - id = "azure-ad-hybrid" - env = [ - "OAUTH_AZURE_AD_HYBRID_CLIENT_ID", - "OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET", - "OAUTH_AZURE_AD_HYBRID_TENANT_ID", - ] - authorize_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize" - if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - ) - token_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token" - if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/token" - ) - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET") - nonce = random_secret(16) - self.authorize_params = { - "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"), - "response_type": "code id_token", - "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid", - "response_mode": "form_post", - "nonce": nonce, - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://graph.microsoft.com/v1.0/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - azure_user = response.json() - - try: - photo_response = await client.get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", - headers={"Authorization": f"Bearer {token}"}, - ) - photo_data = await photo_response.aread() - base64_image = base64.b64encode(photo_data) - azure_user["image"] = ( - f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" - ) - except Exception as e: - # Ignore errors getting the photo - pass - - user = User( - identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, - ) - return (azure_user, user) diff --git a/backend/chainlit/oauth/azure_ad_oauth_provider.py b/backend/chainlit/oauth/azure_ad_oauth_provider.py deleted file mode 100644 index 2ed32bbe28..0000000000 --- a/backend/chainlit/oauth/azure_ad_oauth_provider.py +++ /dev/null @@ -1,89 +0,0 @@ -import base64 -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class AzureADOAuthProvider(OAuthProvider): - id = "azure-ad" - env = [ - "OAUTH_AZURE_AD_CLIENT_ID", - "OAUTH_AZURE_AD_CLIENT_SECRET", - "OAUTH_AZURE_AD_TENANT_ID", - ] - authorize_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize" - if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - ) - token_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token" - if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/token" - ) - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET") - self.authorize_params = { - "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), - "response_type": "code", - "scope": "https://graph.microsoft.com/User.Read", - "response_mode": "query", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://graph.microsoft.com/v1.0/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - azure_user = response.json() - - try: - photo_response = await client.get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", - headers={"Authorization": f"Bearer {token}"}, - ) - photo_data = await photo_response.aread() - base64_image = base64.b64encode(photo_data) - azure_user["image"] = ( - f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" - ) - except Exception as e: - # Ignore errors getting the photo - pass - - user = User( - identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, - ) - return (azure_user, user) diff --git a/backend/chainlit/oauth/descope_oauth_provider.py b/backend/chainlit/oauth/descope_oauth_provider.py deleted file mode 100644 index 08c96da0d5..0000000000 --- a/backend/chainlit/oauth/descope_oauth_provider.py +++ /dev/null @@ -1,61 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User - - -class DescopeOAuthProvider(OAuthProvider): - id = "descope" - env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"] - # Ensure that the domain does not have a trailing slash - domain = f"https://api.descope.com/oauth2/v1" - - authorize_url = f"{domain}/authorize" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET") - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "audience": f"{self.domain}/userinfo", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} - ) - response.raise_for_status() # This will raise an exception for 4xx/5xx responses - descope_user = response.json() - - user = User( - identifier=descope_user.get("email"), - metadata={"image": "", "provider": "descope"}, - ) - return (descope_user, user) diff --git a/backend/chainlit/oauth/github.py b/backend/chainlit/oauth/github.py deleted file mode 100644 index 5ab0d72059..0000000000 --- a/backend/chainlit/oauth/github.py +++ /dev/null @@ -1,63 +0,0 @@ -import os -import urllib.parse - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class GithubOAuthProvider(OAuthProvider): - id = "github" - env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"] - authorize_url = "https://github.com/login/oauth/authorize" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET") - self.authorize_params = { - "scope": "user:email", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - } - async with httpx.AsyncClient() as client: - response = await client.post( - "https://github.com/login/oauth/access_token", - data=payload, - ) - response.raise_for_status() - content = urllib.parse.parse_qs(response.text) - token = content.get("access_token", [""])[0] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - user_response = await client.get( - "https://api.github.com/user", - headers={"Authorization": f"token {token}"}, - ) - user_response.raise_for_status() - github_user = user_response.json() - - emails_response = await client.get( - "https://api.github.com/user/emails", - headers={"Authorization": f"token {token}"}, - ) - emails_response.raise_for_status() - emails = emails_response.json() - - github_user.update({"emails": emails}) - user = User( - identifier=github_user["login"], - metadata={"image": github_user["avatar_url"], "provider": "github"}, - ) - return (github_user, user) diff --git a/backend/chainlit/oauth/gitlab_oauth_provider.py b/backend/chainlit/oauth/gitlab_oauth_provider.py deleted file mode 100644 index 22e993a77b..0000000000 --- a/backend/chainlit/oauth/gitlab_oauth_provider.py +++ /dev/null @@ -1,67 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class GitlabOAuthProvider(OAuthProvider): - id = "gitlab" - env = [ - "OAUTH_GITLAB_CLIENT_ID", - "OAUTH_GITLAB_CLIENT_SECRET", - "OAUTH_GITLAB_DOMAIN", - ] - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET") - # Ensure that the domain does not have a trailing slash - self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}" - - self.authorize_url = f"{self.domain}/oauth/authorize" - - self.authorize_params = { - "scope": "openid profile email", - "response_type": "code", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/oauth/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - gitlab_user = response.json() - user = User( - identifier=gitlab_user.get("email"), - metadata={ - "image": gitlab_user.get("picture", ""), - "provider": "gitlab", - }, - ) - return (gitlab_user, user) diff --git a/backend/chainlit/oauth/google.py b/backend/chainlit/oauth/google.py deleted file mode 100644 index 0d4cc1cfea..0000000000 --- a/backend/chainlit/oauth/google.py +++ /dev/null @@ -1,58 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User - - -class GoogleOAuthProvider(OAuthProvider): - id = "google" - env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"] - authorize_url = "https://accounts.google.com/o/oauth2/v2/auth" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET") - self.authorize_params = { - "scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", - "response_type": "code", - "access_type": "offline", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - "https://oauth2.googleapis.com/token", - data=payload, - ) - response.raise_for_status() - json = response.json() - token = json.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://www.googleapis.com/userinfo/v2/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - google_user = response.json() - user = User( - identifier=google_user["email"], - metadata={"image": google_user["picture"], "provider": "google"}, - ) - return (google_user, user) diff --git a/backend/chainlit/oauth/oauth_provider.py b/backend/chainlit/oauth/oauth_provider.py deleted file mode 100644 index 4db4d0d413..0000000000 --- a/backend/chainlit/oauth/oauth_provider.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -from typing import Dict, List, Tuple - -from chainlit.user import User - - -class OAuthProvider: - id: str - env: List[str] - client_id: str - client_secret: str - authorize_url: str - authorize_params: Dict[str, str] - - def is_configured(self): - return all([os.environ.get(env) for env in self.env]) - - async def get_token(self, code: str, url: str) -> str: - raise NotImplementedError() - - async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]: - raise NotImplementedError() diff --git a/backend/chainlit/oauth/okta_oauth_provider.py b/backend/chainlit/oauth/okta_oauth_provider.py deleted file mode 100644 index a531ddadb0..0000000000 --- a/backend/chainlit/oauth/okta_oauth_provider.py +++ /dev/null @@ -1,78 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User - - -class OktaOAuthProvider(OAuthProvider): - id = "okta" - env = [ - "OAUTH_OKTA_CLIENT_ID", - "OAUTH_OKTA_CLIENT_SECRET", - "OAUTH_OKTA_DOMAIN", - ] - # Avoid trailing slash in domain if supplied - domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET") - self.authorization_server_id = os.environ.get( - "OAUTH_OKTA_AUTHORIZATION_SERVER_ID", "" - ) - self.authorize_url = ( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize" - ) - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "response_mode": "query", - } - - def get_authorization_server_path(self): - if not self.authorization_server_id: - return "/default" - if self.authorization_server_id == "false": - return "" - return f"/{self.authorization_server_id}" - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token", - data=payload, - ) - response.raise_for_status() - json_data = response.json() - - token = json_data.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - okta_user = response.json() - - user = User( - identifier=okta_user.get("email"), - metadata={"image": "", "provider": "okta"}, - ) - return (okta_user, user) diff --git a/backend/chainlit/oauth/providers.py b/backend/chainlit/oauth/providers.py deleted file mode 100644 index a9510e67a7..0000000000 --- a/backend/chainlit/oauth/providers.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Optional - -from chainlit.config import config -from chainlit.oauth.auth0_oauth_provider import Auth0OAuthProvider -from chainlit.oauth.aws_cognito_oauth_provider import AWSCognitoOAuthProvider -from chainlit.oauth.azure_ad_hubrid_oauth_provider import AzureADHybridOAuthProvider -from chainlit.oauth.azure_ad_oauth_provider import AzureADOAuthProvider -from chainlit.oauth.descope_oauth_provider import DescopeOAuthProvider -from chainlit.oauth.github import GithubOAuthProvider -from chainlit.oauth.gitlab_oauth_provider import GitlabOAuthProvider -from chainlit.oauth.google import GoogleOAuthProvider -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.oauth.okta_oauth_provider import OktaOAuthProvider - -custom_oauth = config.code.custom_oauth_provider -providers = ( - [ - GithubOAuthProvider(), - GoogleOAuthProvider(), - AzureADOAuthProvider(), - AzureADHybridOAuthProvider(), - OktaOAuthProvider(), - Auth0OAuthProvider(), - DescopeOAuthProvider(), - AWSCognitoOAuthProvider(), - GitlabOAuthProvider(), - ] - + [custom_oauth()] - if custom_oauth - else [] -) - - -def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: - for p in providers: - if p.id == provider: - return p - return None - - -def get_configured_oauth_providers(): - return [p.id for p in providers if p.is_configured()] diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index c8d45bcc4c..fe019859b1 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -1,40 +1,645 @@ -import warnings - -warnings.warn( - "The 'oauth_providers' module is deprecated and will be removed in a future version. " - "Please use 'oauth' instead.", - DeprecationWarning, - stacklevel=2, -) - -from chainlit.oauth.auth0_oauth_provider import Auth0OAuthProvider -from chainlit.oauth.aws_cognito_oauth_provider import AWSCognitoOAuthProvider -from chainlit.oauth.azure_ad_hubrid_oauth_provider import AzureADHybridOAuthProvider -from chainlit.oauth.azure_ad_oauth_provider import AzureADOAuthProvider -from chainlit.oauth.descope_oauth_provider import DescopeOAuthProvider -from chainlit.oauth.github import GithubOAuthProvider -from chainlit.oauth.gitlab_oauth_provider import GitlabOAuthProvider -from chainlit.oauth.google import GoogleOAuthProvider -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.oauth.okta_oauth_provider import OktaOAuthProvider -from chainlit.oauth.providers import ( - get_configured_oauth_providers, - get_oauth_provider, - providers, -) - -__all__ = [ - "providers", - "get_oauth_provider", - "get_configured_oauth_providers", - "OAuthProvider", - "GithubOAuthProvider", - "GoogleOAuthProvider", - "AzureADOAuthProvider", - "AzureADHybridOAuthProvider", - "OktaOAuthProvider", - "Auth0OAuthProvider", - "DescopeOAuthProvider", - "AWSCognitoOAuthProvider", - "GitlabOAuthProvider", +import base64 +import os +import urllib.parse +from typing import Dict, List, Optional, Tuple + +import httpx +from chainlit.secret import random_secret +from chainlit.user import User +from fastapi import HTTPException + + +class OAuthProvider: + id: str + env: List[str] + client_id: str + client_secret: str + authorize_url: str + authorize_params: Dict[str, str] + + def is_configured(self): + return all([os.environ.get(env) for env in self.env]) + + async def get_token(self, code: str, url: str) -> str: + raise NotImplementedError() + + async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]: + raise NotImplementedError() + + +class GithubOAuthProvider(OAuthProvider): + id = "github" + env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"] + authorize_url = "https://github.com/login/oauth/authorize" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET") + self.authorize_params = { + "scope": "user:email", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + } + async with httpx.AsyncClient() as client: + response = await client.post( + "https://github.com/login/oauth/access_token", + data=payload, + ) + response.raise_for_status() + content = urllib.parse.parse_qs(response.text) + token = content.get("access_token", [""])[0] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + user_response = await client.get( + "https://api.github.com/user", + headers={"Authorization": f"token {token}"}, + ) + user_response.raise_for_status() + github_user = user_response.json() + + emails_response = await client.get( + "https://api.github.com/user/emails", + headers={"Authorization": f"token {token}"}, + ) + emails_response.raise_for_status() + emails = emails_response.json() + + github_user.update({"emails": emails}) + user = User( + identifier=github_user["login"], + metadata={"image": github_user["avatar_url"], "provider": "github"}, + ) + return (github_user, user) + + +class GoogleOAuthProvider(OAuthProvider): + id = "google" + env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"] + authorize_url = "https://accounts.google.com/o/oauth2/v2/auth" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET") + self.authorize_params = { + "scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", + "response_type": "code", + "access_type": "offline", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + "https://oauth2.googleapis.com/token", + data=payload, + ) + response.raise_for_status() + json = response.json() + token = json.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://www.googleapis.com/userinfo/v2/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + google_user = response.json() + user = User( + identifier=google_user["email"], + metadata={"image": google_user["picture"], "provider": "google"}, + ) + return (google_user, user) + + +class AzureADOAuthProvider(OAuthProvider): + id = "azure-ad" + env = [ + "OAUTH_AZURE_AD_CLIENT_ID", + "OAUTH_AZURE_AD_CLIENT_SECRET", + "OAUTH_AZURE_AD_TENANT_ID", + ] + authorize_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize" + if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + ) + token_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token" + if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/token" + ) + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET") + self.authorize_params = { + "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), + "response_type": "code", + "scope": "https://graph.microsoft.com/User.Read", + "response_mode": "query", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json["access_token"] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://graph.microsoft.com/v1.0/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + azure_user = response.json() + + try: + photo_response = await client.get( + "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", + headers={"Authorization": f"Bearer {token}"}, + ) + photo_data = await photo_response.aread() + base64_image = base64.b64encode(photo_data) + azure_user["image"] = ( + f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" + ) + except Exception as e: + # Ignore errors getting the photo + pass + + user = User( + identifier=azure_user["userPrincipalName"], + metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + ) + return (azure_user, user) + + +class AzureADHybridOAuthProvider(OAuthProvider): + id = "azure-ad-hybrid" + env = [ + "OAUTH_AZURE_AD_HYBRID_CLIENT_ID", + "OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET", + "OAUTH_AZURE_AD_HYBRID_TENANT_ID", + ] + authorize_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize" + if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + ) + token_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token" + if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/token" + ) + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET") + nonce = random_secret(16) + self.authorize_params = { + "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"), + "response_type": "code id_token", + "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid", + "response_mode": "form_post", + "nonce": nonce, + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json["access_token"] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://graph.microsoft.com/v1.0/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + azure_user = response.json() + + try: + photo_response = await client.get( + "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", + headers={"Authorization": f"Bearer {token}"}, + ) + photo_data = await photo_response.aread() + base64_image = base64.b64encode(photo_data) + azure_user["image"] = ( + f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" + ) + except Exception as e: + # Ignore errors getting the photo + pass + + user = User( + identifier=azure_user["userPrincipalName"], + metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + ) + return (azure_user, user) + + +class OktaOAuthProvider(OAuthProvider): + id = "okta" + env = [ + "OAUTH_OKTA_CLIENT_ID", + "OAUTH_OKTA_CLIENT_SECRET", + "OAUTH_OKTA_DOMAIN", + ] + # Avoid trailing slash in domain if supplied + domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET") + self.authorization_server_id = os.environ.get( + "OAUTH_OKTA_AUTHORIZATION_SERVER_ID", "" + ) + self.authorize_url = ( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize" + ) + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "response_mode": "query", + } + + def get_authorization_server_path(self): + if not self.authorization_server_id: + return "/default" + if self.authorization_server_id == "false": + return "" + return f"/{self.authorization_server_id}" + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token", + data=payload, + ) + response.raise_for_status() + json_data = response.json() + + token = json_data.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + okta_user = response.json() + + user = User( + identifier=okta_user.get("email"), + metadata={"image": "", "provider": "okta"}, + ) + return (okta_user, user) + + +class Auth0OAuthProvider(OAuthProvider): + id = "auth0" + env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"] + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET") + # Ensure that the domain does not have a trailing slash + self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}" + self.original_domain = ( + f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}" + if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN") + else self.domain + ) + + self.authorize_url = f"{self.domain}/authorize" + + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "audience": f"{self.original_domain}/userinfo", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.original_domain}/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + auth0_user = response.json() + user = User( + identifier=auth0_user.get("email"), + metadata={ + "image": auth0_user.get("picture", ""), + "provider": "auth0", + }, + ) + return (auth0_user, user) + + +class DescopeOAuthProvider(OAuthProvider): + id = "descope" + env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"] + # Ensure that the domain does not have a trailing slash + domain = f"https://api.descope.com/oauth2/v1" + + authorize_url = f"{domain}/authorize" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET") + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "audience": f"{self.domain}/userinfo", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} + ) + response.raise_for_status() # This will raise an exception for 4xx/5xx responses + descope_user = response.json() + + user = User( + identifier=descope_user.get("email"), + metadata={"image": "", "provider": "descope"}, + ) + return (descope_user, user) + + +class AWSCognitoOAuthProvider(OAuthProvider): + id = "aws-cognito" + env = [ + "OAUTH_COGNITO_CLIENT_ID", + "OAUTH_COGNITO_CLIENT_SECRET", + "OAUTH_COGNITO_DOMAIN", + ] + authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login" + token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET") + self.authorize_params = { + "response_type": "code", + "client_id": self.client_id, + "scope": "openid profile email", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + user_info_url = ( + f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo" + ) + async with httpx.AsyncClient() as client: + response = await client.get( + user_info_url, + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + cognito_user = response.json() + + # Customize user metadata as needed + user = User( + identifier=cognito_user["email"], + metadata={ + "image": cognito_user.get("picture", ""), + "provider": "aws-cognito", + }, + ) + return (cognito_user, user) + + +class GitlabOAuthProvider(OAuthProvider): + id = "gitlab" + env = [ + "OAUTH_GITLAB_CLIENT_ID", + "OAUTH_GITLAB_CLIENT_SECRET", + "OAUTH_GITLAB_DOMAIN", + ] + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET") + # Ensure that the domain does not have a trailing slash + self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}" + + self.authorize_url = f"{self.domain}/oauth/authorize" + + self.authorize_params = { + "scope": "openid profile email", + "response_type": "code", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/oauth/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + gitlab_user = response.json() + user = User( + identifier=gitlab_user.get("email"), + metadata={ + "image": gitlab_user.get("picture", ""), + "provider": "gitlab", + }, + ) + return (gitlab_user, user) + + +providers = [ + GithubOAuthProvider(), + GoogleOAuthProvider(), + AzureADOAuthProvider(), + AzureADHybridOAuthProvider(), + OktaOAuthProvider(), + Auth0OAuthProvider(), + DescopeOAuthProvider(), + AWSCognitoOAuthProvider(), + GitlabOAuthProvider(), ] + + +def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: + for p in providers: + if p.id == provider: + return p + return None + + +def get_configured_oauth_providers(): + return [p.id for p in providers if p.is_configured()] diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index f88ac68057..597830ee43 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -6,16 +6,6 @@ import re import shutil import urllib.parse -from typing import Any, Optional, Union - -from chainlit.oauth.providers import get_oauth_provider -from chainlit.secret import random_secret - -mimetypes.add_type("application/javascript", ".js") -mimetypes.add_type("text/css", ".css") - -import asyncio -import os import webbrowser from contextlib import asynccontextmanager from pathlib import Path