Skip to content

Commit

Permalink
Merge pull request #452 from supertokens/feat/twitter-provider
Browse files Browse the repository at this point in the history
feat: Add Twitter provider
  • Loading branch information
rishabhpoddar authored Sep 28, 2023
2 parents 2bdd8b1 + 75615ef commit 4eab8c8
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

- Add Twitter provider for thirdparty login
- Add `Cache-Control` header for jwks endpoint `/jwt/jwks.json`
- Add `validity_in_secs` to the return value of overridable `get_jwks` recipe function.
- This can be used to control the `Cache-Control` header mentioned above.
Expand Down
4 changes: 3 additions & 1 deletion supertokens_python/recipe/thirdparty/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ async def sign_in_up_post(
user_context=user_context,
)

if user_info.email is None and not provider.config.require_email:
if user_info.email is None and provider.config.require_email is False:
# We don't expect to get an email from this provider.
# So we generate a fake one
if provider.config.generate_fake_email is not None:
user_info.email = UserInfoEmail(
email=await provider.config.generate_fake_email(
Expand Down
10 changes: 5 additions & 5 deletions supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
client_secret: Optional[str] = None,
client_type: Optional[str] = None,
scope: Optional[List[str]] = None,
force_pkce: bool = False,
force_pkce: Optional[bool] = None,
additional_config: Optional[Dict[str, Any]] = None,
):
self.client_id = client_id
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
require_email: Optional[bool] = None,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(
client_secret: Optional[str] = None,
client_type: Optional[str] = None,
scope: Optional[List[str]] = None,
force_pkce: bool = False,
force_pkce: Optional[bool] = None,
additional_config: Optional[Dict[str, Any]] = None,
# CommonProviderConfig:
third_party_id: str = "temp",
Expand All @@ -240,7 +240,7 @@ def __init__(
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
require_email: Optional[bool] = None,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Expand Down Expand Up @@ -303,7 +303,7 @@ def __init__(
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
require_email: Optional[bool] = None,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .google_workspaces import GoogleWorkspaces
from .google import Google
from .linkedin import Linkedin
from .twitter import Twitter
from .okta import Okta
from .custom import NewProvider
from .utils import do_get_request
Expand Down Expand Up @@ -82,11 +83,7 @@ def merge_config(
if config_from_core.oidc_discovery_endpoint is None
else config_from_core.oidc_discovery_endpoint
),
require_email=(
config_from_static.require_email
if config_from_core.require_email is None
else config_from_core.require_email
),
require_email=config_from_static.require_email,
user_info_map=config_from_static.user_info_map,
generate_fake_email=config_from_static.generate_fake_email,
validate_id_token_payload=config_from_static.validate_id_token_payload,
Expand Down Expand Up @@ -206,6 +203,8 @@ def create_provider(provider_input: ProviderInput) -> Provider:
return Okta(provider_input)
if provider_input.config.third_party_id.startswith("linkedin"):
return Linkedin(provider_input)
if provider_input.config.third_party_id.startswith("twitter"):
return Twitter(provider_input)
if provider_input.config.third_party_id.startswith("boxy-saml"):
return BoxySAML(provider_input)

Expand Down
7 changes: 6 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
do_get_request,
do_post_request,
get_actual_client_id_from_development_client_id,
is_using_oauth_development_client_id,
is_using_oauth_development_client_id, DEV_KEY_IDENTIFIER, DEV_OAUTH_CLIENT_IDS,
)

from ..types import RawUserInfoFromProvider, UserInfo, UserInfoEmail
Expand Down Expand Up @@ -180,6 +180,11 @@ def merge_into_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]
return res


def is_using_development_client_id(client_id):
return client_id.startswith(DEV_KEY_IDENTIFIER) or client_id in DEV_OAUTH_CLIENT_IDS



class GenericProvider(Provider):
def __init__(self, provider_config: ProviderConfig):
self.input_config = input_config = self._normalize_input(provider_config)
Expand Down
113 changes: 113 additions & 0 deletions supertokens_python/recipe/thirdparty/providers/twitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from base64 import b64encode
from typing import Any, Dict, Optional
from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo
from supertokens_python.recipe.thirdparty.providers.utils import do_post_request, DEV_OAUTH_REDIRECT_URL, \
get_actual_client_id_from_development_client_id
from ..provider import (
Provider,
ProviderConfigForClient,
ProviderInput,
UserFields,
UserInfoMap,
)

from .custom import (
GenericProvider,
NewProvider, is_using_development_client_id,
)


class TwitterImpl(GenericProvider):
async def get_config_for_client_type(
self, client_type: Optional[str], user_context: Dict[str, Any]
) -> ProviderConfigForClient:
config = await super().get_config_for_client_type(client_type, user_context)

if config.scope is None:
config.scope = ["users.read", "tweet.read"]

if config.force_pkce is None:
config.force_pkce = True

return config

async def exchange_auth_code_for_oauth_tokens(
self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any]
) -> Dict[str, Any]:

client_id = self.config.client_id
redirect_uri = redirect_uri_info.redirect_uri_on_provider_dashboard

# We need to do this because we don't call the original implementation
# Transformation needed for dev keys BEGIN
if is_using_development_client_id(self.config.client_id):
client_id = get_actual_client_id_from_development_client_id(self.config.client_id)
redirect_uri = DEV_OAUTH_REDIRECT_URL
# Transformation needed for dev keys END

credentials = client_id + ":" + (self.config.client_secret or "")
auth_token = b64encode(credentials.encode()).decode()

twitter_oauth_tokens_params: Dict[str, Any] = {
"grant_type": "authorization_code",
"client_id": client_id,
"code_verifier": redirect_uri_info.pkce_code_verifier,
"redirect_uri": redirect_uri,
"code": redirect_uri_info.redirect_uri_query_params["code"],
}

twitter_oauth_tokens_params = {
**twitter_oauth_tokens_params,
**(self.config.token_endpoint_body_params or {}),
}

assert self.config.token_endpoint is not None

return await do_post_request(
self.config.token_endpoint,
body_params=twitter_oauth_tokens_params,
headers={"Authorization": f"Basic {auth_token}"},
)


def Twitter(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Twitter"

if input.config.authorization_endpoint is None:
input.config.authorization_endpoint = "https://twitter.com/i/oauth2/authorize"

if input.config.token_endpoint is None:
input.config.token_endpoint = "https://api.twitter.com/2/oauth2/token"

if input.config.user_info_endpoint is None:
input.config.user_info_endpoint = "https://api.twitter.com/2/users/me"

if input.config.require_email is None:
input.config.require_email = False

if input.config.user_info_map is None:
input.config.user_info_map = UserInfoMap(UserFields(), UserFields())

if input.config.user_info_map.from_user_info_api is None:
input.config.user_info_map.from_user_info_api = UserFields()

if input.config.user_info_map.from_user_info_api.user_id is None:
input.config.user_info_map.from_user_info_api.user_id = "data.id"

return NewProvider(input, TwitterImpl)

0 comments on commit 4eab8c8

Please sign in to comment.