Skip to content

Commit

Permalink
SK-1649: Added validations and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshwar-skyflow committed Oct 20, 2024
1 parent a4374a0 commit 4a4909b
Show file tree
Hide file tree
Showing 36 changed files with 1,321 additions and 371 deletions.
120 changes: 69 additions & 51 deletions skyflow/client/skyflow.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from collections import OrderedDict
from skyflow import LogLevel
from skyflow.error import SkyflowError
from skyflow.utils.validations import validate_vault_config, validate_connection_config
from skyflow.utils import Logger, log_info, SkyflowMessages, log_error
from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \
validate_update_connection_config, validate_credentials, validate_log_level
from skyflow.vault.client.client import VaultClient
from skyflow.vault.controller import Vault
from skyflow.vault.controller import Connection

class Skyflow:
def __init__(self, builder):
self.__builder = builder
log_info(SkyflowMessages.Info.CLIENT_INITIALIZED.value,
SkyflowMessages.InterfaceName.CLIENT.value,
self.__builder.get_logger())

@staticmethod
def builder():
Expand Down Expand Up @@ -72,94 +77,107 @@ def connection(self, connection_id = None):
class Builder:
def __init__(self):
self.__vault_configs = OrderedDict()
self.__vault_list = list()
self.__connection_configs = OrderedDict()
self.__connection_list = list()
self.__skyflow_credentials = None
self.__log_level = LogLevel.ERROR
self.__log_level = LogLevel.OFF
self.__logger = Logger(LogLevel.OFF)

def add_vault_config(self, config):
if validate_vault_config(config) and config.get("vault_id") not in self.__vault_configs.keys():
vault_id = config.get("vault_id")
vault_client = VaultClient(config)
self.__vault_configs[vault_id] = {
"vault_client": vault_client,
"controller": Vault(vault_client)
}
return self
else:
raise SkyflowError(f"Vault config with id {config['vault_id']} already exists")
self.__vault_list.append(config)
return self

def remove_vault_config(self, vault_id):
if vault_id in self.__vault_configs.keys():
self.__vault_configs.pop(vault_id)
else:
raise SkyflowError(f"Vault config with id {vault_id} not found")
log_error(SkyflowMessages.Error.INVALID_VAULT_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger = self.__logger)

def update_vault_config(self, config):
validate_update_vault_config(self.__logger, config)
vault_id = config.get("vault_id")
if not vault_id:
raise SkyflowError("vault_id is required and cannot be None")
if vault_id in self.__vault_configs.keys():
vault_config = self.__vault_configs[vault_id]
vault_config.get("vault_client").update_config(config)
else:
raise SkyflowError(f"Vault config with id {vault_id} not found")
vault_config = self.__vault_configs[vault_id]
vault_config.get("vault_client").update_config(config)

def get_vault_config(self, vault_id):
if vault_id in self.__vault_configs.keys():
vault_config = self.__vault_configs.get(vault_id)
return vault_config
raise SkyflowError(f"Vault config with id {vault_id} not found")
raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger, logger_method=log_error)

def add_connection_config(self, config):
if validate_connection_config(config) and config["connection_id"] not in self.__connection_configs.keys():
connection_id = config.get("connection_id")
vault_client = VaultClient(config)
self.__connection_configs[connection_id] = {
"vault_client": vault_client,
"controller": Connection(vault_client)
}
return self
else:
raise SkyflowError(f"Connection config with id {config['connection_id']} already exists")
self.__connection_list.append(config)
return self

def remove_connection_config(self, connection_id):
if connection_id in self.__connection_configs.keys():
self.__connection_configs.pop(connection_id)
else:
raise SkyflowError(f"Connection config with id {connection_id} not found")
log_error(SkyflowMessages.Error.INVALID_CONNECTION_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger = self.__logger)

def update_connection_config(self, config):
validate_update_connection_config(self.__logger, config)
connection_id = config['connection_id']
if not connection_id:
raise SkyflowError("connection_id is required and can not be empty")

