Skip to content

Commit

Permalink
SK-1731: Updated error, info logs and Fixes (#139)
Browse files Browse the repository at this point in the history
* SK-1731: Updated error, info logs and Fixes
  • Loading branch information
saileshwar-skyflow authored Nov 25, 2024
1 parent 847973b commit 7d27637
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 170 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ DateTime~=5.5
PyJWT~=2.9.0
requests~=2.32.3
coverage
cryptography
cryptography
python-dotenv~=1.0.1
273 changes: 131 additions & 142 deletions skyflow/utils/_skyflow_messages.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import urllib.parse
from dotenv import load_dotenv
from requests.sessions import PreparedRequest
from requests.models import HTTPError
import requests
Expand All @@ -22,6 +23,9 @@
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value

def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None):
dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env")
if dotenv_path:
load_dotenv(dotenv_path)
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
if config_level_creds:
return config_level_creds
Expand All @@ -30,8 +34,10 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg
if env_skyflow_credentials:
env_skyflow_credentials.strip()
try:
env_creds = json.loads(env_skyflow_credentials.replace('\n', '\\n'))
return env_creds
env_creds = env_skyflow_credentials.replace('\n', '\\n')
return {
'credentials_string': env_creds
}
except json.JSONDecodeError:
raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code)
else:
Expand Down
2 changes: 1 addition & 1 deletion skyflow/utils/validations/_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def validate_update_vault_config(logger, config):
return True

def validate_connection_config(logger, config):
log_info(SkyflowMessages.Info.VALIDATE_CONNECTION_CONFIG.value, logger)
log_info(SkyflowMessages.Info.VALIDATING_CONNECTION_CONFIG.value, logger)
validate_keys(logger, config, valid_connection_config_keys)

validate_required_field(
Expand Down
8 changes: 4 additions & 4 deletions skyflow/vault/client/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from skyflow.generated.rest import Configuration, RecordsApi, ApiClient, TokensApi, QueryApi
from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired
from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages
Expand All @@ -23,7 +24,7 @@ def set_logger(self, log_level, logger):
self.__logger = logger

def initialize_client_configuration(self):
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
token = self.get_bearer_token(credentials)
vault_url = get_vault_url(self.__config.get("cluster_id"),
self.__config.get("env"),
Expand Down Expand Up @@ -58,8 +59,6 @@ def get_bearer_token(self, credentials):
"ctx": self.__config.get("ctx")
}

log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_TRIGGERED, self.__logger)

if self.__bearer_token is None or self.__is_config_updated:
if 'path' in credentials:
path = credentials.get("path")
Expand All @@ -77,12 +76,13 @@ def get_bearer_token(self, credentials):
self.__logger
)
self.__is_config_updated = False
else:
log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)

if is_expired(self.__bearer_token):
self.__is_config_updated = True
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)
return self.__bearer_token

def update_config(self, config):
Expand Down
2 changes: 0 additions & 2 deletions skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from venv import logger

from skyflow.generated.rest import V1FieldRecords, RecordServiceInsertRecordBody, V1DetokenizeRecordRequest, \
V1DetokenizePayload, V1TokenizeRecordRequest, V1TokenizePayload, QueryServiceExecuteQueryBody, \
RecordServiceBulkDeleteRecordBody, RecordServiceUpdateRecordBody, RecordServiceBatchOperationBody, V1BatchRecord, \
Expand Down
4 changes: 3 additions & 1 deletion tests/constants/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
EMPTY_URL = ""
SCOPES_LIST = ["admin", "user", "viewer"]
FORMATTED_SCOPES = "role:admin role:user role:viewer"
INVALID_JSON_FORMAT = '{"invalid": json}'
INVALID_JSON_FORMAT = '[{"invalid": "json"}]'

TEST_ERROR_MESSAGE = "Test error message."

Expand All @@ -90,6 +90,8 @@
CREDENTIALS_WITH_PATH = {"path": "/path/to/creds.json"}
CREDENTIALS_WITH_STRING = {"credentials_string": "dummy_credentials_string"}

VALID_ENV_CREDENTIALS = {"clientID":"CLIENT_ID","clientName":"test_V2","tokenURI":"TOKEN_URI","keyID":"KEY_ID","privateKey":"PRIVATE_KEY","keyValidAfterTime":"2024-10-21T18:06:26.000Z","keyValidBeforeTime":"2025-10-21T18:06:26.000Z","keyAlgorithm":"KEY_ALG_RSA_2048"}


# connection controller constants

Expand Down
22 changes: 6 additions & 16 deletions tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@
from skyflow.vault.connection import InvokeConnectionResponse
from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse
from skyflow.vault.tokens import DetokenizeResponse, TokenizeResponse
from tests.constants.test_constants import VALID_CREDENTIALS_STRING, INVALID_JSON_FORMAT, TEST_ERROR_MESSAGE
from tests.constants.test_constants import VALID_CREDENTIALS_STRING, INVALID_JSON_FORMAT, TEST_ERROR_MESSAGE, \
VALID_ENV_CREDENTIALS


class TestUtils(unittest.TestCase):
# def test_get_credentials_empty_credentials(self):
# with self.assertRaises(SkyflowError) as context:
# get_credentials()
# self.assertIn(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value)

@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": VALID_CREDENTIALS_STRING})
@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": json.dumps(VALID_ENV_CREDENTIALS)})
def test_get_credentials_env_variable(self):
creds = get_credentials()
VALID_CREDENTIALS_STRING.strip()
print(type(creds))
self.assertEqual(creds, json.loads(VALID_CREDENTIALS_STRING.replace('\n', '\\n')))
credentials = get_credentials()
credentials_string = credentials.get('credentials_string')
self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n'))

def test_get_credentials_with_config_level_creds(self):
test_creds = {"authToken": "test_token"}
Expand All @@ -41,12 +37,6 @@ def test_get_credentials_with_common_creds(self):
creds = get_credentials(common_skyflow_creds=test_creds)
self.assertEqual(creds, test_creds)

@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": INVALID_JSON_FORMAT})
def test_get_credentials_invalid_json_format(self):
with self.assertRaises(SkyflowError) as context:
get_credentials()
self.assertIn(context.exception.message, SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value)

def test_get_vault_url_valid(self):
valid_cluster_id = "testCluster"
valid_env = Env.DEV
Expand Down
2 changes: 1 addition & 1 deletion tests/vault/client/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_set_logger(self):
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
def test_initialize_client_configuration(self, mock_init_api_client, mock_config, mock_get_vault_url,
mock_get_credentials):
mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY
mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY)
mock_get_vault_url.return_value = "https://test-vault-url.com"

self.vault_client.initialize_client_configuration()
Expand Down

0 comments on commit 7d27637

Please sign in to comment.