Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Twitter provider #452

Merged
merged 6 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

- Add Twitter provider for thirdparty login
- Add `Cache-Control` header for jwks endpoint `/jwt/jwks.json`
- Add `validity_in_secs` to the return value of overridable `get_jwks` recipe function.
- This can be used to control the `Cache-Control` header mentioned above.
Expand Down
4 changes: 3 additions & 1 deletion supertokens_python/recipe/thirdparty/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ async def sign_in_up_post(
user_context=user_context,
)

if user_info.email is None and not provider.config.require_email:
if user_info.email is None and provider.config.require_email is False:
# We don't expect to get an email from this provider.
# So we generate a fake one
if provider.config.generate_fake_email is not None:
user_info.email = UserInfoEmail(
email=await provider.config.generate_fake_email(
Expand Down
10 changes: 5 additions & 5 deletions supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
client_secret: Optional[str] = None,
client_type: Optional[str] = None,
scope: Optional[List[str]] = None,
force_pkce: bool = False,
force_pkce: Optional[bool] = None,
additional_config: Optional[Dict[str, Any]] = None,
):
self.client_id = client_id
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
require_email: Optional[bool] = None,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(
client_secret: Optional[str] = None,
client_type: Optional[str] = None,
scope: Optional[List[str]] = None,
force_pkce: bool = False,
force_pkce: Optional[bool] = None,
additional_config: Optional[Dict[str, Any]] = None,
# CommonProviderConfig:
third_party_id: str = "temp",
Expand All @@ -240,7 +240,7 @@ def __init__(
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
require_email: Optional[bool] = None,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Expand Down Expand Up @@ -303,7 +303,7 @@ def __init__(
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
require_email: Optional[bool] = None,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .google_workspaces import GoogleWorkspaces
from .google import Google
from .linkedin import Linkedin
from .twitter import Twitter
from .okta import Okta
from .custom import NewProvider
from .utils import do_get_request
Expand Down Expand Up @@ -82,11 +83,7 @@ def merge_config(
if config_from_core.oidc_discovery_endpoint is None
else config_from_core.oidc_discovery_endpoint
),
require_email=(
config_from_static.require_email
if config_from_core.require_email is None
else config_from_core.require_email
),
require_email=config_from_static.require_email,
user_info_map=config_from_static.user_info_map,
generate_fake_email=config_from_static.generate_fake_email,
validate_id_token_payload=config_from_static.validate_id_token_payload,
Expand Down Expand Up @@ -206,6 +203,8 @@ def create_provider(provider_input: ProviderInput) -> Provider:
return Okta(provider_input)
if provider_input.config.third_party_id.startswith("linkedin"):
return Linkedin(provider_input)
if provider_input.config.third_party_id.startswith("twitter"):
return Twitter(provider_input)
if provider_input.config.third_party_id.startswith("boxy-saml"):
return BoxySAML(provider_input)

Expand Down
7 changes: 6 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
do_get_request,
do_post_request,
get_actual_client_id_from_development_client_id,
is_using_oauth_development_client_id,
is_using_oauth_development_client_id, DEV_KEY_IDENTIFIER, DEV_OAUTH_CLIENT_IDS,
)

from ..types import RawUserInfoFromProvider, UserInfo, UserInfoEmail
Expand Down Expand Up @@ -180,6 +180,11 @@ def merge_into_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]
return res


def is_using_development_client_id(client_id):
return client_id.startswith(DEV_KEY_IDENTIFIER) or client_id in DEV_OAUTH_CLIENT_IDS



class GenericProvider(Provider):
def __init__(self, provider_config: ProviderConfig):
self.input_config = input_config = self._normalize_input(provider_config)
Expand Down
113 changes: 113 additions & 0 deletions supertokens_python/recipe/thirdparty/providers/twitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from base64 import b64encode
from typing import Any, Dict, Optional
from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo
from supertokens_python.recipe.thirdparty.providers.utils import do_post_request, DEV_OAUTH_REDIRECT_URL, \
get_actual_client_id_from_development_client_id
from ..provider import (
Provider,
ProviderConfigForClient,
ProviderInput,
UserFields,
UserInfoMap,
)

from .custom import (
GenericProvider,
NewProvider, is_using_development_client_id,
)


class TwitterImpl(GenericProvider):
async def get_config_for_client_type(
self, client_type: Optional[str], user_context: Dict[str, Any]
) -> ProviderConfigForClient:
config = await super().get_config_for_client_type(client_type, user_context)

if config.scope is None:
config.scope = ["users.read", "tweet.read"]

if config.force_pkce is None:
config.force_pkce = True

return config

async def exchange_auth_code_for_oauth_tokens(
self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any]
) -> Dict[str, Any]:

client_id = self.config.client_id
redirect_uri = redirect_uri_info.redirect_uri_on_provider_dashboard

# We need to do this because we don't call the original implementation
# Transformation needed for dev keys BEGIN
if is_using_development_client_id(self.config.client_id):
client_id = get_actual_client_id_from_development_client_id(self.config.client_id)
redirect_uri = DEV_OAUTH_REDIRECT_URL
# Transformation needed for dev keys END

credentials = client_id + ":" + (self.config.client_secret or "")
auth_token = b64encode(credentials.encode()).decode()

twitter_oauth_tokens_params: Dict[str, Any] = {
"grant_type": "authorization_code",
"client_id": client_id,
"code_verifier": redirect_uri_info.pkce_code_verifier,
"redirect_uri": redirect_uri,
"code": redirect_uri_info.redirect_uri_query_params["code"],
}

twitter_oauth_tokens_params = {
**twitter_oauth_tokens_params,
**(self.config.token_endpoint_body_params or {}),
}

assert self.config.token_endpoint is not None

return await do_post_request(
self.config.token_endpoint,
body_params=twitter_oauth_tokens_params,
headers={"Authorization": f"Basic {auth_token}"},
)


def Twitter(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Twitter"

if input.config.authorization_endpoint is None:
input.config.authorization_endpoint = "https://twitter.com/i/oauth2/authorize"

if input.config.token_endpoint is None:
input.config.token_endpoint = "https://api.twitter.com/2/oauth2/token"

if input.config.user_info_endpoint is None:
input.config.user_info_endpoint = "https://api.twitter.com/2/users/me"

if input.config.require_email is None:
input.config.require_email = False

if input.config.user_info_map is None:
input.config.user_info_map = UserInfoMap(UserFields(), UserFields())

if input.config.user_info_map.from_user_info_api is None:
input.config.user_info_map.from_user_info_api = UserFields()

if input.config.user_info_map.from_user_info_api.user_id is None:
input.config.user_info_map.from_user_info_api.user_id = "data.id"

return NewProvider(input, TwitterImpl)
Loading