if connection_id in self.__connection_configs.keys():
connection_config = self.__connection_configs[connection_id]
connection_config.get("vault_client").update_config(config)
else:
raise SkyflowError(f"Connection config with id {connection_id} not found")
connection_config = self.__connection_configs[connection_id]
connection_config.get("vault_client").update_config(config)

def get_connection_config(self, connection_id):
if connection_id in self.__connection_configs.keys():
connection_config = self.__connection_configs[connection_id]
return connection_config
raise SkyflowError(f"Connection config with id {connection_id} not found")
raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger, logger_method=log_error)

def add_skyflow_credentials(self, credentials):
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_common_skyflow_credentials(credentials)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_common_skyflow_credentials(credentials)
self.__skyflow_credentials = credentials
return self

def set_log_level(self, log_level):
self.__log_level = log_level
return self

def get_logger(self):
return self.__logger

def build(self):
log_info(SkyflowMessages.Info.INITIALIZE_CLIENT.value, SkyflowMessages.InterfaceName.CLIENT.value, self.__logger)
validate_log_level(self.__logger, self.__log_level)
self.__logger.set_log_level(self.__log_level)

for config in self.__vault_list:
validate_vault_config(self.__logger, config)
vault_id = config.get("vault_id")
vault_client = VaultClient(config)
self.__vault_configs[vault_id] = {
"vault_client": vault_client,
"controller": Vault(vault_client)
}

for config in self.__connection_list:
validate_connection_config(self.__logger, config=config)
connection_id = config.get("connection_id")
vault_client = VaultClient(config)
self.__connection_configs[connection_id] = {
"vault_client": vault_client,
"controller": Connection(vault_client)
}

for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_log_level(log_level)
vault_config.get("vault_client").set_logger(self.__log_level, self.__logger)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_log_level(log_level)
return self
connection_config.get("vault_client").set_logger(self.__log_level, self.__logger)

def build(self):
return Skyflow(self)
if self.__skyflow_credentials is not None:
validate_credentials(self.__logger, self.__skyflow_credentials)
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials)

return Skyflow(self)
2 changes: 1 addition & 1 deletion skyflow/error/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from skyflow.error._skyflow_error import SkyflowError
from ._skyflow_error import SkyflowError
15 changes: 12 additions & 3 deletions skyflow/error/_skyflow_error.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
class SkyflowError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message
def __init__(self,
message,
http_code,
request_id = None,
grpc_code = None,
http_status = None,
details = None,
logger = None,
logger_method = None):

logger_method(message, http_code, request_id, grpc_code, http_status, details, logger)
super().__init__()
2 changes: 1 addition & 1 deletion skyflow/service_account/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._utils import generate_bearer_token, generate_bearer_token_from_creds
from ._utils import generate_bearer_token, generate_bearer_token_from_creds, is_expired
50 changes: 26 additions & 24 deletions skyflow/service_account/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from skyflow.error import SkyflowError
from skyflow.generated.rest.models import V1GetAuthTokenRequest
from skyflow.service_account.client.auth_client import AuthClient
from skyflow.utils import get_base_url, format_scope
from skyflow.utils import get_base_url, format_scope, SkyflowMessages, log_error

def is_expired(token):
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value

def is_expired(token, logger = None):
if len(token) == 0:
return True

Expand All @@ -18,54 +20,54 @@ def is_expired(token):
return False
except jwt.ExpiredSignatureError:
return True
except Exception as e:
SkyflowError("Invalid token")
except Exception:
log_error(SkyflowMessages.Error.JWT_DECODE_ERROR.value, invalid_input_error_code, logger = logger)
return True
pass

def generate_bearer_token(credentials_file_path, options = None):
def generate_bearer_token(credentials_file_path, options = None, logger = None):
try:
credentials_file =open(credentials_file_path, 'r')
except Exception:
raise SkyflowError("Invalid file path")
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger, logger_method=log_error)

try:
credentials = json.load(credentials_file)
except Exception:
raise SkyflowError("Error in json parsing")
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code, logger = logger, logger_method=log_error)

finally:
credentials_file.close()
result = get_service_account_token(credentials, options)
result = get_service_account_token(credentials, options, logger)
return result

