Skip to content

Commit

Permalink
Add timeout to all requests
Browse files Browse the repository at this point in the history
  • Loading branch information
U1F984 committed Dec 11, 2024
1 parent 11bf3c0 commit 9c7ae84
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions trakt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
#: The ID of the application to register with, when using PIN authentication
APPLICATION_ID = None

#: Timeout in seconds for all requests
TIMEOUT = 30

#: Global session to make requests with
session = requests.Session()

Expand Down Expand Up @@ -141,7 +144,7 @@ def pin_auth(pin=None, client_id=None, client_secret=None, store=False):
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET}

response = session.post(''.join([BASE_URL, '/oauth/token']), data=args)
response = session.post(''.join([BASE_URL, '/oauth/token']), data=args, timeout=TIMEOUT)
OAUTH_TOKEN = response.json().get('access_token', None)

if store:
Expand Down Expand Up @@ -231,12 +234,12 @@ def get_device_code(client_id=None, client_secret=None):
data = {"client_id": CLIENT_ID}

device_response = session.post(device_code_url,
json=data, headers=headers).json()
json=data, headers=headers, timeout=TIMEOUT).json()
print('Your user code is: {user_code}, please navigate to '
'{verification_url} to authenticate'.format(
user_code=device_response.get('user_code'),
verification_url=device_response.get('verification_url')
))
user_code=device_response.get('user_code'),
verification_url=device_response.get('verification_url')
))

device_response['requested'] = time.time()
return device_response
Expand Down Expand Up @@ -272,7 +275,7 @@ def get_device_token(device_code, client_id=None, client_secret=None,
}

response = session.post(
urljoin(BASE_URL, '/oauth/device/token'), json=data
urljoin(BASE_URL, '/oauth/device/token'), json=data, timeout=TIMEOUT
)

# We only get json on success.
Expand Down Expand Up @@ -409,13 +412,13 @@ def _refresh_token(s):
s.logger.info("OAuth token has expired, refreshing now...")
url = urljoin(BASE_URL, '/oauth/token')
data = {
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'refresh_token': OAUTH_REFRESH,
'redirect_uri': REDIRECT_URI,
'grant_type': 'refresh_token'
}
response = session.post(url, json=data, headers=HEADERS)
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'refresh_token': OAUTH_REFRESH,
'redirect_uri': REDIRECT_URI,
'grant_type': 'refresh_token'
}
response = session.post(url, json=data, headers=HEADERS, timeout=TIMEOUT)
s.logger.debug('RESPONSE [post] (%s): %s - %s', url, str(response), response.content)
if response.status_code == 200:
data = response.json()
Expand Down Expand Up @@ -542,10 +545,10 @@ def _handle_request(self, method, url, data=None):
self.logger.debug('method, url :: %s, %s', method, url)
if method == 'get': # GETs need to pass data as params, not body
response = session.request(method, url, headers=HEADERS,
params=data)
params=data, timeout=TIMEOUT)
else:
response = session.request(method, url, headers=HEADERS,
data=json.dumps(data))
data=json.dumps(data), timeout=TIMEOUT)
self.logger.debug('RESPONSE [%s] (%s): %s', method, url, str(response))
if response.status_code in self.error_map:
raise self.error_map[response.status_code](response)
Expand All @@ -568,6 +571,7 @@ def get(self, f):
results
:return: The results of the generator co-routine
"""

@wraps(f)
def inner(*args, **kwargs):
self._bootstrap()
Expand All @@ -581,20 +585,23 @@ def inner(*args, **kwargs):
return generator.send(json_data)
except StopIteration:
return None

return inner

def delete(self, f):
"""Perform an HTTP DELETE request using the provided uri
:param f: Function that returns a uri to delete to
"""

@wraps(f)
def inner(*args, **kwargs):
self._bootstrap()
generator = f(*args, **kwargs)
uri = next(generator)
url = BASE_URL + uri
self._handle_request('delete', url)

return inner

def post(self, f):
Expand All @@ -607,6 +614,7 @@ def post(self, f):
results
:return: The results of the generator co-routine
"""

@wraps(f)
def inner(*args, **kwargs):
self._bootstrap()
Expand All @@ -616,6 +624,7 @@ def inner(*args, **kwargs):
return generator.send(json_data)
except StopIteration:
return None

return inner

def put(self, f):
Expand All @@ -628,6 +637,7 @@ def put(self, f):
results
:return: The results of the generator co-routine
"""

@wraps(f)
def inner(*args, **kwargs):
self._bootstrap()
Expand All @@ -637,6 +647,7 @@ def inner(*args, **kwargs):
return generator.send(json_data)
except StopIteration:
return None

return inner


Expand Down

0 comments on commit 9c7ae84

Please sign in to comment.