Skip to content

Commit

Permalink
SK-1731: Fixed unit test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshwar-skyflow committed Nov 25, 2024
1 parent c0b03d5 commit b00b62f
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 162 deletions.
260 changes: 131 additions & 129 deletions skyflow/utils/_skyflow_messages.py

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg
load_dotenv(dotenv_path)
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
if config_level_creds:
return config_level_creds, False
return config_level_creds
if common_skyflow_creds:
return common_skyflow_creds, False
return common_skyflow_creds
if env_skyflow_credentials:
env_skyflow_credentials.strip()
try:
env_creds = json.loads(env_skyflow_credentials.replace('\n', '\\n'))
return env_creds, True
env_creds = env_skyflow_credentials.replace('\n', '\\n')
return {
'credentials_string': env_creds
}
except json.JSONDecodeError:
raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code)
else:
Expand Down
19 changes: 6 additions & 13 deletions skyflow/vault/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def set_logger(self, log_level, logger):
self.__logger = logger

def initialize_client_configuration(self):
credentials, env_creds = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
token = self.get_bearer_token(credentials, env_creds)
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
token = self.get_bearer_token(credentials)
vault_url = get_vault_url(self.__config.get("cluster_id"),
self.__config.get("env"),
self.__config.get("vault_id"),
Expand All @@ -48,7 +48,7 @@ def get_query_api(self):
def get_vault_id(self):
return self.__config.get("vault_id")

def get_bearer_token(self, credentials, env_creds):
def get_bearer_token(self, credentials):
if 'api_key' in credentials:
return credentials.get('api_key')
elif 'token' in credentials:
Expand All @@ -60,15 +60,7 @@ def get_bearer_token(self, credentials, env_creds):
}

if self.__bearer_token is None or self.__is_config_updated:
if env_creds:
log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value,
self.__logger)
self.__bearer_token, _ = generate_bearer_token_from_creds(
json.dumps(credentials),
options,
self.__logger
)
elif 'path' in credentials:
if 'path' in credentials:
path = credentials.get("path")
self.__bearer_token, _ = generate_bearer_token(
path,
Expand All @@ -84,12 +76,13 @@ def get_bearer_token(self, credentials, env_creds):
self.__logger
)
self.__is_config_updated = False
else:
log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)

if is_expired(self.__bearer_token):
self.__is_config_updated = True
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)
return self.__bearer_token

def update_config(self, config):
Expand Down
4 changes: 3 additions & 1 deletion tests/constants/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
EMPTY_URL = ""
SCOPES_LIST = ["admin", "user", "viewer"]
FORMATTED_SCOPES = "role:admin role:user role:viewer"
INVALID_JSON_FORMAT = '{"invalid": json}'
INVALID_JSON_FORMAT = '[{"invalid": "json"}]'

TEST_ERROR_MESSAGE = "Test error message."

Expand All @@ -90,6 +90,8 @@
CREDENTIALS_WITH_PATH = {"path": "/path/to/creds.json"}
CREDENTIALS_WITH_STRING = {"credentials_string": "dummy_credentials_string"}

VALID_ENV_CREDENTIALS = {"clientID":"CLIENT_ID","clientName":"test_V2","tokenURI":"TOKEN_URI","keyID":"KEY_ID","privateKey":"PRIVATE_KEY","keyValidAfterTime":"2024-10-21T18:06:26.000Z","keyValidBeforeTime":"2025-10-21T18:06:26.000Z","keyAlgorithm":"KEY_ALG_RSA_2048"}


# connection controller constants

Expand Down
21 changes: 8 additions & 13 deletions tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,28 @@
from skyflow.vault.connection import InvokeConnectionResponse
from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse
from skyflow.vault.tokens import DetokenizeResponse, TokenizeResponse
from tests.constants.test_constants import VALID_CREDENTIALS_STRING, INVALID_JSON_FORMAT, TEST_ERROR_MESSAGE
from tests.constants.test_constants import VALID_CREDENTIALS_STRING, INVALID_JSON_FORMAT, TEST_ERROR_MESSAGE, \
VALID_ENV_CREDENTIALS


class TestUtils(unittest.TestCase):

@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": VALID_CREDENTIALS_STRING})
@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": json.dumps(VALID_ENV_CREDENTIALS)})
def test_get_credentials_env_variable(self):
creds, _ = get_credentials()
VALID_CREDENTIALS_STRING.strip()
self.assertEqual(creds, json.loads(VALID_CREDENTIALS_STRING.replace('\n', '\\n')))
credentials = get_credentials()
credentials_string = credentials.get('credentials_string')
self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n'))

def test_get_credentials_with_config_level_creds(self):
test_creds = {"authToken": "test_token"}
creds, _ = get_credentials(config_level_creds=test_creds)
creds = get_credentials(config_level_creds=test_creds)
self.assertEqual(creds, test_creds)

def test_get_credentials_with_common_creds(self):
test_creds = {"authToken": "test_token"}
creds, _ = get_credentials(common_skyflow_creds=test_creds)
creds = get_credentials(common_skyflow_creds=test_creds)
self.assertEqual(creds, test_creds)

@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": INVALID_JSON_FORMAT})
def test_get_credentials_invalid_json_format(self):
with self.assertRaises(SkyflowError) as context:
get_credentials()
self.assertIn(context.exception.message, SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value)

def test_get_vault_url_valid(self):
valid_cluster_id = "testCluster"
valid_env = Env.DEV
Expand Down
4 changes: 2 additions & 2 deletions tests/vault/client/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_set_logger(self):
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
def test_initialize_client_configuration(self, mock_init_api_client, mock_config, mock_get_vault_url,
mock_get_credentials):
mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY, False)
mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY)
mock_get_vault_url.return_value = "https://test-vault-url.com"

self.vault_client.initialize_client_configuration()
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_get_vault_id(self):
@patch("skyflow.vault.client.client.log_info")
def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token,
mock_generate_bearer_token_from_creds):
token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY, False)
token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY)
self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"])

def test_update_config(self):
Expand Down

0 comments on commit b00b62f

Please sign in to comment.