Skip to content

Commit

Permalink
SK-1650: Added unit test cases for python sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshwar-skyflow committed Nov 3, 2024
1 parent 1a5eb20 commit 704f48e
Show file tree
Hide file tree
Showing 42 changed files with 2,390 additions and 306 deletions.
131 changes: 74 additions & 57 deletions skyflow/client/skyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,27 @@
class Skyflow:
def __init__(self, builder):
self.__builder = builder
log_info(SkyflowMessages.Info.CLIENT_INITIALIZED.value,
SkyflowMessages.InterfaceName.CLIENT.value,
self.__builder.get_logger())
log_info(SkyflowMessages.Info.CLIENT_INITIALIZED.value, self.__builder.get_logger())

@staticmethod
def builder():
return Skyflow.Builder()

def add_vault_config(self, config):
self.__builder.add_vault_config(config)
self.__builder._Builder__add_vault_config(config)
return self

def remove_vault_config(self, vault_id):
self.__builder.remove_vault_config(vault_id)
return self

def update_vault_config(self,config):
self.__builder.update_vault_config(config)
return self

def get_vault_config(self, vault_id):
return self.__builder.get_vault_config(vault_id)
return self.__builder.get_vault_config(vault_id).get("vault_client").get_config()

def add_connection_config(self, config):
self.__builder.add_connection_config(config)
self.__builder._Builder__add_connection_config(config)
return self

def remove_connection_config(self, connection_id):
Expand All @@ -48,24 +44,21 @@ def update_connection_config(self, config):
return self

def get_connection_config(self, connection_id):
self.__builder.get_connection_config(connection_id)
return self
return self.__builder.get_connection_config(connection_id).get("vault_client").get_config()

def add_skyflow_credentials(self, credentials):
self.__builder.add_skyflow_credentials(credentials)
self.__builder._Builder__add_skyflow_credentials(credentials)
return self

def update_skyflow_credentials(self, credentials):
self.__builder.add_skyflow_credentials(credentials)
return self
self.__builder._Builder__add_skyflow_credentials(credentials)

def set_log_level(self, log_level):
self.__builder.set_log_level(log_level)
self.__builder._Builder__set_log_level(log_level)
return self

def update_log_level(self, log_level):
self.__builder.set_log_level(log_level)
return self
self.__builder._Builder__set_log_level(log_level)

def vault(self, vault_id = None):
vault_config = self.__builder.get_vault_config(vault_id)
Expand All @@ -90,14 +83,13 @@ def add_vault_config(self, config):
if not isinstance(vault_id, str) or not vault_id:
raise SkyflowError(
SkyflowMessages.Error.INVALID_VAULT_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger=self.__logger
SkyflowMessages.ErrorCodes.INVALID_INPUT.value
)
if vault_id in [vault.get("vault_id") for vault in self.__vault_list]:
log_info(SkyflowMessages.Info.VAULT_CONFIG_EXISTS.value.format(vault_id), self.__logger)
raise SkyflowError(
SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger=self.__logger
SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(vault_id),
SkyflowMessages.ErrorCodes.INVALID_INPUT.value
)

self.__vault_list.append(config)
Expand All @@ -121,25 +113,26 @@ def get_vault_config(self, vault_id):
if vault_id is None:
if self.__vault_configs:
return next(iter(self.__vault_configs.values()))
raise SkyflowError(SkyflowMessages.Error.EMPTY_VAULT_CONFIGS.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger)
raise SkyflowError(SkyflowMessages.Error.EMPTY_VAULT_CONFIGS.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

if vault_id in self.__vault_configs:
return self.__vault_configs.get(vault_id)
raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger)
log_info(SkyflowMessages.Info.VAULT_CONFIG_DOES_NOT_EXIST.value.format(vault_id), self.__logger)
raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value)


def add_connection_config(self, config):
connection_id = config.get("connection_id")
if not isinstance(connection_id, str) or not connection_id:
raise SkyflowError(
SkyflowMessages.Error.INVALID_CONNECTION_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger = self.__logger
SkyflowMessages.ErrorCodes.INVALID_INPUT.value
)
if connection_id in [connection.get("connection_id") for connection in self.__connection_list]:
log_info(SkyflowMessages.Info.CONNECTION_CONFIG_EXISTS.value.format(connection_id), self.__logger)
raise SkyflowError(
SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
logger=self.__logger
SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id),
SkyflowMessages.ErrorCodes.INVALID_INPUT.value
)
self.__connection_list.append(config)
return self
Expand All @@ -162,11 +155,14 @@ def get_connection_config(self, connection_id):
if connection_id is None:
if self.__connection_configs:
return next(iter(self.__connection_configs.values()))
return SkyflowError(SkyflowMessages.Error.EMPTY_CONNECTION_CONFIGS, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger)

