Skip to content

Commit

Permalink
SK-1649: Added logging and updated error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshwar-skyflow committed Oct 22, 2024
1 parent 4a4909b commit 1a5eb20
Show file tree
Hide file tree
Showing 24 changed files with 452 additions and 285 deletions.
58 changes: 47 additions & 11 deletions skyflow/client/skyflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import OrderedDict
from skyflow import LogLevel
from skyflow.error import SkyflowError
from skyflow.utils import Logger, log_info, SkyflowMessages, log_error
from skyflow.utils import SkyflowMessages
from skyflow.utils.logger import log_info, Logger, log_error
from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \
validate_update_connection_config, validate_credentials, validate_log_level
from skyflow.vault.client.client import VaultClient
Expand Down Expand Up @@ -81,10 +82,24 @@ def __init__(self):
self.__connection_configs = OrderedDict()
self.__connection_list = list()
self.__skyflow_credentials = None
self.__log_level = LogLevel.OFF
self.__logger = Logger(LogLevel.OFF)
self.__log_level = LogLevel.ERROR
self.__logger = Logger(LogLevel.ERROR)

def add_vault_config(self, config):
vault_id = config.get("vault_id")
if not isinstance(vault_id, str) or not vault_id:
raise SkyflowError(
SkyflowMessages.Error.INVALID_VAULT_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger=self.__logger
)
if vault_id in [vault.get("vault_id") for vault in self.__vault_list]:
raise SkyflowError(
SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger=self.__logger
)

self.__vault_list.append(config)
return self

Expand All @@ -103,12 +118,29 @@ def update_vault_config(self, config):
vault_config.get("vault_client").update_config(config)

def get_vault_config(self, vault_id):
if vault_id in self.__vault_configs.keys():
vault_config = self.__vault_configs.get(vault_id)
return vault_config
raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger, logger_method=log_error)
if vault_id is None:
if self.__vault_configs:
return next(iter(self.__vault_configs.values()))
raise SkyflowError(SkyflowMessages.Error.EMPTY_VAULT_CONFIGS.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger)

if vault_id in self.__vault_configs:
return self.__vault_configs.get(vault_id)
raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger)

def add_connection_config(self, config):
connection_id = config.get("connection_id")
if not isinstance(connection_id, str) or not connection_id:
raise SkyflowError(
SkyflowMessages.Error.INVALID_CONNECTION_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger = self.__logger
)
if connection_id in [connection.get("connection_id") for connection in self.__connection_list]:
raise SkyflowError(
SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger=self.__logger
)
self.__connection_list.append(config)
return self

Expand All @@ -127,10 +159,14 @@ def update_connection_config(self, config):
connection_config.get("vault_client").update_config(config)

def get_connection_config(self, connection_id):
if connection_id in self.__connection_configs.keys():
connection_config = self.__connection_configs[connection_id]
return connection_config
raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger, logger_method=log_error)
if connection_id is None:
if self.__connection_configs:
return next(iter(self.__connection_configs.values()))
return SkyflowError(SkyflowMessages.Error.EMPTY_CONNECTION_CONFIGS, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger)

if connection_id in self.__connection_configs:
return self.__connection_configs.get(connection_id)
raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger)

def add_skyflow_credentials(self, credentials):
self.__skyflow_credentials = credentials
Expand Down
8 changes: 4 additions & 4 deletions skyflow/error/_skyflow_error.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from skyflow.utils.logger import log_error

class SkyflowError(Exception):
def __init__(self,
message,
Expand All @@ -6,8 +8,6 @@ def __init__(self,
grpc_code = None,
http_status = None,
details = None,
logger = None,
logger_method = None):

logger_method(message, http_code, request_id, grpc_code, http_status, details, logger)
logger = None):
log_error(message, http_code, request_id, grpc_code, http_status, details, logger)
super().__init__()
2 changes: 1 addition & 1 deletion skyflow/service_account/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._utils import generate_bearer_token, generate_bearer_token_from_creds, is_expired
from ._utils import generate_bearer_token, generate_bearer_token_from_creds, is_expired, validate_api_key
41 changes: 27 additions & 14 deletions skyflow/service_account/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from skyflow.error import SkyflowError
from skyflow.generated.rest.models import V1GetAuthTokenRequest
from skyflow.service_account.client.auth_client import AuthClient
from skyflow.utils import get_base_url, format_scope, SkyflowMessages, log_error
from skyflow.utils.logger import log_error
from skyflow.utils import get_base_url, format_scope, SkyflowMessages


invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value

Expand All @@ -25,55 +27,66 @@ def is_expired(token, logger = None):
return True
pass

import re

def validate_api_key(api_key: str) -> bool:
if len(api_key) != 42:
return False
api_key_pattern = re.compile(r'^sky-[a-zA-Z0-9]{5}-[a-fA-F0-9]{32}$')

return bool(api_key_pattern.match(api_key))


def generate_bearer_token(credentials_file_path, options = None, logger = None):
try:
credentials_file =open(credentials_file_path, 'r')
except Exception:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger, logger_method=log_error)
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger)

try:
credentials = json.load(credentials_file)
except Exception:
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code, logger = logger, logger_method=log_error)
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code, logger = logger)

finally:
credentials_file.close()
result = get_service_account_token(credentials, options, logger)
return result

def generate_bearer_token_from_creds(credentials, options = None, logger = None):
credentials = credentials.strip()
try:
json_credentials = json.loads(credentials.replace('\n', '\\n'))
except Exception as e:
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value, invalid_input_error_code, logger = logger, logger_method=log_error)
except Exception:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value, invalid_input_error_code, logger = logger)
result = get_service_account_token(json_credentials, options, logger)
return result

def get_service_account_token(credentials, options, logger):
try:
private_key = credentials["privateKey"]
except:
raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code, logger = logger, logger_method=log_error)
raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code, logger = logger)
try:
client_id = credentials["clientID"]
except:
raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code, logger = logger, logger_method=log_error)
raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code, logger = logger)
try:
key_id = credentials["keyID"]
except:
raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code, logger = logger, logger_method=log_error)
raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code, logger = logger)
try:
token_uri = credentials["tokenURI"]
except:
raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code, logger = logger, logger_method=log_error)
raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code, logger = logger)

signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger)
base_url = get_base_url(token_uri)
auth_client = AuthClient(base_url)
auth_api = auth_client.get_auth_api()

formatted_scope = None
if "role_ids" in options:
if options and "role_ids" in options:
formatted_scope = format_scope(options.get("role_ids"))

request = V1GetAuthTokenRequest(assertion = signed_token,
Expand All @@ -90,12 +103,12 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):
"sub": client_id,
"exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=60)
}
if "ctx" in options:
if options and "ctx" in options:
payload["ctx"] = options.get("ctx")
try:
return jwt.encode(payload=payload, key=private_key, algorithm="RS256")
except Exception as e:
raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code, logger = logger, logger_method=log_error)
except Exception:
raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code, logger = logger)



Expand Down Expand Up @@ -139,7 +152,7 @@ def generate_signed_data_tokens(credentials_file_path, options, logger = None):
try:
credentials_file =open(credentials_file_path, 'r')
except Exception:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger, logger_method=log_error)
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger)

return get_signed_tokens(credentials_file, options)

Expand Down
6 changes: 3 additions & 3 deletions skyflow/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..utils.enums import LogLevel, Env
from ._skyflow_messages import SkyflowMessages
from ._version import SDK_VERSION
from ._logger import Logger
from ._log_helpers import log_error, log_info
from ._utils import get_credentials, get_vault_url, get_client_configuration, get_base_url, format_scope, get_redaction_type, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response
from ._helpers import get_base_url, format_scope
from ._utils import get_credentials, get_vault_url, get_client_configuration, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response

11 changes: 11 additions & 0 deletions skyflow/utils/_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from urllib.parse import urlparse

def get_base_url(url):
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
return base_url

def format_scope(scopes):
if not scopes:
return None
return " ".join([f"role:{scope}" for scope in scopes])
Loading

0 comments on commit 1a5eb20

Please sign in to comment.