Skip to content

Commit

Permalink
SK-1731: Updated SkyflowError class (#144)
Browse files Browse the repository at this point in the history
* SK-1731: Updated skyflow error class
  • Loading branch information
saileshwar-skyflow authored Dec 2, 2024
1 parent 41e4ab0 commit 9c43c17
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 17 deletions.
10 changes: 4 additions & 6 deletions skyflow/client/skyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ def remove_vault_config(self, vault_id):
if vault_id in self.__vault_configs.keys():
self.__vault_configs.pop(vault_id)
else:
log_error(SkyflowMessages.Error.INVALID_VAULT_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger = self.__logger)
raise SkyflowError(SkyflowMessages.Error.INVALID_VAULT_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

def update_vault_config(self, config):
validate_update_vault_config(self.__logger, config)
Expand Down Expand Up @@ -141,9 +140,8 @@ def remove_connection_config(self, connection_id):
if connection_id in self.__connection_configs.keys():
self.__connection_configs.pop(connection_id)
else:
log_error(SkyflowMessages.Error.INVALID_CONNECTION_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger = self.__logger)
raise SkyflowError(SkyflowMessages.Error.INVALID_CONNECTION_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

def update_connection_config(self, config):
validate_update_connection_config(self.__logger, config)
Expand Down
3 changes: 2 additions & 1 deletion skyflow/error/_skyflow_error.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from skyflow.utils import SkyflowMessages
from skyflow.utils.logger import log_error

class SkyflowError(Exception):
Expand All @@ -11,7 +12,7 @@ def __init__(self,
self.message = message
self.http_code = http_code
self.grpc_code = grpc_code
self.http_status = http_status
self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value
self.details = details
self.request_id = request_id
log_error(message, http_code, request_id, grpc_code, http_status, details)
Expand Down
2 changes: 1 addition & 1 deletion skyflow/service_account/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def is_expired(token, logger = None):
except jwt.ExpiredSignatureError:
return True
except Exception:
log_error(SkyflowMessages.Error.JWT_DECODE_ERROR.value, invalid_input_error_code, logger = logger)
log_error_log(SkyflowMessages.Error.JWT_DECODE_ERROR.value, logger)
return True

def generate_bearer_token(credentials_file_path, options = None, logger = None):
Expand Down
2 changes: 2 additions & 0 deletions skyflow/utils/_skyflow_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ class Interface(Enum):
UPDATE = "UPDATE"
DELETE = "DELETE"

class HttpStatus(Enum):
BAD_REQUEST = "Bad Request"

class Warning(Enum):
WARNING_MESSAGE = "WARNING MESSAGE"
Expand Down
2 changes: 1 addition & 1 deletion skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def parse_invoke_connection_response(api_response: requests.Response):


def log_and_reject_error(description, status_code, request_id, http_status=None, grpc_code=None, details=None, logger = None):
log_error(description, status_code, request_id, grpc_code, http_status, details, logger= logger)
raise SkyflowError(description, status_code, request_id, grpc_code, http_status, details)

def handle_exception(error, logger):
request_id = error.headers.get('x-request-id', 'unknown-request-id')
Expand Down
10 changes: 6 additions & 4 deletions tests/client/test_skyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def test_remove_vault_config_valid(self):
def test_remove_vault_config_invalid(self, mock_log_error):
self.builder.add_vault_config(VALID_VAULT_CONFIG)
self.builder.build()
self.builder.remove_vault_config("invalid_id")
mock_log_error.assert_called_once()
with self.assertRaises(SkyflowError) as context:
self.builder.remove_vault_config("invalid_id")
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_VAULT_ID.value)


@patch('skyflow.vault.client.client.VaultClient.update_config')
Expand Down Expand Up @@ -157,8 +158,9 @@ def test_remove_connection_config_valid(self):
def test_remove_connection_config_invalid(self, mock_log_error):
self.builder.add_connection_config(VALID_CONNECTION_CONFIG)
self.builder.build()
self.builder.remove_connection_config("invalid_id")
mock_log_error.assert_called_once()
with self.assertRaises(SkyflowError) as context:
self.builder.remove_connection_config("invalid_id")
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONNECTION_ID.value)

@patch('skyflow.vault.client.client.VaultClient.update_config')
def test_update_connection_config_valid(self, mock_validate):
Expand Down
5 changes: 1 addition & 4 deletions tests/service_account/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@ def test_is_expired_expired_token(self):
token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256")
self.assertTrue(is_expired(token))

@patch("skyflow.service_account._utils.log_error")
@patch("skyflow.utils.logger._log_helpers.log_error_log")
@patch("jwt.decode", side_effect=Exception("Some error"))
def test_is_expired_general_exception(self, mock_jwt_decode, mock_log_error):
token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256")
self.assertTrue(is_expired(token))
mock_log_error.assert_called_once_with(
SkyflowMessages.Error.JWT_DECODE_ERROR.value, 400, logger=None
)

@patch("builtins.open", side_effect=FileNotFoundError)
def test_generate_bearer_token_invalid_file_path(self, mock_open):
Expand Down

0 comments on commit 9c43c17

Please sign in to comment.