diff --git a/ytmusicapi/auth/oauth/__init__.py b/ytmusicapi/auth/oauth/__init__.py index f84e63fe..fa1f6624 100644 --- a/ytmusicapi/auth/oauth/__init__.py +++ b/ytmusicapi/auth/oauth/__init__.py @@ -1,5 +1,4 @@ -from .base import OAuthToken from .credentials import OAuthCredentials -from .refreshing import RefreshingToken +from .token import OAuthToken, RefreshingToken -__all__ = ["OAuthCredentials", "RefreshingToken", "OAuthToken"] +__all__ = ["OAuthCredentials", "OAuthToken", "RefreshingToken"] diff --git a/ytmusicapi/auth/oauth/base.py b/ytmusicapi/auth/oauth/base.py deleted file mode 100644 index dbbe4c6d..00000000 --- a/ytmusicapi/auth/oauth/base.py +++ /dev/null @@ -1,148 +0,0 @@ -import json -import time -from abc import ABC -from typing import Mapping, Optional - -from requests.structures import CaseInsensitiveDict - -from .models import BaseTokenDict, Bearer, DefaultScope, RefreshableTokenDict - - -class Credentials: - """Base class representation of YouTubeMusicAPI OAuth Credentials""" - - client_id: str - client_secret: str - - def get_code(self) -> Mapping: - raise NotImplementedError() - - def token_from_code(self, device_code: str) -> RefreshableTokenDict: - raise NotImplementedError() - - def refresh_token(self, refresh_token: str) -> BaseTokenDict: - raise NotImplementedError() - - -class Token(ABC): - """Base class representation of the YouTubeMusicAPI OAuth token.""" - - _access_token: str - _refresh_token: str - _expires_in: int - _expires_at: int - _is_expiring: bool - - _scope: DefaultScope - _token_type: Bearer - - def __repr__(self) -> str: - """Readable version.""" - return f"{self.__class__.__name__}: {self.as_dict()}" - - def as_dict(self) -> RefreshableTokenDict: - """Returns dictionary containing underlying token values.""" - return { - "access_token": self.access_token, - "refresh_token": self.refresh_token, - "scope": self.scope, - "expires_at": self.expires_at, - "expires_in": self.expires_in, - "token_type": self.token_type, - } - - def as_json(self) -> str: - return json.dumps(self.as_dict()) - - def as_auth(self) -> str: - """Returns Authorization header ready str of token_type and access_token.""" - return f"{self.token_type} {self.access_token}" - - @property - def access_token(self) -> str: - return self._access_token - - @property - def refresh_token(self) -> str: - return self._refresh_token - - @property - def token_type(self) -> Bearer: - return self._token_type - - @property - def scope(self) -> DefaultScope: - return self._scope - - @property - def expires_at(self) -> int: - return self._expires_at - - @property - def expires_in(self) -> int: - return self._expires_in - - @property - def is_expiring(self) -> bool: - return self.expires_in < 60 - - -class OAuthToken(Token): - """Wrapper for an OAuth token implementing expiration methods.""" - - def __init__( - self, - access_token: str, - refresh_token: str, - scope: str, - token_type: str, - expires_at: Optional[int] = None, - expires_in: int = 0, - ): - """ - - :param access_token: active oauth key - :param refresh_token: access_token's matching oauth refresh string - :param scope: most likely 'https://www.googleapis.com/auth/youtube' - :param token_type: commonly 'Bearer' - :param expires_at: Optional. Unix epoch (seconds) of access token expiration. - :param expires_in: Optional. Seconds till expiration, assumes/calculates epoch of init. - - """ - # match baseclass attribute/property format - self._access_token = access_token - self._refresh_token = refresh_token - self._scope = scope - self._token_type = token_type - - # set/calculate token expiration using current epoch - self._expires_at: int = expires_at if expires_at else int(time.time()) + expires_in - self._expires_in: int = expires_in - - @staticmethod - def is_oauth(headers: CaseInsensitiveDict) -> bool: - oauth_structure = { - "access_token", - "expires_at", - "expires_in", - "token_type", - "refresh_token", - } - return all(key in headers for key in oauth_structure) - - def update(self, fresh_access: BaseTokenDict): - """ - Update access_token and expiration attributes with a BaseTokenDict inplace. - expires_at attribute set using current epoch, avoid expiration desync - by passing only recently requested tokens dicts or updating values to compensate. - """ - self._access_token = fresh_access["access_token"] - self._expires_at = int(time.time() + fresh_access["expires_in"]) - - @property - def expires_in(self) -> int: - return int(self.expires_at - time.time()) - - @property - def is_expiring(self) -> bool: - return self.expires_in < 60 diff --git a/ytmusicapi/auth/oauth/credentials.py b/ytmusicapi/auth/oauth/credentials.py index bdf12f76..80104fff 100644 --- a/ytmusicapi/auth/oauth/credentials.py +++ b/ytmusicapi/auth/oauth/credentials.py @@ -1,5 +1,6 @@ -import webbrowser -from typing import Dict, Optional +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, Mapping, Optional import requests @@ -12,10 +13,29 @@ OAUTH_USER_AGENT, ) -from .base import Credentials, OAuthToken from .exceptions import BadOAuthClient, UnauthorizedOAuthClient from .models import AuthCodeDict, BaseTokenDict, RefreshableTokenDict -from .refreshing import RefreshingToken + + +@dataclass +class Credentials(ABC): + """Base class representation of YouTubeMusicAPI OAuth Credentials""" + + client_id: str + client_secret: str + + @abstractmethod + def get_code(self) -> Mapping: + """Method for obtaining a new user auth code. First step of token creation.""" + + @abstractmethod + def token_from_code(self, device_code: str) -> RefreshableTokenDict: + """Method for verifying user auth code and conversion into a FullTokenDict.""" + + @abstractmethod + def refresh_token(self, refresh_token: str) -> BaseTokenDict: + """Method for requesting a new access token for a given refresh_token. + Token must have been created by the same OAuth client.""" class OAuthCredentials(Credentials): @@ -90,25 +110,6 @@ def token_from_code(self, device_code: str) -> RefreshableTokenDict: ) return response.json() - def prompt_for_token(self, open_browser: bool = False, to_file: Optional[str] = None) -> RefreshingToken: - """ - Method for CLI token creation via user inputs. - - :param open_browser: Optional. Open browser to OAuth consent url automatically. (Default = False). - :param to_file: Optional. Path to store/sync json version of resulting token. (Default = None). - """ - - code = self.get_code() - url = f"{code['verification_url']}?user_code={code['user_code']}" - if open_browser: - webbrowser.open(url) - input(f"Go to {url}, finish the login flow and press Enter when done, Ctrl-C to abort") - raw_token = self.token_from_code(code["device_code"]) - ref_token = RefreshingToken(OAuthToken(**raw_token), credentials=self) - if to_file: - ref_token.local_cache = to_file - return ref_token - def refresh_token(self, refresh_token: str) -> BaseTokenDict: """ Method for requesting a new access token for a given refresh_token. diff --git a/ytmusicapi/auth/oauth/refreshing.py b/ytmusicapi/auth/oauth/refreshing.py deleted file mode 100644 index e6243b52..00000000 --- a/ytmusicapi/auth/oauth/refreshing.py +++ /dev/null @@ -1,87 +0,0 @@ -import json -import os -from typing import Optional - -from .base import Credentials, OAuthToken, Token -from .models import Bearer, RefreshableTokenDict - - -class RefreshingToken(Token): - """ - Compositional implementation of Token that automatically refreshes - an underlying OAuthToken when required (credential expiration <= 1 min) - upon access_token attribute access. - """ - - @classmethod - def from_file(cls, file_path: str, credentials: Credentials, sync=True): - """ - Initialize a refreshing token and underlying OAuthToken directly from a file. - - :param file_path: path to json containing token values - :param credentials: credentials used with token in file. - :param sync: Optional. Whether to pass the filepath into instance enabling file - contents to be updated upon refresh. (Default=True). - :return: RefreshingToken instance - :rtype: RefreshingToken - """ - - if os.path.isfile(file_path): - with open(file_path) as json_file: - file_pack = json.load(json_file) - - return cls(OAuthToken(**file_pack), credentials, file_path if sync else None) - - def __init__(self, token: OAuthToken, credentials: Credentials, local_cache: Optional[str] = None): - """ - :param token: Underlying Token being maintained. - :param credentials: OAuth client being used for refreshing. - :param local_cache: Optional. Path to json file where token values are stored. - When provided, file contents is updated upon token refresh. - """ - - self.token: OAuthToken = token #: internal token being used / refreshed / maintained - self.credentials = credentials #: credentials used for access_token refreshing - - #: protected/property attribute enables auto writing token - # values to new file location via setter - self._local_cache = local_cache - - @property - def token_type(self) -> Bearer: - return self.token.token_type - - @property - def local_cache(self) -> str | None: - return self._local_cache - - @local_cache.setter - def local_cache(self, path: str): - """Update attribute and dump token to new path.""" - self._local_cache = path - self.store_token() - - @property - def access_token(self) -> str: - if self.token.is_expiring: - fresh = self.credentials.refresh_token(self.token.refresh_token) - self.token.update(fresh) - self.store_token() - - return self.token.access_token - - def store_token(self, path: Optional[str] = None) -> None: - """ - Write token values to json file at specified path, defaulting to self.local_cache. - Operation does not update instance local_cache attribute. - Automatically called when local_cache is set post init. - """ - file_path = path if path else self.local_cache - - if file_path: - with open(file_path, encoding="utf8", mode="w") as file: - json.dump(self.token.as_dict(), file, indent=True) - - def as_dict(self) -> RefreshableTokenDict: - # override base class method with call to underlying token's method - return self.token.as_dict() diff --git a/ytmusicapi/auth/oauth/token.py b/ytmusicapi/auth/oauth/token.py new file mode 100644 index 00000000..563c3fff --- /dev/null +++ b/ytmusicapi/auth/oauth/token.py @@ -0,0 +1,155 @@ +import json +import os +import time +import webbrowser +from dataclasses import dataclass +from typing import Optional + +from requests.structures import CaseInsensitiveDict + +from ytmusicapi.auth.oauth.credentials import Credentials +from ytmusicapi.auth.oauth.models import BaseTokenDict, Bearer, DefaultScope, RefreshableTokenDict + + +@dataclass +class Token: + """Base class representation of the YouTubeMusicAPI OAuth token.""" + + scope: DefaultScope + token_type: Bearer + + access_token: str + refresh_token: str + expires_at: int + expires_in: int = 0 + + def __repr__(self) -> str: + """Readable version.""" + return f"{self.__class__.__name__}: {self.as_dict()}" + + def as_dict(self) -> RefreshableTokenDict: + """Returns dictionary containing underlying token values.""" + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "scope": self.scope, + "expires_at": self.expires_at, + "expires_in": self.expires_in, + "token_type": self.token_type, + } + + def as_json(self) -> str: + return json.dumps(self.as_dict()) + + def as_auth(self) -> str: + """Returns Authorization header ready str of token_type and access_token.""" + return f"{self.token_type} {self.access_token}" + + @property + def is_expiring(self) -> bool: + return self.expires_in < 60 + + +class OAuthToken(Token): + """Wrapper for an OAuth token implementing expiration methods.""" + + @staticmethod + def is_oauth(headers: CaseInsensitiveDict) -> bool: + oauth_structure = { + "access_token", + "expires_at", + "expires_in", + "token_type", + "refresh_token", + } + return all(key in headers for key in oauth_structure) + + def update(self, fresh_access: BaseTokenDict): + """ + Update access_token and expiration attributes with a BaseTokenDict inplace. + expires_at attribute set using current epoch, avoid expiration desync + by passing only recently requested tokens dicts or updating values to compensate. + """ + self.access_token = fresh_access["access_token"] + self.expires_at = int(time.time()) + fresh_access["expires_in"] + + @property + def is_expiring(self) -> bool: + return self.expires_at - int(time.time()) < 60 + + @classmethod + def from_json(cls, file_path: str) -> "OAuthToken": + if os.path.isfile(file_path): + with open(file_path) as json_file: + file_pack = json.load(json_file) + + return cls(**file_pack) + + +@dataclass +class RefreshingToken(OAuthToken): + """ + Compositional implementation of Token that automatically refreshes + an underlying OAuthToken when required (credential expiration <= 1 min) + upon access_token attribute access. + """ + + #: credentials used for access_token refreshing + credentials: Optional[Credentials] = None + + #: protected/property attribute enables auto writing token values to new file location via setter + _local_cache: Optional[str] = None + + def __getattr__(self, item): + """access token setter to auto-refresh if it is expiring""" + if item == "access_token" and self.is_expiring: + fresh = self.credentials.refresh_token(self.refresh_token) + self.update(fresh) + self.store_token() + + return super().__getattribute__(item) + + @property + def local_cache(self) -> Optional[str]: + return self._local_cache + + @local_cache.setter + def local_cache(self, path: str): + """Update attribute and dump token to new path.""" + self._local_cache = path + self.store_token() + + @classmethod + def prompt_for_token( + cls, credentials: Credentials, open_browser: bool = False, to_file: Optional[str] = None + ) -> "RefreshingToken": + """ + Method for CLI token creation via user inputs. + + :param credentials: Client credentials + :param open_browser: Optional. Open browser to OAuth consent url automatically. (Default = False). + :param to_file: Optional. Path to store/sync json version of resulting token. (Default = None). + """ + + code = credentials.get_code() + url = f"{code['verification_url']}?user_code={code['user_code']}" + if open_browser: + webbrowser.open(url) + input(f"Go to {url}, finish the login flow and press Enter when done, Ctrl-C to abort") + raw_token = credentials.token_from_code(code["device_code"]) + ref_token = cls(credentials=credentials, **raw_token) + if to_file: + ref_token.local_cache = to_file + return ref_token + + def store_token(self, path: Optional[str] = None) -> None: + """ + Write token values to json file at specified path, defaulting to self.local_cache. + Operation does not update instance local_cache attribute. + Automatically called when local_cache is set post init. + """ + file_path = path if path else self.local_cache + + if file_path: + with open(file_path, encoding="utf8", mode="w") as file: + json.dump(self.as_dict(), file, indent=True) diff --git a/ytmusicapi/mixins/browsing.py b/ytmusicapi/mixins/browsing.py index f5b6e3b5..75b49bbc 100644 --- a/ytmusicapi/mixins/browsing.py +++ b/ytmusicapi/mixins/browsing.py @@ -345,7 +345,7 @@ def get_user_playlists(self, channelId: str, params: str) -> List[Dict]: return user_playlists - def get_album_browse_id(self, audioPlaylistId: str) -> str | None: + def get_album_browse_id(self, audioPlaylistId: str) -> Optional[str]: """ Get an album's browseId based on its audioPlaylistId diff --git a/ytmusicapi/setup.py b/ytmusicapi/setup.py index 47d68bc5..a76c5dfd 100644 --- a/ytmusicapi/setup.py +++ b/ytmusicapi/setup.py @@ -53,7 +53,7 @@ def setup_oauth( else: oauth_credentials = OAuthCredentials(session=session, proxies=proxies) - return oauth_credentials.prompt_for_token(open_browser, filepath) + return RefreshingToken.prompt_for_token(oauth_credentials, open_browser, filepath) def parse_args(args): diff --git a/ytmusicapi/ytmusic.py b/ytmusicapi/ytmusic.py index 81892718..147870ea 100644 --- a/ytmusicapi/ytmusic.py +++ b/ytmusicapi/ytmusic.py @@ -34,7 +34,7 @@ from ytmusicapi.parsers.i18n import Parser from .auth.oauth import OAuthCredentials, OAuthToken, RefreshingToken -from .auth.oauth.base import Token +from .auth.oauth.token import Token from .auth.types import AuthType @@ -142,8 +142,9 @@ def __init__( self._input_dict = CaseInsensitiveDict(self.auth) if OAuthToken.is_oauth(self._input_dict): - base_token = OAuthToken(**self._input_dict) - self._token = RefreshingToken(base_token, self.oauth_credentials, auth_filepath) + self._token = RefreshingToken( + credentials=self.oauth_credentials, _local_cache=auth_filepath, **self._input_dict + ) self.auth_type = AuthType.OAUTH_CUSTOM_CLIENT if oauth_credentials else AuthType.OAUTH_DEFAULT # prepare context