From e18f3131569113f74da1411373a002aacd4aa4d4 Mon Sep 17 00:00:00 2001 From: Jean Lucas Date: Fri, 5 Aug 2022 17:32:42 +0200 Subject: [PATCH] return raw response (#340) --- nucleus/__init__.py | 17 ++++------------- nucleus/connection.py | 10 +++++++++- nucleus/model.py | 14 ++++++++------ 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/nucleus/__init__.py b/nucleus/__init__.py index aefd5601..f695ba18 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -1021,6 +1021,7 @@ def make_request( payload: Optional[dict], route: str, requests_command=requests.post, + return_raw_response: bool = False, ) -> dict: """Makes a request to a Nucleus API endpoint. @@ -1030,9 +1031,10 @@ def make_request( payload: Given request payload. route: Route for the request. Requests command: ``requests.post``, ``requests.get``, or ``requests.delete``. + return_raw_response: return the request's response object entirely Returns: - Response payload as JSON dict. + Response payload as JSON dict or request object. """ if payload is None: payload = {} @@ -1042,18 +1044,7 @@ def make_request( "Received defined payload with GET request! Will ignore payload" ) payload = None - return self._connection.make_request(payload, route, requests_command) # type: ignore - - def handle_bad_response( - self, - endpoint, - requests_command, - requests_response=None, - aiohttp_response=None, - ): - self._connection.handle_bad_response( - endpoint, requests_command, requests_response, aiohttp_response - ) + return self._connection.make_request(payload, route, requests_command, return_raw_response) # type: ignore def _set_api_key(self, api_key): """Fetch API key from environment variable NUCLEUS_API_KEY if not set""" diff --git a/nucleus/connection.py b/nucleus/connection.py index 11d07ba4..162380dd 100644 --- a/nucleus/connection.py +++ b/nucleus/connection.py @@ -40,7 +40,11 @@ def put(self, payload: dict, route: str): return self.make_request(payload, route, requests_command=requests.put) def make_request( - self, payload: dict, route: str, requests_command=requests.post + self, + payload: dict, + route: str, + requests_command=requests.post, + return_raw_response: bool = False, ) -> dict: """ Makes a request to Nucleus endpoint and logs a warning if not @@ -49,6 +53,7 @@ def make_request( :param payload: given payload :param route: route for the request :param requests_command: requests.post, requests.get, requests.delete + :param return_raw_response: return the request's response object entirely :return: response JSON """ endpoint = f"{self.endpoint}/{route}" @@ -73,6 +78,9 @@ def make_request( if not response.ok: self.handle_bad_response(endpoint, requests_command, response) + if return_raw_response: + return response + return response.json() def handle_bad_response( diff --git a/nucleus/model.py b/nucleus/model.py index a3e18c74..67a2e0ce 100644 --- a/nucleus/model.py +++ b/nucleus/model.py @@ -234,16 +234,17 @@ def add_tags(self, tags: List[str]): Args: tags: list of tag names """ - response = self._client.make_request( + response: requests.Response = self._client.make_request( {MODEL_TAGS_KEY: tags}, f"model/{self.id}/tag", requests_command=requests.post, + return_raw_response=True, ) - if response.get("msg", False): + if response.ok: self.tags.extend(tags) - return response + return response.json() def remove_tags(self, tags: List[str]): """Remove tag(s) from the model. :: @@ -257,13 +258,14 @@ def remove_tags(self, tags: List[str]): Args: tags: list of tag names to remove """ - response = self._client.make_request( + response: requests.Response = self._client.make_request( {MODEL_TAGS_KEY: tags}, f"model/{self.id}/tag", requests_command=requests.delete, + return_raw_response=True, ) - if response.get("msg", False): + if response.ok: self.tags = list(filter(lambda t: t not in tags, self.tags)) - return response + return response.json()