def generate_bearer_token_from_creds(credentials, options = None):
def generate_bearer_token_from_creds(credentials, options = None, logger = None):
try:
json_credentials = json.loads(credentials.replace('\n', '\\n'))
except Exception as e:
raise SkyflowError(e)
result = get_service_account_token(json_credentials, options)
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value, invalid_input_error_code, logger = logger, logger_method=log_error)
result = get_service_account_token(json_credentials, options, logger)
return result

def get_service_account_token(credentials, options):
def get_service_account_token(credentials, options, logger):
try:
private_key = credentials["privateKey"]
except:
raise SkyflowError("privateKey not found")
raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code, logger = logger, logger_method=log_error)
try:
client_id = credentials["clientID"]
except:
raise SkyflowError("clientID not found")
raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code, logger = logger, logger_method=log_error)
try:
key_id = credentials["keyID"]
except:
raise SkyflowError("keyID not found")
raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code, logger = logger, logger_method=log_error)
try:
token_uri = credentials["tokenURI"]
except:
raise SkyflowError("tokenURI not found")
raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code, logger = logger, logger_method=log_error)

signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key)
signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger)
base_url = get_base_url(token_uri)
auth_client = AuthClient(base_url)
auth_api = auth_client.get_auth_api()
Expand All @@ -80,7 +82,7 @@ def get_service_account_token(credentials, options):
response = auth_api.authentication_service_get_auth_token(request)
return response.access_token, response.token_type

def get_signed_jwt(options, client_id, key_id, token_uri, private_key):
def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):
payload = {
"iss": client_id,
"key": key_id,
Expand All @@ -93,7 +95,7 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key):
try:
return jwt.encode(payload=payload, key=private_key, algorithm="RS256")
except Exception as e:
raise SkyflowError("")
raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code, logger = logger, logger_method=log_error)



Expand All @@ -102,7 +104,7 @@ def get_signed_tokens(credentials, options):
try:
credentials_obj = json.loads(credentials)
except:
raise SkyflowError("Invalid JSON")
raise ValueError("Invalid JSON")

expiry_time = time.time() + options.get("time_to_live", 60)
prefix = "signed_token_"
Expand Down Expand Up @@ -130,16 +132,16 @@ def get_signed_tokens(credentials, options):
return response_array

except Exception as e:
raise SkyflowError(str(e))
raise ValueError(str(e))


def generate_signed_data_tokens(credentials_file_path, options):
def generate_signed_data_tokens(credentials_file_path, options, logger = None):
try:
credentials_file =open(credentials_file_path, 'r')
except Exception:
raise SkyflowError("Invalid file path")
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger, logger_method=log_error)

return get_signed_tokens(credentials_file_path, options)
return get_signed_tokens(credentials_file, options)

def generate_signed_data_tokens_from_creds(credentials, options):
return get_signed_tokens(credentials, options)
Expand Down
6 changes: 5 additions & 1 deletion skyflow/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from ..utils.enums import LogLevel, Env
from ._utils import get_credentials, get_vault_url, get_client_configuration, get_base_url, format_scope, get_redaction_type, construct_invoke_connection_request, build_field_records
from ._skyflow_messages import SkyflowMessages
from ._version import SDK_VERSION
from ._logger import Logger
from ._log_helpers import log_error, log_info
from ._utils import get_credentials, get_vault_url, get_client_configuration, get_base_url, format_scope, get_redaction_type, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response
27 changes: 27 additions & 0 deletions skyflow/utils/_log_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .enums import LogLevel
from . import Logger


def log_info(message, interface, logger = None):
formatted_message = '{} {}'.format(interface, message)
logger.info(formatted_message)

def log_error(message, http_code, request_id=None, grpc_code=None, http_status=None, details=None, logger=None):
if not logger:
logger = Logger(LogLevel.ERROR)

log_data = {
'http_code': http_code,
'message': message
}

if grpc_code is not None:
log_data['grpc_code'] = grpc_code
if http_status is not None:
log_data['http_status'] = http_status
if request_id is not None:
log_data['request_id'] = request_id
if details is not None:
log_data['details'] = details

logger.error(log_data)
Loading

0 comments on commit 4a4909b

Please sign in to comment.