Skip to content

Commit

Permalink
SK-1731: Updated error, info logs and Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshwar-skyflow committed Nov 22, 2024
1 parent 847973b commit c0b03d5
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 165 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ DateTime~=5.5
PyJWT~=2.9.0
requests~=2.32.3
coverage
cryptography
cryptography
python-dotenv~=1.0.1
271 changes: 129 additions & 142 deletions skyflow/utils/_skyflow_messages.py

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import urllib.parse
from dotenv import load_dotenv
from requests.sessions import PreparedRequest
from requests.models import HTTPError
import requests
Expand All @@ -22,16 +23,19 @@
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value

def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None):
dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env")
if dotenv_path:
load_dotenv(dotenv_path)
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
if config_level_creds:
return config_level_creds
return config_level_creds, False
if common_skyflow_creds:
return common_skyflow_creds
return common_skyflow_creds, False
if env_skyflow_credentials:
env_skyflow_credentials.strip()
try:
env_creds = json.loads(env_skyflow_credentials.replace('\n', '\\n'))
return env_creds
return env_creds, True
except json.JSONDecodeError:
raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code)
else:
Expand Down
2 changes: 1 addition & 1 deletion skyflow/utils/validations/_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def validate_update_vault_config(logger, config):
return True

def validate_connection_config(logger, config):
log_info(SkyflowMessages.Info.VALIDATE_CONNECTION_CONFIG.value, logger)
log_info(SkyflowMessages.Info.VALIDATING_CONNECTION_CONFIG.value, logger)
validate_keys(logger, config, valid_connection_config_keys)

validate_required_field(
Expand Down
19 changes: 13 additions & 6 deletions skyflow/vault/client/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from skyflow.generated.rest import Configuration, RecordsApi, ApiClient, TokensApi, QueryApi
from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired
from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages
Expand All @@ -23,8 +24,8 @@ def set_logger(self, log_level, logger):
self.__logger = logger

def initialize_client_configuration(self):
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
token = self.get_bearer_token(credentials)
credentials, env_creds = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
token = self.get_bearer_token(credentials, env_creds)
vault_url = get_vault_url(self.__config.get("cluster_id"),
self.__config.get("env"),
self.__config.get("vault_id"),
Expand All @@ -47,7 +48,7 @@ def get_query_api(self):
def get_vault_id(self):
return self.__config.get("vault_id")

def get_bearer_token(self, credentials):
def get_bearer_token(self, credentials, env_creds):
if 'api_key' in credentials:
return credentials.get('api_key')
elif 'token' in credentials:
Expand All @@ -58,10 +59,16 @@ def get_bearer_token(self, credentials):
"ctx": self.__config.get("ctx")
}

log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_TRIGGERED, self.__logger)

if self.__bearer_token is None or self.__is_config_updated:
if 'path' in credentials:
if env_creds:
log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value,
self.__logger)
self.__bearer_token, _ = generate_bearer_token_from_creds(
json.dumps(credentials),
options,
self.__logger
)
elif 'path' in credentials:
path = credentials.get("path")
self.__bearer_token, _ = generate_bearer_token(
path,
Expand Down
2 changes: 0 additions & 2 deletions skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from venv import logger

from skyflow.generated.rest import V1FieldRecords, RecordServiceInsertRecordBody, V1DetokenizeRecordRequest, \
V1DetokenizePayload, V1TokenizeRecordRequest, V1TokenizePayload, QueryServiceExecuteQueryBody, \
RecordServiceBulkDeleteRecordBody, RecordServiceUpdateRecordBody, RecordServiceBatchOperationBody, V1BatchRecord, \
Expand Down
11 changes: 3 additions & 8 deletions tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,21 @@


class TestUtils(unittest.TestCase):
# def test_get_credentials_empty_credentials(self):
# with self.assertRaises(SkyflowError) as context:
# get_credentials()
# self.assertIn(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value)

@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": VALID_CREDENTIALS_STRING})
def test_get_credentials_env_variable(self):
creds = get_credentials()
creds, _ = get_credentials()
VALID_CREDENTIALS_STRING.strip()
print(type(creds))
self.assertEqual(creds, json.loads(VALID_CREDENTIALS_STRING.replace('\n', '\\n')))

def test_get_credentials_with_config_level_creds(self):
test_creds = {"authToken": "test_token"}
creds = get_credentials(config_level_creds=test_creds)
creds, _ = get_credentials(config_level_creds=test_creds)
self.assertEqual(creds, test_creds)

def test_get_credentials_with_common_creds(self):
test_creds = {"authToken": "test_token"}
creds = get_credentials(common_skyflow_creds=test_creds)
creds, _ = get_credentials(common_skyflow_creds=test_creds)
self.assertEqual(creds, test_creds)

@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": INVALID_JSON_FORMAT})
Expand Down
4 changes: 2 additions & 2 deletions tests/vault/client/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_set_logger(self):
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
def test_initialize_client_configuration(self, mock_init_api_client, mock_config, mock_get_vault_url,
mock_get_credentials):
mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY
mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY, False)
mock_get_vault_url.return_value = "https://test-vault-url.com"

self.vault_client.initialize_client_configuration()
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_get_vault_id(self):
@patch("skyflow.vault.client.client.log_info")
def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token,
mock_generate_bearer_token_from_creds):
token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY)
token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY, False)
self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"])

def test_update_config(self):
Expand Down

0 comments on commit c0b03d5

Please sign in to comment.