raise SkyflowError(SkyflowMessages.Error.EMPTY_CONNECTION_CONFIGS.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

if connection_id in self.__connection_configs:
return self.__connection_configs.get(connection_id)
raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger)
log_info(SkyflowMessages.Info.CONNECTION_CONFIG_DOES_NOT_EXIST.value.format(connection_id), self.__logger)
raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value)


def add_skyflow_credentials(self, credentials):
self.__skyflow_credentials = credentials
Expand All @@ -179,41 +175,62 @@ def set_log_level(self, log_level):
def get_logger(self):
return self.__logger

def __add_vault_config(self, config):
validate_vault_config(self.__logger, config)
vault_id = config.get("vault_id")
vault_client = VaultClient(config)
self.__vault_configs[vault_id] = {
"vault_client": vault_client,
"controller": Vault(vault_client)
}
log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger)

def __add_connection_config(self, config):
validate_connection_config(self.__logger, config)
connection_id = config.get("connection_id")
vault_client = VaultClient(config)
self.__connection_configs[connection_id] = {
"vault_client": vault_client,
"controller": Connection(vault_client)
}
log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get("connection_id")), self.__logger)

def __update_vault_client_logger(self, log_level, logger):
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_logger(log_level,logger)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_logger(log_level,logger)

def __set_log_level(self, log_level):
validate_log_level(self.__logger, log_level)
self.__log_level = log_level
self.__logger.set_log_level(log_level)
self.__update_vault_client_logger(log_level, self.__logger)
log_info(SkyflowMessages.Info.LOGGER_SETUP_DONE.value, self.__logger)
log_info(SkyflowMessages.Info.CURRENT_LOG_LEVEL.value.format(self.__log_level), self.__logger)

def __add_skyflow_credentials(self, credentials):
if credentials is not None:
self.__skyflow_credentials = credentials
validate_credentials(self.__logger, credentials)
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_common_skyflow_credentials(credentials)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials)
def build(self):
log_info(SkyflowMessages.Info.INITIALIZE_CLIENT.value, SkyflowMessages.InterfaceName.CLIENT.value, self.__logger)
validate_log_level(self.__logger, self.__log_level)
self.__logger.set_log_level(self.__log_level)

for config in self.__vault_list:
validate_vault_config(self.__logger, config)
vault_id = config.get("vault_id")
vault_client = VaultClient(config)
self.__vault_configs[vault_id] = {
"vault_client": vault_client,
"controller": Vault(vault_client)
}
self.__add_vault_config(config)

for config in self.__connection_list:
validate_connection_config(self.__logger, config=config)
connection_id = config.get("connection_id")
vault_client = VaultClient(config)
self.__connection_configs[connection_id] = {
"vault_client": vault_client,
"controller": Connection(vault_client)
}
self.__add_connection_config(config)

for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_logger(self.__log_level, self.__logger)
self.__update_vault_client_logger(self.__log_level, self.__logger)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_logger(self.__log_level, self.__logger)

if self.__skyflow_credentials is not None:
validate_credentials(self.__logger, self.__skyflow_credentials)
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials)
self.__add_skyflow_credentials(self.__skyflow_credentials)

return Skyflow(self)
6 changes: 3 additions & 3 deletions skyflow/error/_skyflow_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(self,
request_id = None,
grpc_code = None,
http_status = None,
details = None,
logger = None):
log_error(message, http_code, request_id, grpc_code, http_status, details, logger)
details = None):
self.message = message
log_error(message, http_code, request_id, grpc_code, http_status, details)
super().__init__()
2 changes: 1 addition & 1 deletion skyflow/service_account/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._utils import generate_bearer_token, generate_bearer_token_from_creds, is_expired, validate_api_key
from ._utils import generate_bearer_token, generate_bearer_token_from_creds, is_expired
Loading

0 comments on commit 704f48e

Please sign in to comment.