From 9c43c17676ff9629783ca408a37ace38cd3d3ef6 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Mon, 2 Dec 2024 19:31:11 +0530 Subject: [PATCH] SK-1731: Updated SkyflowError class (#144) * SK-1731: Updated skyflow error class --- skyflow/client/skyflow.py | 10 ++++------ skyflow/error/_skyflow_error.py | 3 ++- skyflow/service_account/_utils.py | 2 +- skyflow/utils/_skyflow_messages.py | 2 ++ skyflow/utils/_utils.py | 2 +- tests/client/test_skyflow.py | 10 ++++++---- tests/service_account/test__utils.py | 5 +---- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 6cba302..3dd6bf4 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -99,9 +99,8 @@ def remove_vault_config(self, vault_id): if vault_id in self.__vault_configs.keys(): self.__vault_configs.pop(vault_id) else: - log_error(SkyflowMessages.Error.INVALID_VAULT_ID.value, - SkyflowMessages.ErrorCodes.INVALID_INPUT.value, - logger = self.__logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_VAULT_ID.value, + SkyflowMessages.ErrorCodes.INVALID_INPUT.value) def update_vault_config(self, config): validate_update_vault_config(self.__logger, config) @@ -141,9 +140,8 @@ def remove_connection_config(self, connection_id): if connection_id in self.__connection_configs.keys(): self.__connection_configs.pop(connection_id) else: - log_error(SkyflowMessages.Error.INVALID_CONNECTION_ID.value, - SkyflowMessages.ErrorCodes.INVALID_INPUT.value, - logger = self.__logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_CONNECTION_ID.value, + SkyflowMessages.ErrorCodes.INVALID_INPUT.value) def update_connection_config(self, config): validate_update_connection_config(self.__logger, config) diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index 68c97fb..e23c013 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -1,3 +1,4 @@ +from skyflow.utils import SkyflowMessages from skyflow.utils.logger import log_error class SkyflowError(Exception): @@ -11,7 +12,7 @@ def __init__(self, self.message = message self.http_code = http_code self.grpc_code = grpc_code - self.http_status = http_status + self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value self.details = details self.request_id = request_id log_error(message, http_code, request_id, grpc_code, http_status, details) diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 12ae41e..86da7d2 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -27,7 +27,7 @@ def is_expired(token, logger = None): except jwt.ExpiredSignatureError: return True except Exception: - log_error(SkyflowMessages.Error.JWT_DECODE_ERROR.value, invalid_input_error_code, logger = logger) + log_error_log(SkyflowMessages.Error.JWT_DECODE_ERROR.value, logger) return True def generate_bearer_token(credentials_file_path, options = None, logger = None): diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index afbb07e..b12401f 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -284,6 +284,8 @@ class Interface(Enum): UPDATE = "UPDATE" DELETE = "DELETE" + class HttpStatus(Enum): + BAD_REQUEST = "Bad Request" class Warning(Enum): WARNING_MESSAGE = "WARNING MESSAGE" diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index fa17a70..c88a49b 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -341,7 +341,7 @@ def parse_invoke_connection_response(api_response: requests.Response): def log_and_reject_error(description, status_code, request_id, http_status=None, grpc_code=None, details=None, logger = None): - log_error(description, status_code, request_id, grpc_code, http_status, details, logger= logger) + raise SkyflowError(description, status_code, request_id, grpc_code, http_status, details) def handle_exception(error, logger): request_id = error.headers.get('x-request-id', 'unknown-request-id') diff --git a/tests/client/test_skyflow.py b/tests/client/test_skyflow.py index a203b19..87ef1a8 100644 --- a/tests/client/test_skyflow.py +++ b/tests/client/test_skyflow.py @@ -69,8 +69,9 @@ def test_remove_vault_config_valid(self): def test_remove_vault_config_invalid(self, mock_log_error): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - self.builder.remove_vault_config("invalid_id") - mock_log_error.assert_called_once() + with self.assertRaises(SkyflowError) as context: + self.builder.remove_vault_config("invalid_id") + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_VAULT_ID.value) @patch('skyflow.vault.client.client.VaultClient.update_config') @@ -157,8 +158,9 @@ def test_remove_connection_config_valid(self): def test_remove_connection_config_invalid(self, mock_log_error): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() - self.builder.remove_connection_config("invalid_id") - mock_log_error.assert_called_once() + with self.assertRaises(SkyflowError) as context: + self.builder.remove_connection_config("invalid_id") + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONNECTION_ID.value) @patch('skyflow.vault.client.client.VaultClient.update_config') def test_update_connection_config_valid(self, mock_validate): diff --git a/tests/service_account/test__utils.py b/tests/service_account/test__utils.py index a426fdd..7ffb36d 100644 --- a/tests/service_account/test__utils.py +++ b/tests/service_account/test__utils.py @@ -47,14 +47,11 @@ def test_is_expired_expired_token(self): token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) - @patch("skyflow.service_account._utils.log_error") + @patch("skyflow.utils.logger._log_helpers.log_error_log") @patch("jwt.decode", side_effect=Exception("Some error")) def test_is_expired_general_exception(self, mock_jwt_decode, mock_log_error): token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) - mock_log_error.assert_called_once_with( - SkyflowMessages.Error.JWT_DECODE_ERROR.value, 400, logger=None - ) @patch("builtins.open", side_effect=FileNotFoundError) def test_generate_bearer_token_invalid_file_path(self, mock_open):