diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index ce1be52..7e54191 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -1,7 +1,8 @@ from collections import OrderedDict from skyflow import LogLevel from skyflow.error import SkyflowError -from skyflow.utils import Logger, log_info, SkyflowMessages, log_error +from skyflow.utils import SkyflowMessages +from skyflow.utils.logger import log_info, Logger, log_error from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \ validate_update_connection_config, validate_credentials, validate_log_level from skyflow.vault.client.client import VaultClient @@ -81,10 +82,24 @@ def __init__(self): self.__connection_configs = OrderedDict() self.__connection_list = list() self.__skyflow_credentials = None - self.__log_level = LogLevel.OFF - self.__logger = Logger(LogLevel.OFF) + self.__log_level = LogLevel.ERROR + self.__logger = Logger(LogLevel.ERROR) def add_vault_config(self, config): + vault_id = config.get("vault_id") + if not isinstance(vault_id, str) or not vault_id: + raise SkyflowError( + SkyflowMessages.Error.INVALID_VAULT_ID.value, + SkyflowMessages.ErrorCodes.INVALID_INPUT.value, + logger=self.__logger + ) + if vault_id in [vault.get("vault_id") for vault in self.__vault_list]: + raise SkyflowError( + SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value, + SkyflowMessages.ErrorCodes.INVALID_INPUT.value, + logger=self.__logger + ) + self.__vault_list.append(config) return self @@ -103,12 +118,29 @@ def update_vault_config(self, config): vault_config.get("vault_client").update_config(config) def get_vault_config(self, vault_id): - if vault_id in self.__vault_configs.keys(): - vault_config = self.__vault_configs.get(vault_id) - return vault_config - raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger, logger_method=log_error) + if vault_id is None: + if self.__vault_configs: + return next(iter(self.__vault_configs.values())) + raise SkyflowError(SkyflowMessages.Error.EMPTY_VAULT_CONFIGS.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger) + + if vault_id in self.__vault_configs: + return self.__vault_configs.get(vault_id) + raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger) def add_connection_config(self, config): + connection_id = config.get("connection_id") + if not isinstance(connection_id, str) or not connection_id: + raise SkyflowError( + SkyflowMessages.Error.INVALID_CONNECTION_ID.value, + SkyflowMessages.ErrorCodes.INVALID_INPUT.value, + logger = self.__logger + ) + if connection_id in [connection.get("connection_id") for connection in self.__connection_list]: + raise SkyflowError( + SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value, + SkyflowMessages.ErrorCodes.INVALID_INPUT.value, + logger=self.__logger + ) self.__connection_list.append(config) return self @@ -127,10 +159,14 @@ def update_connection_config(self, config): connection_config.get("vault_client").update_config(config) def get_connection_config(self, connection_id): - if connection_id in self.__connection_configs.keys(): - connection_config = self.__connection_configs[connection_id] - return connection_config - raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger, logger_method=log_error) + if connection_id is None: + if self.__connection_configs: + return next(iter(self.__connection_configs.values())) + return SkyflowError(SkyflowMessages.Error.EMPTY_CONNECTION_CONFIGS, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger) + + if connection_id in self.__connection_configs: + return self.__connection_configs.get(connection_id) + raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger = self.__logger) def add_skyflow_credentials(self, credentials): self.__skyflow_credentials = credentials diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index 5c52f26..2a40d91 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -1,3 +1,5 @@ +from skyflow.utils.logger import log_error + class SkyflowError(Exception): def __init__(self, message, @@ -6,8 +8,6 @@ def __init__(self, grpc_code = None, http_status = None, details = None, - logger = None, - logger_method = None): - - logger_method(message, http_code, request_id, grpc_code, http_status, details, logger) + logger = None): + log_error(message, http_code, request_id, grpc_code, http_status, details, logger) super().__init__() \ No newline at end of file diff --git a/skyflow/service_account/__init__.py b/skyflow/service_account/__init__.py index b5c1919..111f235 100644 --- a/skyflow/service_account/__init__.py +++ b/skyflow/service_account/__init__.py @@ -1 +1 @@ -from ._utils import generate_bearer_token, generate_bearer_token_from_creds, is_expired \ No newline at end of file +from ._utils import generate_bearer_token, generate_bearer_token_from_creds, is_expired, validate_api_key \ No newline at end of file diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 71141ea..c099860 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -5,7 +5,9 @@ from skyflow.error import SkyflowError from skyflow.generated.rest.models import V1GetAuthTokenRequest from skyflow.service_account.client.auth_client import AuthClient -from skyflow.utils import get_base_url, format_scope, SkyflowMessages, log_error +from skyflow.utils.logger import log_error +from skyflow.utils import get_base_url, format_scope, SkyflowMessages + invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value @@ -25,16 +27,26 @@ def is_expired(token, logger = None): return True pass +import re + +def validate_api_key(api_key: str) -> bool: + if len(api_key) != 42: + return False + api_key_pattern = re.compile(r'^sky-[a-zA-Z0-9]{5}-[a-fA-F0-9]{32}$') + + return bool(api_key_pattern.match(api_key)) + + def generate_bearer_token(credentials_file_path, options = None, logger = None): try: credentials_file =open(credentials_file_path, 'r') except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger) try: credentials = json.load(credentials_file) except Exception: - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code, logger = logger) finally: credentials_file.close() @@ -42,10 +54,11 @@ def generate_bearer_token(credentials_file_path, options = None, logger = None): return result def generate_bearer_token_from_creds(credentials, options = None, logger = None): + credentials = credentials.strip() try: json_credentials = json.loads(credentials.replace('\n', '\\n')) - except Exception as e: - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value, invalid_input_error_code, logger = logger, logger_method=log_error) + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value, invalid_input_error_code, logger = logger) result = get_service_account_token(json_credentials, options, logger) return result @@ -53,19 +66,19 @@ def get_service_account_token(credentials, options, logger): try: private_key = credentials["privateKey"] except: - raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code, logger = logger) try: client_id = credentials["clientID"] except: - raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code, logger = logger) try: key_id = credentials["keyID"] except: - raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code, logger = logger) try: token_uri = credentials["tokenURI"] except: - raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code, logger = logger) signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger) base_url = get_base_url(token_uri) @@ -73,7 +86,7 @@ def get_service_account_token(credentials, options, logger): auth_api = auth_client.get_auth_api() formatted_scope = None - if "role_ids" in options: + if options and "role_ids" in options: formatted_scope = format_scope(options.get("role_ids")) request = V1GetAuthTokenRequest(assertion = signed_token, @@ -90,12 +103,12 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): "sub": client_id, "exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if "ctx" in options: + if options and "ctx" in options: payload["ctx"] = options.get("ctx") try: return jwt.encode(payload=payload, key=private_key, algorithm="RS256") - except Exception as e: - raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code, logger = logger, logger_method=log_error) + except Exception: + raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code, logger = logger) @@ -139,7 +152,7 @@ def generate_signed_data_tokens(credentials_file_path, options, logger = None): try: credentials_file =open(credentials_file_path, 'r') except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code, logger = logger) return get_signed_tokens(credentials_file, options) diff --git a/skyflow/utils/__init__.py b/skyflow/utils/__init__.py index 3c8cfbf..e2ccbbe 100644 --- a/skyflow/utils/__init__.py +++ b/skyflow/utils/__init__.py @@ -1,6 +1,6 @@ from ..utils.enums import LogLevel, Env from ._skyflow_messages import SkyflowMessages from ._version import SDK_VERSION -from ._logger import Logger -from ._log_helpers import log_error, log_info -from ._utils import get_credentials, get_vault_url, get_client_configuration, get_base_url, format_scope, get_redaction_type, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response \ No newline at end of file +from ._helpers import get_base_url, format_scope +from ._utils import get_credentials, get_vault_url, get_client_configuration, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response + diff --git a/skyflow/utils/_helpers.py b/skyflow/utils/_helpers.py new file mode 100644 index 0000000..97eecab --- /dev/null +++ b/skyflow/utils/_helpers.py @@ -0,0 +1,11 @@ +from urllib.parse import urlparse + +def get_base_url(url): + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + return base_url + +def format_scope(scopes): + if not scopes: + return None + return " ".join([f"role:{scope}" for scope in scopes]) \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index efbdd8a..70cc791 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -16,36 +16,56 @@ class ErrorCodes(Enum): class Error(Enum): EMPTY_VAULT_ID = f"{error_prefix} Initialization failed. Invalid vault Id. Specify a valid vault Id." INVALID_VAULT_ID = f"{error_prefix} Initialization failed. Invalid vault Id. Specify a valid vault Id as a string." - EMPTY_CLUSTER_ID = f"{error_prefix} Initialization failed. Invalid cluster Id. Specify a valid cluster Id." - INVALID_CLUSTER_ID = f"{error_prefix} Initialization failed. Invalid cluster Id. Specify cluster Id as a string." - INVALID_ENV = f"{error_prefix} Initialization failed. Invalid env. Specify a valid env." + EMPTY_CLUSTER_ID = f"{error_prefix} Initialization failed. Invalid cluster Id for vault with id {{}}. Specify a valid cluster Id." + INVALID_CLUSTER_ID = f"{error_prefix} Initialization failed. Invalid cluster Id for vault with id {{}}. Specify cluster Id as a string." + INVALID_ENV = f"{error_prefix} Initialization failed. Invalid env for vault with id {{}}. Specify a valid env." INVALID_KEY = f"{error_prefix} Initialization failed. Invalid {{}}. Specify a valid key" - VAULT_ID_NOT_IN_CONFIG_LIST = f"{error_prefix} Validation error. {{}} is missing from the config. Specify the vaultId's from config." - - EMPTY_CREDENTIALS = f"{error_prefix} Validation error. Invalid credentials. Credentials must not be empty." + VAULT_ID_NOT_IN_CONFIG_LIST = f"{error_prefix} Validation error. Vault id {{}} is missing from the config. Specify the vault_id's from config." + EMPTY_VAULT_CONFIGS = f"{error_prefix} Validation error. Specify at least one vault config." + EMPTY_CONNECTION_CONFIGS = f"{error_prefix} Validation error. Specify at least one connection config." + VAULT_ID_ALREADY_EXISTS =f"{error_prefix} Initialization failed. vault with id {{}} already exists." + CONNECTION_ID_ALREADY_EXISTS = f"{error_prefix} Initialization failed. connection with id {{}} already exists." + + EMPTY_CREDENTIALS = f"{error_prefix} Validation error. Invalid credentials for {{}} with id {{}}. Credentials must not be empty." + INVALID_CREDENTIALS_IN_CONFIG = f"{error_prefix} Validation error. Invalid credentials for {{}} with id {{}}. Specify a valid credentials." INVALID_CREDENTIALS = f"{error_prefix} Validation error. Invalid credentials. Specify a valid credentials." + MULTIPLE_CREDENTIALS_PASSED_IN_CONFIG = f"{error_prefix} Validation error. Multiple credentials provided for {{}} with id {{}}. Please specify only one valid credential." MULTIPLE_CREDENTIALS_PASSED = f"{error_prefix} Validation error. Multiple credentials provided. Please specify only one valid credential." + EMPTY_CREDENTIALS_STRING_IN_CONFIG = f"{error_prefix} Validation error. Invalid credentials for {{}} with id {{}}. Specify valid credentials." EMPTY_CREDENTIALS_STRING = f"{error_prefix} Validation error. Invalid credentials. Specify valid credentials." + INVALID_CREDENTIALS_STRING_IN_CONFIG = f"{error_prefix} Validation error. Invalid credentials for {{}} with id {{}}. Specify credentials as a string." INVALID_CREDENTIALS_STRING = f"{error_prefix} Validation error. Invalid credentials. Specify credentials as a string." + EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Specify a valid file path." EMPTY_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Specify a valid file path." + INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Expected file path to be a string." INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a string." - EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Specify a valid credentials token." + EMPTY_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid token for {{}} with id {{}}.Specify a valid credentials token." + EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token." + INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string." INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string." EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." - EMPTY_API_KEY = f"{error_prefix} Initialization failed. Specify a valid api key." + EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key." + EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key." + INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string." INVALID_API_KEY = f"{error_prefix} Initialization failed. Invalid api key. Expected api key to be a string." + INVALID_ROLES_KEY_TYPE_IN_CONFIG = f"{error_prefix} Validation error. Invalid roles for {{}} with id {{}}. Specify roles as an array." INVALID_ROLES_KEY_TYPE = f"{error_prefix} Validation error. Invalid roles. Specify roles as an array." + EMPTY_ROLES_IN_CONFIG = f"{error_prefix} Validation error. Invalid roles for {{}} with id {{}}. Specify at least one role." EMPTY_ROLES = f"{error_prefix} Validation error. Invalid roles. Specify at least one role." + EMPTY_CONTEXT_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid context provided for {{}} with id {{}}. Specify context as type Context." EMPTY_CONTEXT = f"{error_prefix} Initialization failed. Invalid context provided. Specify context as type Context." + INVALID_CONTEXT_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid context for {{}} with id {{}}. Specify a valid context." INVALID_CONTEXT = f"{error_prefix} Initialization failed. Invalid context. Specify a valid context." INVALID_LOG_LEVEL = f"{error_prefix} Initialization failed. Invalid log level. Specify a valid log level." EMPTY_LOG_LEVEL = f"{error_prefix} Initialization failed. Specify a valid log level." EMPTY_CONNECTION_ID = f"{error_prefix} Initialization failed. Invalid connection Id. Specify a valid connection Id." INVALID_CONNECTION_ID = f"{error_prefix} Initialization failed. Invalid connection Id. Specify connection Id as a string." - EMPTY_CONNECTION_URL = f"{error_prefix} Initialization failed. Invalid connection Url. Specify a valid connection Url." - INVALID_CONNECTION_URL = f"{error_prefix} Initialization failed. Invalid connection Url. Specify connection Url as a string." + EMPTY_CONNECTION_URL = f"{error_prefix} Initialization failed. Invalid connection Url for connection with id {{}}. Specify a valid connection Url." + INVALID_CONNECTION_URL = f"{error_prefix} Initialization failed. Invalid connection Url for connection with id {{}}. Specify connection Url as a string." CONNECTION_ID_NOT_IN_CONFIG_LIST = f"{error_prefix} Validation error. {{}} is missing from the config. Specify the connectionIds from config." + RESPONSE_NOT_JSON = f"{error_prefix} Response {{}} is not valid JSON." + API_ERROR = f"{error_prefix} Server returned status code {{}}" MISSING_TABLE_NAME_IN_INSERT = f"{error_prefix} Validation error. Table name cannot be empty in insert request. Specify a table name." INVALID_TABLE_NAME_IN_INSERT = f"{error_prefix} Validation error. Invalid table name in insert request. Specify a valid table name." @@ -117,6 +137,7 @@ class Error(Enum): JWT_INVALID_FORMAT = f"{error_prefix} Initialization failed. Invalid private key format. Verify your credentials." JWT_DECODE_ERROR = f"{error_prefix} Validation error. Invalid access token. Verify your credentials." FILE_INVALID_JSON = f"{error_prefix} Initialization failed. File at {{}} is not in valid JSON format. Verify the file contents." + INVALID_JSON_FORMAT = f"{error_prefix} Validation error. Invalid JSON format in SKYFLOW_CREDENTIALS environment variable." class Info(Enum): INITIALIZE_CLIENT = "Initializing skyflow client" diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 94d6bf9..5e214ac 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -1,63 +1,52 @@ import os import json -from urllib.parse import urlparse import urllib.parse +from requests.sessions import PreparedRequest +from requests.models import HTTPError import requests import platform import sys -from requests import PreparedRequest from skyflow.error import SkyflowError -from skyflow.generated.rest import RedactionEnumREDACTION, V1UpdateRecordResponse, V1BulkDeleteRecordResponse, \ +from skyflow.generated.rest import V1UpdateRecordResponse, V1BulkDeleteRecordResponse, \ V1DetokenizeResponse, V1TokenizeResponse, V1GetQueryResponse, V1BulkGetRecordResponse -from . import SkyflowMessages, SDK_VERSION, log_error -from .enums import Env, ContentType, Redaction -import skyflow.generated.rest as vault_client +from skyflow.utils.logger import log_error + +from . import SkyflowMessages, SDK_VERSION +from .enums import Env, ContentType, EnvUrls from skyflow.vault.data import InsertResponse, UpdateResponse, DeleteResponse, QueryResponse, GetResponse from .validations import validate_invoke_connection_params +from ..vault.connection import InvokeConnectionResponse from ..vault.tokens import DetokenizeResponse, TokenizeResponse invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value -def get_credentials(config_level_creds = None, common_skyflow_creds = None): +def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None): env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") if config_level_creds: return config_level_creds if common_skyflow_creds: return common_skyflow_creds if env_skyflow_credentials: - return env_skyflow_credentials + try: + env_creds = json.loads(env_skyflow_credentials) + return env_creds + except json.JSONDecodeError: + raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT.value, invalid_input_error_code, logger = logger) else: - raise Exception("Invalid Credentials") - pass - + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code, logger = logger) -def get_vault_url(cluster_id, env): - if env == Env.PROD: - return f"http://{cluster_id}.vault.skyflowapis.com" - elif env == Env.SANDBOX: - return f"https://{cluster_id}.vault.skyflowapis-preview.com" - elif env == Env.DEV: - return f"https://{cluster_id}.vault.skyflowapis.dev" - else: - return f"https://{cluster_id}.vault.skyflowapis.com" -def get_client_configuration(vault_url, bearer_token): - return vault_client.Configuration( - host=vault_url, - api_key_prefix="Bearer", - api_key=bearer_token - ) +def get_vault_url(cluster_id, env,vault_id, logger = None): + if not cluster_id or not isinstance(cluster_id, str) or not cluster_id.strip(): + raise SkyflowError(SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id), invalid_input_error_code, logger = logger) -def get_base_url(url): - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - return base_url + if env not in Env: + raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code, logger = logger) -def format_scope(scopes): - if not scopes: - return None - return " ".join([f"role:{scope}" for scope in scopes]) + base_url = EnvUrls[env.name].value + protocol = "https" if env != Env.PROD else "http" + return f"{protocol}://{cluster_id}.{base_url}" def parse_path_params(url, path_params): result = url @@ -67,9 +56,9 @@ def parse_path_params(url, path_params): return result def to_lowercase_keys(dict): - ''' + """ convert keys of dictionary to lowercase - ''' + """ result = {} for key, value in dict.items(): result[key.lower()] = value @@ -84,12 +73,12 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep header = to_lowercase_keys(json.loads( json.dumps(request.request_headers))) else: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code, logger = logger) except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code, logger = logger) if not 'Content-Type'.lower() in header: - header['content-type'] = ContentType.JSON + header['content-type'] = ContentType.JSON.value try: if isinstance(request.body, dict): @@ -97,37 +86,37 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep request.body, header["content-type"] ) else: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code, logger = logger) except Exception as e: - raise SkyflowError( SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError( SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code, logger = logger) validate_invoke_connection_params(logger, request.query_params, request.path_params) try: return requests.Request( - method = request.method, + method = request.method.value, url = url, data = json_data, headers = header, - params = request.params, + params = request.query_params, files = files ).prepare() except requests.exceptions.InvalidURL: - raise SkyflowError(SkyflowMessages.Error.INVALID_URL.value.format(connection_url), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_URL.value.format(connection_url), invalid_input_error_code, logger = logger) def http_build_query(data): - ''' + """ Creates a form urlencoded string from python dictionary urllib.urlencode() doesn't encode it in a php-esque way, this function helps in that - ''' + """ return urllib.parse.urlencode(r_urlencode(list(), dict(), data)) def r_urlencode(parents, pairs, data): - ''' + """ convert the python dict recursively into a php style associative dictionary - ''' + """ if isinstance(data, list) or isinstance(data, tuple): for i in range(len(data)): parents.append(i) @@ -144,9 +133,9 @@ def r_urlencode(parents, pairs, data): return pairs def render_key(parents): - ''' + """ renders the nested dictionary key as an associative array (php style dict) - ''' + """ depth, out_str = 0, '' for x in parents: s = "[%s]" if depth > 0 or isinstance(x, int) else "%s" @@ -155,25 +144,25 @@ def render_key(parents): return out_str def get_data_from_content_type(data, content_type): - ''' + """ Get request data according to content type - ''' + """ converted_data = data files = {} - if content_type == ContentType.URLENCODED: + if content_type == ContentType.URLENCODED.value: converted_data = http_build_query(data) - elif content_type == ContentType.FORMDATA: + elif content_type == ContentType.FORMDATA.value: converted_data = r_urlencode(list(), dict(), data) files = {(None, None)} - elif content_type == ContentType.JSON: + elif content_type == ContentType.JSON.value: converted_data = json.dumps(data) return converted_data, files def get_metrics(): - ''' fetch metrics - ''' + """ fetch metrics + """ sdk_name_version = "skyflow-python@" + SDK_VERSION try: @@ -320,6 +309,38 @@ def parse_query_response(api_response: V1GetQueryResponse): query_response.fields = fields return query_response +def parse_invoke_connection_response(api_response: requests.Response): + invoke_connection_response = InvokeConnectionResponse() + + status_code = api_response.status_code + content = api_response.content.decode('utf-8') + try: + api_response.raise_for_status() + try: + json_content = json.loads(content) + if 'x-request-id' in api_response.headers: + request_id = api_response.headers['x-request-id'] + json_content['request_id'] = request_id + + invoke_connection_response.response = json_content + return invoke_connection_response + except: + raise SkyflowError(SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content), status_code) + except HTTPError: + message = SkyflowMessages.Error.API_ERROR.value.format(status_code) + if api_response and api_response.content: + try: + error_response = json.loads(content) + if isinstance(error_response.get('error'), dict) and 'message' in error_response['error']: + message = error_response['error']['message'] + except json.JSONDecodeError: + message = SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content) + + if 'x-request-id' in api_response.headers: + message += ' - request id: ' + api_response.headers['x-request-id'] + + raise SkyflowError(message, status_code) + def log_and_reject_error(description, status_code, request_id, http_status=None, grpc_code=None, details=None, logger = None): log_error(description, status_code, request_id, grpc_code, http_status, details, logger= logger) @@ -329,7 +350,6 @@ def handle_exception(error, logger): content_type = error.headers.get('content-type') data = error.body - # Call relevant handler based on content type if content_type: if 'application/json' in content_type: handle_json_error(error, data, request_id, logger) @@ -354,8 +374,8 @@ def handle_json_error(err, data, request_id, logger): log_and_reject_error("Invalid JSON response received.", err, request_id, logger = logger) def handle_text_error(err, data, request_id, logger): - log_and_reject_error(data, err, request_id, logger = logger) + log_and_reject_error(data, err.status, request_id, logger = logger) def handle_generic_error(err, request_id, logger): description = "An error occurred." - log_and_reject_error(description, err, request_id, logger = logger) + log_and_reject_error(description, err.status, request_id, logger = logger) diff --git a/skyflow/utils/enums/__init__.py b/skyflow/utils/enums/__init__.py index 63e6e65..6a010ba 100644 --- a/skyflow/utils/enums/__init__.py +++ b/skyflow/utils/enums/__init__.py @@ -1,8 +1,6 @@ -from .env import Env +from .env import Env, EnvUrls from .log_level import LogLevel from .content_types import ContentType -from .interface_name import InterfaceName from .token_strict import TokenStrict -from .batch_method import BatchMethod -from .redaction_type import Redaction -from .order_by import OrderBy \ No newline at end of file +from .method import Method +from .redaction_type import Redaction \ No newline at end of file diff --git a/skyflow/utils/enums/env.py b/skyflow/utils/enums/env.py index 4ffdb00..97120c4 100644 --- a/skyflow/utils/enums/env.py +++ b/skyflow/utils/enums/env.py @@ -3,4 +3,9 @@ class Env(Enum): DEV = 'DEV', SANDBOX = 'SANDBOX', - PROD = 'PROD' \ No newline at end of file + PROD = 'PROD' + +class EnvUrls(Enum): + PROD = "vault.skyflowapis.com", + SANDBOX = "vault.skyflowapis-preview.com", + DEV = "vault.skyflowapis.dev" \ No newline at end of file diff --git a/skyflow/utils/enums/interface_name.py b/skyflow/utils/enums/interface_name.py deleted file mode 100644 index cf49dc9..0000000 --- a/skyflow/utils/enums/interface_name.py +++ /dev/null @@ -1,15 +0,0 @@ -from enum import Enum - -class InterfaceName(Enum): - CLIENT = "client" - INSERT = "client.insert" - DETOKENIZE = "client.detokenize" - GET_BY_ID = "client.get_by_id" - GET = "client.get" - UPDATE = "client.update" - INVOKE_CONNECTION = "client.invoke_connection" - QUERY = "client.query" - GENERATE_BEARER_TOKEN = "service_account.generate_bearer_token" - IS_TOKEN_VALID = "service_account.isTokenValid" - IS_EXPIRED = "service_account.is_expired" - DELETE = "client.delete" \ No newline at end of file diff --git a/skyflow/utils/enums/batch_method.py b/skyflow/utils/enums/method.py similarity index 81% rename from skyflow/utils/enums/batch_method.py rename to skyflow/utils/enums/method.py index 18300af..0561e8f 100644 --- a/skyflow/utils/enums/batch_method.py +++ b/skyflow/utils/enums/method.py @@ -1,6 +1,6 @@ from enum import Enum -class BatchMethod(Enum): +class Method(Enum): GET = "GET" POST = "POST" PUT = "PUT" diff --git a/skyflow/utils/enums/order_by.py b/skyflow/utils/enums/order_by.py deleted file mode 100644 index 4c93731..0000000 --- a/skyflow/utils/enums/order_by.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - -class OrderBy(Enum): - ASCENDING = "ASCENDING" - DESCENDING = "DESCENDING" - NONE = "NONE" \ No newline at end of file diff --git a/skyflow/utils/enums/token_strict.py b/skyflow/utils/enums/token_strict.py index 92c0043..b211a86 100644 --- a/skyflow/utils/enums/token_strict.py +++ b/skyflow/utils/enums/token_strict.py @@ -1,6 +1,9 @@ from enum import Enum +from skyflow.generated.rest import V1BYOT + + class TokenStrict(Enum): - DISABLE = "DISABLE" - ENABLE = "ENABLE" - ENABLE_STRICT = "ENABLE_STRICT" \ No newline at end of file + DISABLE = V1BYOT.DISABLE + ENABLE = V1BYOT.ENABLE + ENABLE_STRICT = V1BYOT.ENABLE_STRICT \ No newline at end of file diff --git a/skyflow/utils/logger/__init__.py b/skyflow/utils/logger/__init__.py new file mode 100644 index 0000000..465e249 --- /dev/null +++ b/skyflow/utils/logger/__init__.py @@ -0,0 +1,2 @@ +from ._logger import Logger +from ._log_helpers import log_error, log_info \ No newline at end of file diff --git a/skyflow/utils/_log_helpers.py b/skyflow/utils/logger/_log_helpers.py similarity index 96% rename from skyflow/utils/_log_helpers.py rename to skyflow/utils/logger/_log_helpers.py index bf4b6af..40fcf09 100644 --- a/skyflow/utils/_log_helpers.py +++ b/skyflow/utils/logger/_log_helpers.py @@ -1,4 +1,4 @@ -from .enums import LogLevel +from ..enums import LogLevel from . import Logger diff --git a/skyflow/utils/_logger.py b/skyflow/utils/logger/_logger.py similarity index 97% rename from skyflow/utils/_logger.py rename to skyflow/utils/logger/_logger.py index 0827abd..e41e152 100644 --- a/skyflow/utils/_logger.py +++ b/skyflow/utils/logger/_logger.py @@ -1,6 +1,5 @@ import logging -from .enums.log_level import LogLevel - +from ..enums.log_level import LogLevel class Logger: def __init__(self, level=LogLevel.ERROR): diff --git a/skyflow/utils/validations/__init__.py b/skyflow/utils/validations/__init__.py index d78e625..17bc49a 100644 --- a/skyflow/utils/validations/__init__.py +++ b/skyflow/utils/validations/__init__.py @@ -12,5 +12,5 @@ validate_update_request, validate_detokenize_request, validate_tokenize_request, - validate_invoke_connection_params + validate_invoke_connection_params, ) \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index bb839ff..7118f88 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -1,9 +1,8 @@ import json - -from skyflow.service_account import is_expired +from skyflow.service_account import is_expired, validate_api_key from skyflow.utils.enums import LogLevel, TokenStrict, Redaction, Env from skyflow.error import SkyflowError -from skyflow.utils import SkyflowMessages, log_error +from skyflow.utils import SkyflowMessages valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] @@ -14,71 +13,102 @@ def validate_required_field(logger, config, field_name, expected_type, empty_err field_value = config.get(field_name) if field_name not in config or not isinstance(field_value, expected_type): - raise SkyflowError(invalid_error, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(invalid_error, invalid_input_error_code, logger = logger) if isinstance(field_value, str) and not field_value.strip(): - raise SkyflowError(empty_error, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(empty_error, invalid_input_error_code, logger = logger) + -def validate_credentials(logger, credentials): +def validate_credentials(logger, credentials, config_id_type=None, config_id=None): key_present = [k for k in ["path", "token", "credentials_string", "api_key"] if credentials.get(k)] + if len(key_present) == 0: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + error_message = ( + SkyflowMessages.Error.INVALID_CREDENTIALS_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else + SkyflowMessages.Error.INVALID_CREDENTIALS.value + ) + raise SkyflowError(error_message, invalid_input_error_code, logger=logger) elif len(key_present) > 1: - raise SkyflowError(SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED.value, invalid_input_error_code, logger = logger, logger_method=log_error) + error_message = ( + SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else + SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED.value + ) + raise SkyflowError(error_message, invalid_input_error_code, logger=logger) if "roles" in credentials: validate_required_field( logger, credentials, "roles", list, - SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE.value, - SkyflowMessages.Error.EMPTY_ROLES.value + SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE.value, + SkyflowMessages.Error.EMPTY_ROLES_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.EMPTY_ROLES.value ) if "context" in credentials: validate_required_field( logger, credentials, "context", str, - SkyflowMessages.Error.EMPTY_CONTEXT.value, - SkyflowMessages.Error.INVALID_CONTEXT.value + SkyflowMessages.Error.EMPTY_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CONTEXT.value, + SkyflowMessages.Error.INVALID_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.INVALID_CONTEXT.value ) if "credentials_string" in credentials: validate_required_field( logger, credentials, "credentials_string", str, - SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING.value, - SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value + SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING.value, + SkyflowMessages.Error.INVALID_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value ) elif "path" in credentials: validate_required_field( logger, credentials, "path", str, - SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH.value, - SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value + SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH.value, + SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value ) elif "token" in credentials: validate_required_field( logger, credentials, "token", str, - SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value, - SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value + SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value, + SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value ) - if is_expired(credentials.get("token"), logger): - raise SkyflowError() + raise SkyflowError( + SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value, + invalid_input_error_code, logger=logger + ) elif "api_key" in credentials: validate_required_field( logger, credentials, "api_key", str, - SkyflowMessages.Error.EMPTY_API_KEY.value, - SkyflowMessages.Error.INVALID_API_KEY.value + SkyflowMessages.Error.EMPTY_API_KEY.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.EMPTY_API_KEY.value, + SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value ) + if not validate_api_key(credentials.get("api_key")): + raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) + if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, + invalid_input_error_code, logger=logger) def validate_log_level(logger, log_level): if not isinstance(log_level, LogLevel): - raise SkyflowError( SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError( SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code, logger = logger) if log_level is None: - raise SkyflowError(SkyflowMessages.Error.EMPTY_LOG_LEVEL.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_LOG_LEVEL.value, invalid_input_error_code, logger = logger) def validate_keys(logger, config, config_keys): for key in config.keys(): if key not in config_keys: - raise SkyflowError(SkyflowMessages.Error.INVALID_KEY.value.format(key), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_KEY.value.format(key), invalid_input_error_code, logger = logger) def validate_vault_config(logger, config): @@ -90,23 +120,23 @@ def validate_vault_config(logger, config): SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - + vault_id = config.get("vault_id") # Validate cluster_id (string, not empty) validate_required_field( logger, config, "cluster_id", str, - SkyflowMessages.Error.EMPTY_CLUSTER_ID.value, - SkyflowMessages.Error.INVALID_CLUSTER_ID.value + SkyflowMessages.Error.EMPTY_CLUSTER_ID.value.format(vault_id), + SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id) ) # Validate credentials (dict, not empty) if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code, logger = logger) - validate_credentials(logger, config.get("credentials")) + validate_credentials(logger, config.get("credentials"), "vault", vault_id) # Validate env (optional, should be one of LogLevel values) if "env" in config and config.get("env") not in Env: - raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code, logger = logger) return True @@ -121,16 +151,18 @@ def validate_update_vault_config(logger, config): SkyflowMessages.Error.INVALID_VAULT_ID.value ) + vault_id = config.get("vault_id") + if "cluster_id" in config and not config.get("cluster_id"): - raise SkyflowError(SkyflowMessages.Error.INVALID_CLUSTER_ID.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id), invalid_input_error_code, logger = logger) if "env" in config and config.get("env") not in LogLevel: - raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code, logger = logger) if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code, logger = logger) - validate_credentials(logger, config.get("credentials")) + validate_credentials(logger, config.get("credentials"), "vault", vault_id) return True @@ -143,16 +175,18 @@ def validate_connection_config(logger, config): SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) + connection_id = config.get("connection_id") + validate_required_field( logger, config, "connection_url", str, - SkyflowMessages.Error.EMPTY_CONNECTION_URL.value, - SkyflowMessages.Error.INVALID_CONNECTION_URL.value + SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), + SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code, logger = logger) - validate_credentials(logger, config.get("credentials")) + validate_credentials(logger, config.get("credentials"), "connection", connection_id) return True @@ -166,14 +200,16 @@ def validate_update_connection_config(logger, config): SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) + connection_id = config.get("connection_id") + validate_required_field( logger, config, "connection_url", str, - SkyflowMessages.Error.EMPTY_CONNECTION_URL.value, - SkyflowMessages.Error.INVALID_CONNECTION_URL.value + SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), + SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code, logger = logger) validate_credentials(logger, config.get("credentials")) return True @@ -181,73 +217,73 @@ def validate_update_connection_config(logger, config): def validate_insert_request(logger, request): if not isinstance(request.table_name, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_NAME_IN_INSERT.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_NAME_IN_INSERT.value, invalid_input_error_code, logger = logger) if not request.table_name.strip(): - raise SkyflowError(SkyflowMessages.Error.MISSING_TABLE_NAME_IN_INSERT.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.MISSING_TABLE_NAME_IN_INSERT.value, invalid_input_error_code, logger = logger) if not isinstance(request.values, list) or not all(isinstance(v, dict) for v in request.values): - raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code, logger = logger) if not len(request.values): - raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code, logger = logger) if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): - raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code, logger = logger) if not isinstance(request.homogeneous, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_HOMOGENEOUS_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_HOMOGENEOUS_TYPE.value, invalid_input_error_code, logger = logger) if request.token_strict is not None: if not isinstance(request.token_strict, TokenStrict): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_STRICT_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_STRICT_TYPE.value, invalid_input_error_code, logger = logger) if not isinstance(request.return_tokens, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code, logger = logger) if not isinstance(request.continue_on_error, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code, logger = logger) if request.tokens: if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code, - logger=logger, logger_method=log_error) + logger=logger) if request.token_strict == TokenStrict.ENABLE and not request.tokens: - raise SkyflowError(SkyflowMessages.Error.NO_TOKENS_IN_INSERT.value.format(request.token_Strict), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.NO_TOKENS_IN_INSERT.value.format(request.token_Strict), invalid_input_error_code, logger = logger) if request.token_strict == TokenStrict.DISABLE and request.tokens: - raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_STRICT_DISABLE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_STRICT_DISABLE.value, invalid_input_error_code, logger = logger) if request.token_strict == TokenStrict.ENABLE_STRICT: if len(request.values) != len(request.tokens): - raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value, invalid_input_error_code, logger = logger) for v, t in zip(request.values, request.tokens): if set(v.keys()) != set(t.keys()): - raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value, invalid_input_error_code, logger = logger) def validate_delete_request(logger, request): if not isinstance(request.table, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code, logger = logger) if not request.table.strip(): - raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code, logger = logger) if not request.ids: - raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code, logger = logger) def validate_query_request(logger, request): if not isinstance(request.query, str): query_type = str(type(request.query)) - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code, logger = logger) if not request.query.strip(): - raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code, logger = logger) if not request.query.upper().startswith("SELECT"): command = request.query - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code, logger = logger) def validate_get_request(logger, request): redaction_type = request.redaction_type @@ -260,140 +296,160 @@ def validate_get_request(logger, request): download_url = request.download_url if skyflow_ids and (not isinstance(skyflow_ids, list) or not skyflow_ids): - raise SkyflowError(SkyflowMessages.Error.INVALID_IDS_TYPE.value.format(type(skyflow_ids)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_IDS_TYPE.value.format(type(skyflow_ids)), invalid_input_error_code, logger = logger) if not isinstance(request.return_tokens, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code, logger = logger) if redaction_type is not None and not isinstance(redaction_type, Redaction): - raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(redaction_type)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(redaction_type)), invalid_input_error_code, logger = logger) if fields is not None and (not isinstance(fields, list) or not fields): - raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(fields)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(fields)), invalid_input_error_code, logger = logger) if offset is not None and limit is not None: raise SkyflowError( SkyflowMessages.Error.BOTH_OFFSET_AND_LIMIT_SPECIFIED.value, - invalid_input_error_code, logger=logger, logger_method=log_error) + invalid_input_error_code, logger=logger) if offset is not None and not isinstance(offset, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value(type(offset)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value(type(offset)), invalid_input_error_code, logger = logger) if limit is not None and not isinstance(limit, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value(type(limit)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value(type(limit)), invalid_input_error_code, logger = logger) if download_url is not None and not isinstance(download_url, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value(type(download_url)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value(type(download_url)), invalid_input_error_code, logger = logger) if column_name is not None and (not isinstance(column_name, str) or not column_name.strip()): - raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code, logger = logger) if column_values is not None and ( not isinstance(column_values, list) or not column_values or not all( isinstance(val, str) for val in column_values)): - raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code, logger = logger) if request.return_tokens and redaction_type: - raise SkyflowError(SkyflowMessages.Error.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value, invalid_input_error_code, logger = logger) if (column_name or column_values) and request.return_tokens: - raise SkyflowError(SkyflowMessages.Error.TOKENS_GET_COLUMN_NOT_SUPPORTED.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.TOKENS_GET_COLUMN_NOT_SUPPORTED.value, invalid_input_error_code, logger = logger) if column_values and not column_name: - raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code, logger = logger) if column_name and not column_values: - SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code, logger = logger, logger_method=log_error) + SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code, logger = logger) if (column_name or column_values) and skyflow_ids: - raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code, logger = logger) def validate_update_request(logger, request): + field = [{key: value for key, value in request.data.items() if key != "id"}] if not isinstance(request.table, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code, logger = logger) if not request.table.strip(): - raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code, logger = logger) if not isinstance(request.return_tokens, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code, logger = logger) if not isinstance(request.data, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value(type(request.data)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value(type(request.data)), invalid_input_error_code, logger = logger) if not len(request.data.items()): - raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code, logger = logger) if request.token_strict is not None: if not isinstance(request.token_strict, TokenStrict): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_STRICT_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_STRICT_TYPE.value, invalid_input_error_code, logger = logger) if request.tokens: if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code, - logger=logger, logger_method=log_error) + logger=logger) + + if request.token_strict == TokenStrict.ENABLE and not request.tokens: + raise SkyflowError(SkyflowMessages.Error.NO_TOKENS_IN_INSERT.value.format(request.token_Strict), + invalid_input_error_code, logger=logger) + + if request.token_strict == TokenStrict.DISABLE and request.tokens: + raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_STRICT_DISABLE.value, invalid_input_error_code, + logger=logger) + + if request.token_strict == TokenStrict.ENABLE_STRICT: + if len(field) != len(request.tokens): + raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value, + invalid_input_error_code, logger=logger) + + for v, t in zip(field, request.tokens): + if set(v.keys()) != set(t.keys()): + raise SkyflowError( + SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value, + invalid_input_error_code, logger=logger) if 'id' not in request.data: - raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code, logger = logger) def validate_detokenize_request(logger, request): if not isinstance(request.redaction_type, Redaction): - raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(request.redaction_type)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(request.redaction_type)), invalid_input_error_code, logger = logger) if not isinstance(request.continue_on_error, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code, logger = logger) if not len(request.tokens): - raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code, logger = logger) if not isinstance(request.tokens, list): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.tokens)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.tokens)), invalid_input_error_code, logger = logger) def validate_tokenize_request(logger, request): parameters = request.tokenize_parameters if not isinstance(parameters, list): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(parameters)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(parameters)), invalid_input_error_code, logger = logger) if not len(parameters): - raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value, invalid_input_error_code, logger = logger) for i, param in enumerate(parameters): if not isinstance(param, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER.value.format(i, type(param)), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER.value.format(i, type(param)), invalid_input_error_code, logger = logger) allowed_keys = {"value", "column_group"} if set(param.keys()) != allowed_keys: - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER_KEY.value.format(i), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER_KEY.value.format(i), invalid_input_error_code, logger = logger) if not param.get("value"): - raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_VALUE.value.format(i), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_VALUE.value.format(i), invalid_input_error_code, logger = logger) if not param.get("column_group"): - raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_COLUMN_GROUP.value.format(i), invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_COLUMN_GROUP.value.format(i), invalid_input_error_code, logger = logger) def validate_invoke_connection_params(logger, query_params, path_params): if not isinstance(path_params, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_PATH_PARAMS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_PATH_PARAMS.value, invalid_input_error_code, logger = logger) if not isinstance(query_params, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code, logger = logger) for param, value in path_params.items(): if not(isinstance(param, str) and isinstance(value, str)): - raise SkyflowError(SkyflowMessages.Error.INVALID_PATH_PARAMS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_PATH_PARAMS.value, invalid_input_error_code, logger = logger) for param, value in query_params.items(): if not isinstance(param, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code, logger = logger) try: json.dumps(query_params) except TypeError: - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code, logger = logger, logger_method=log_error) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code, logger = logger) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index 23a503b..c492ed0 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -1,6 +1,7 @@ from skyflow.generated.rest import Configuration, RecordsApi, ApiClient, TokensApi, QueryApi from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds -from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages, log_info +from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages +from skyflow.utils.logger import log_info class VaultClient: @@ -11,6 +12,8 @@ def __init__(self, config): self.__client_configuration = None self.__api_client = None self.__logger = None + self.__is_config_updated = False + self.__bearer_token = None def set_common_skyflow_credentials(self, credentials): self.__common_skyflow_credentials = credentials @@ -20,11 +23,13 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials) - bearer_token = self.get_bearer_token(credentials) + credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger) + token = self.get_bearer_token(credentials) vault_url = get_vault_url(self.__config.get("cluster_id"), - self.__config.get("env")) - self.__client_configuration = Configuration(host=vault_url, access_token=bearer_token) + self.__config.get("env"), + self.__config.get("vault_id"), + logger = self.__logger) + self.__client_configuration = Configuration(host=vault_url, access_token=token) self.initialize_api_client(self.__client_configuration) def initialize_api_client(self, config): @@ -48,29 +53,36 @@ def get_bearer_token(self, credentials): return credentials.get('api_key') elif 'token' in credentials: return credentials.get("token") - elif 'path' in credentials: - credentials = self.__config.get("credentials") - options = { - "role_ids": self.__config.get("roles"), - "ctx": self.__config.get("ctx") - } - log_info(self.__logger, SkyflowMessages.Info.GENERATE_BEARER_TOKEN_TRIGGERED, interface) - token, _ = generate_bearer_token(credentials.get("path"), options, self.__logger) - log_info(self.__logger, SkyflowMessages.Info.GENERATE_BEARER_TOKEN_SUCCESS, interface) - return token - else: - credentials = self.__config.get("credentials") - options = { - "role_ids": self.__config.get("roles"), - "ctx": self.__config.get("ctx") - } - log_info(self.__logger, SkyflowMessages.Info.GENERATE_BEARER_TOKEN_TRIGGERED, interface) - token, _ = generate_bearer_token_from_creds(credentials.get("credentials_string"), options, self.__logger) - log_info(self.__logger, SkyflowMessages.Info.GENERATE_BEARER_TOKEN_SUCCESS, interface) - return token + + options = { + "role_ids": self.__config.get("roles"), + "ctx": self.__config.get("ctx") + } + + log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_TRIGGERED, interface, self.__logger) + + if self.__bearer_token is None or self.__is_config_updated: + if 'path' in credentials: + path = credentials.get("path") + self.__bearer_token, _ = generate_bearer_token( + path, + options, + self.__logger + ) + else: + credentials_string = credentials.get('credentials_string') + self.__bearer_token, _ = generate_bearer_token_from_creds( + credentials_string, + options, + self.__logger + ) + self.__is_config_updated = False + log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_SUCCESS, interface, self.__logger) + return self.__bearer_token def update_config(self, config): self.__config.update(config) + self.__is_config_updated = True def get_config(self): return self.__config diff --git a/skyflow/vault/connection/_invoke_connection_response.py b/skyflow/vault/connection/_invoke_connection_response.py index 67b2882..661b61d 100644 --- a/skyflow/vault/connection/_invoke_connection_response.py +++ b/skyflow/vault/connection/_invoke_connection_response.py @@ -6,7 +6,4 @@ def __repr__(self): return f"ConnectionResponse({self.response})" def __str__(self): - return self.__repr__() - - def parse_invoke_connection_response(self, response): - self.response = response \ No newline at end of file + return self.__repr__() \ No newline at end of file diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 508339a..8175270 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -2,18 +2,19 @@ import requests from skyflow.error import SkyflowError -from skyflow.utils import construct_invoke_connection_request, log_info, SkyflowMessages, get_metrics, log_error -from skyflow.vault.connection import InvokeConnectionRequest, InvokeConnectionResponse +from skyflow.utils import construct_invoke_connection_request, SkyflowMessages, get_metrics, \ + parse_invoke_connection_response +from skyflow.utils.logger import log_info +from skyflow.vault.connection import InvokeConnectionRequest class Connection: def __init__(self, vault_client): self.__vault_client = vault_client - self.logger = self.__vault_client.get_logger() def invoke(self, request: InvokeConnectionRequest): interface = SkyflowMessages.InterfaceName.INVOKE_CONNECTION.value - log_info(self.logger, SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED, interface) + log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED, interface, self.__vault_client.get_logger()) session = requests.Session() @@ -21,19 +22,20 @@ def invoke(self, request: InvokeConnectionRequest): bearer_token = self.__vault_client.get_bearer_token(config.get("credentials")) connection_url = config.get("connection_url") - invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.logger) + invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: - invoke_connection_request.headers['x-skyflow-authorization'] = f"Bearer ${bearer_token}" + invoke_connection_request.headers['x-skyflow-authorization'] = bearer_token invoke_connection_request.headers['sky-metadata'] = json.dumps(get_metrics()) - log_info(self.logger, SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED, interface) + log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED, interface, self.__vault_client.get_logger()) try: response = session.send(invoke_connection_request) session.close() - invoke_connection_response = InvokeConnectionResponse() - return invoke_connection_response.parse_invoke_connection_response(response) - except: - raise SkyflowError(SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value, SkyflowMessages.ErrorCodes.SERVER_ERROR, logger = self.logger, logger_method=log_error) \ No newline at end of file + invoke_connection_response = parse_invoke_connection_response(response) + return invoke_connection_response + except Exception as e: + print(e) + raise SkyflowError(SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value, SkyflowMessages.ErrorCodes.SERVER_ERROR, logger = self.__vault_client.get_logger()) \ No newline at end of file diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 5c73804..921c46d 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -2,13 +2,14 @@ V1DetokenizePayload, V1TokenizeRecordRequest, V1TokenizePayload, QueryServiceExecuteQueryBody, \ RecordServiceBulkDeleteRecordBody, RecordServiceUpdateRecordBody, RecordServiceBatchOperationBody, V1BatchRecord, \ BatchRecordMethod -from skyflow.generated.rest.exceptions import BadRequestException -from skyflow.utils import log_info, SkyflowMessages, parse_insert_response, \ +from skyflow.generated.rest.exceptions import BadRequestException, UnauthorizedException +from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ parse_tokenize_response, parse_query_response, parse_get_response +from skyflow.utils.logger import log_info from skyflow.utils.validations import validate_insert_request, validate_delete_request, validate_query_request, \ validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request -from skyflow.vault.data import InsertRequest, UpdateRequest, DeleteRequest, GetRequest, QueryRequest, GetResponse +from skyflow.vault.data import InsertRequest, UpdateRequest, DeleteRequest, GetRequest, QueryRequest from skyflow.vault.tokens import DetokenizeRequest, TokenizeRequest class Vault: @@ -30,6 +31,7 @@ def __build_batch_field_records(self, values, tokens, table_name, return_tokens, table_name=table_name, method=BatchRecordMethod.POST, tokenization=return_tokens, + upsert=upsert ) if token is not None: batch_record.tokens = token @@ -48,7 +50,7 @@ def __build_insert_body(self, request: InsertRequest): body = RecordServiceBatchOperationBody( records=records_list, continue_on_error=request.continue_on_error, - byot=request.token_strict + byot=request.token_strict.value ) return body else: @@ -86,6 +88,8 @@ def insert(self, request: InsertRequest): except BadRequestException as e: handle_exception(e, self.__vault_client.get_logger()) + except UnauthorizedException as e: + handle_exception(e, self.__vault_client.get_logger()) def update(self, request: UpdateRequest): interface = SkyflowMessages.InterfaceName.UPDATE @@ -109,6 +113,8 @@ def update(self, request: UpdateRequest): return update_response except Exception as e: handle_exception(e, self.__vault_client.get_logger()) + except UnauthorizedException as e: + handle_exception(e, self.__vault_client.get_logger()) def delete(self, request: DeleteRequest): interface = SkyflowMessages.InterfaceName.DELETE.value @@ -129,6 +135,8 @@ def delete(self, request: DeleteRequest): return delete_response except Exception as e: handle_exception(e, self.__vault_client.get_logger()) + except UnauthorizedException as e: + handle_exception(e, self.__vault_client.get_logger()) def get(self, request: GetRequest): interface = SkyflowMessages.InterfaceName.GET.value @@ -156,6 +164,8 @@ def get(self, request: GetRequest): return get_response except Exception as e: handle_exception(e, self.__vault_client.get_logger()) + except UnauthorizedException as e: + handle_exception(e, self.__vault_client.get_logger()) def query(self, request: QueryRequest): interface = SkyflowMessages.InterfaceName.QUERY.value @@ -174,6 +184,8 @@ def query(self, request: QueryRequest): return query_response except Exception as e: handle_exception(e, self.__vault_client.get_logger()) + except UnauthorizedException as e: + handle_exception(e, self.__vault_client.get_logger()) def detokenize(self, request: DetokenizeRequest): interface = SkyflowMessages.InterfaceName.DETOKENIZE.value @@ -197,6 +209,8 @@ def detokenize(self, request: DetokenizeRequest): return detokenize_response except Exception as e: handle_exception(e, self.__vault_client.get_logger()) + except UnauthorizedException as e: + handle_exception(e, self.__vault_client.get_logger()) def tokenize(self, request: TokenizeRequest): validate_tokenize_request(self.__vault_client.get_logger(), request) @@ -217,3 +231,5 @@ def tokenize(self, request: TokenizeRequest): return tokenize_response except Exception as e: handle_exception(e, self.__vault_client.get_logger()) + except UnauthorizedException as e: + handle_exception(e, self.__vault_client.get_logger()) diff --git a/skyflow/vault/data/_get_request.py b/skyflow/vault/data/_get_request.py index 84a617c..308d21f 100644 --- a/skyflow/vault/data/_get_request.py +++ b/skyflow/vault/data/_get_request.py @@ -1,6 +1,3 @@ -from skyflow.utils.enums import Redaction, OrderBy - - class GetRequest: def __init__(self, table,