diff --git a/skyflow/_utils.py b/skyflow/_utils.py index 301e013..83bf54a 100644 --- a/skyflow/_utils.py +++ b/skyflow/_utils.py @@ -60,11 +60,13 @@ class InfoMessages(Enum): INSERT_DATA_SUCCESS = "Data has been inserted successfully." DETOKENIZE_SUCCESS = "Data has been detokenized successfully." GET_BY_ID_SUCCESS = "Data fetched from ID successfully." + QUERY_SUCCESS = "Query executed successfully." BEARER_TOKEN_RECEIVED = "tokenProvider returned token successfully." INSERT_TRIGGERED = "Insert method triggered." DETOKENIZE_TRIGGERED = "Detokenize method triggered." GET_BY_ID_TRIGGERED = "Get by ID triggered." INVOKE_CONNECTION_TRIGGERED = "Invoke connection triggered." + QUERY_TRIGGERED = "Query method triggered." GENERATE_BEARER_TOKEN_TRIGGERED = "Generate bearer token triggered" GENERATE_BEARER_TOKEN_SUCCESS = "Generate bearer token returned successfully" IS_TOKEN_VALID_TRIGGERED = "isTokenValid() triggered" @@ -87,6 +89,7 @@ class InterfaceName(Enum): GET = "client.get" UPDATE = "client.update" INVOKE_CONNECTION = "client.invoke_connection" + QUERY = "client.query" GENERATE_BEARER_TOKEN = "service_account.generate_bearer_token" IS_TOKEN_VALID = "service_account.isTokenValid" diff --git a/skyflow/errors/_skyflow_errors.py b/skyflow/errors/_skyflow_errors.py index 410383f..efd4fc4 100644 --- a/skyflow/errors/_skyflow_errors.py +++ b/skyflow/errors/_skyflow_errors.py @@ -85,6 +85,10 @@ class SkyflowErrorMessages(Enum): INVALID_UPSERT_COLUMN_TYPE = "upsert object column key has value of type %s, expected string" EMPTY_UPSERT_OPTION_TABLE = "upsert object table value is empty string at index %s, expected non-empty string" EMPTY_UPSERT_OPTION_COLUMN = "upsert object column value is empty string at index %s, expected non-empty string" + QUERY_KEY_ERROR = "Query key is missing from payload" + INVALID_QUERY_TYPE = "Query key has value of type %s, expected string" + EMPTY_QUERY = "Query key cannot be empty" + INVALID_QUERY_COMMAND = "only SELECT commands are supported, %s command was passed instead" SERVER_ERROR = "Server returned errors, check SkyflowError.data for more" BATCH_INSERT_PARTIAL_SUCCESS = "Insert Operation is partially successful" diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index ba6974e..57a97e8 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -8,8 +8,7 @@ from ._delete import deleteProcessResponse from ._insert import getInsertRequestBody, processResponse, convertResponse from ._update import sendUpdateRequests, createUpdateResponseBody -from ._config import Configuration, DeleteOptions -from ._config import DetokenizeOptions, InsertOptions, ConnectionConfig, UpdateOptions +from ._config import Configuration, DeleteOptions, DetokenizeOptions, InsertOptions, ConnectionConfig, UpdateOptions, QueryOptions from ._connection import createRequest from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody from ._get_by_id import sendGetByIdRequests, createGetResponseBody @@ -18,7 +17,7 @@ from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow._utils import log_info, log_error, InfoMessages, InterfaceName, getMetrics from ._token import tokenProviderWrapper - +from ._query import getQueryRequestBody, getQueryResponse class Client: def __init__(self, config: Configuration): @@ -147,6 +146,28 @@ def invoke_connection(self, config: ConnectionConfig): session.close() return processResponse(response, interface=interface) + def query(self, queryInput, options: QueryOptions = QueryOptions()): + interface = InterfaceName.QUERY.value + log_info(InfoMessages.QUERY_TRIGGERED.value, interface=interface) + + self._checkConfig(interface) + + jsonBody = getQueryRequestBody(queryInput, options) + requestURL = self._get_complete_vault_url() + "/query" + self.storedToken = tokenProviderWrapper( + self.storedToken, self.tokenProvider, interface) + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.storedToken, + "sky-metadata": json.dumps(getMetrics()) + } + + response = requests.post(requestURL, data=jsonBody, headers=headers) + result = getQueryResponse(response) + + log_info(InfoMessages.QUERY_SUCCESS.value, interface) + return result + def _checkConfig(self, interface): ''' Performs basic check on the given client config diff --git a/skyflow/vault/_config.py b/skyflow/vault/_config.py index 2de331a..d7bc2f4 100644 --- a/skyflow/vault/_config.py +++ b/skyflow/vault/_config.py @@ -42,6 +42,10 @@ def __init__(self, tokens: bool=True): class DeleteOptions: def __init__(self, tokens: bool=False): self.tokens = tokens + +class QueryOptions: + def __init__(self): + pass class DetokenizeOptions: def __init__(self, continueOnError: bool=True): diff --git a/skyflow/vault/_query.py b/skyflow/vault/_query.py new file mode 100644 index 0000000..373264f --- /dev/null +++ b/skyflow/vault/_query.py @@ -0,0 +1,62 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +import json + +import requests +from ._config import QueryOptions +from requests.models import HTTPError +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +from skyflow._utils import InterfaceName + +interface = InterfaceName.QUERY.value + + +def getQueryRequestBody(data, options): + try: + query = data["query"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.QUERY_KEY_ERROR, interface=interface) + + if not isinstance(query, str): + queryType = str(type(query)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_QUERY_TYPE.value % queryType, interface=interface) + + if not query.strip(): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,SkyflowErrorMessages.EMPTY_QUERY.value, interface=interface) + + requestBody = {"query": query} + try: + jsonBody = json.dumps(requestBody) + except Exception as e: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_JSON.value % ( + 'query payload'), interface=interface) + + return jsonBody + +def getQueryResponse(response: requests.Response, interface=interface): + statusCode = response.status_code + content = response.content.decode('utf-8') + try: + response.raise_for_status() + try: + return json.loads(content) + except: + raise SkyflowError( + statusCode, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % content, interface=interface) + except HTTPError: + message = SkyflowErrorMessages.API_ERROR.value % statusCode + if response != None and response.content != None: + try: + errorResponse = json.loads(content) + if 'error' in errorResponse and type(errorResponse['error']) == type({}) and 'message' in errorResponse['error']: + message = errorResponse['error']['message'] + except: + message = SkyflowErrorMessages.RESPONSE_NOT_JSON.value % content + raise SkyflowError(SkyflowErrorCodes.INVALID_INDEX, message, interface=interface) + error = {"error": {}} + if 'x-request-id' in response.headers: + message += ' - request id: ' + response.headers['x-request-id'] + error['error'].update({"code": statusCode, "description": message}) + raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, SkyflowErrorMessages.SERVER_ERROR.value, error, interface=interface) diff --git a/tests/vault/test_query.py b/tests/vault/test_query.py new file mode 100644 index 0000000..63f9079 --- /dev/null +++ b/tests/vault/test_query.py @@ -0,0 +1,175 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +import json +import unittest +import os +from unittest import mock +import requests +from requests.models import Response +from skyflow.vault._query import getQueryRequestBody, getQueryResponse +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +from skyflow.vault._client import Client +from skyflow.vault._config import Configuration, QueryOptions + +class TestQuery(unittest.TestCase): + + def setUp(self) -> None: + self.dataPath = os.path.join(os.getcwd(), 'tests/vault/data/') + query = "SELECT * FROM pii_fields WHERE skyflow_id='3ea3861-x107-40w8-la98-106sp08ea83f'" + self.data = {"query": query} + self.mockRequest = {"records": [query]} + + self.mockResponse = { + "records": [ + { + "fields": { + "card_number": "XXXXXXXXXXXX1111", + "card_pin": "*REDACTED*", + "cvv": "", + "expiration_date": "*REDACTED*", + "expiration_month": "*REDACTED*", + "expiration_year": "*REDACTED*", + "name": "a***te", + "skyflow_id": "3ea3861-x107-40w8-la98-106sp08ea83f", + "ssn": "XXX-XX-6789", + "zip_code": None + }, + "tokens": None + } + ] + } + + self.requestId = '5d5d7e21-c789-9fcc-ba31-2a279d3a28ef' + + self.mockApiError = { + "error": { + "grpc_code": 13, + "http_code": 500, + "message": "ERROR (internal_error): Could not find Notebook Mapping Notebook Name was not found", + "http_status": "Internal Server Error", + "details": [] + } + } + + self.mockFailResponse = { + "error": { + "code": 500, + "description": "ERROR (internal_error): Could not find Notebook Mapping Notebook Name was not found - request id: 5d5d7e21-c789-9fcc-ba31-2a279d3a28ef" + } + } + + self.queryOptions = QueryOptions() + + return super().setUp() + + def getDataPath(self, file): + return self.dataPath + file + '.json' + + def testGetQueryRequestBodyWithValidBody(self): + body = json.loads(getQueryRequestBody(self.data, self.queryOptions)) + expectedOutput = { + "query": "SELECT * FROM pii_fields WHERE skyflow_id='3ea3861-x107-40w8-la98-106sp08ea83f'", + } + self.assertEqual(body, expectedOutput) + + def testGetQueryRequestBodyNoQuery(self): + invalidData = {"invalidKey": self.data["query"]} + try: + getQueryRequestBody(invalidData, self.queryOptions) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.QUERY_KEY_ERROR.value) + + def testGetQueryRequestBodyInvalidType(self): + invalidData = {"query": ['SELECT * FROM table_name']} + try: + getQueryRequestBody(invalidData, self.queryOptions) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_QUERY_TYPE.value % (str(type(invalidData["query"])))) + + def testGetQueryRequestBodyEmptyBody(self): + invalidData = {"query": ''} + try: + getQueryRequestBody(invalidData, self.queryOptions) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.EMPTY_QUERY.value) + + def testGetQueryValidResponse(self): + response = Response() + response.status_code = 200 + response._content = b'{"key": "value"}' + try: + responseDict = getQueryResponse(response) + self.assertDictEqual(responseDict, {'key': 'value'}) + except SkyflowError as e: + self.fail() + + def testClientInit(self): + config = Configuration( + 'vaultid', 'https://skyflow.com', lambda: 'test') + client = Client(config) + self.assertEqual(client.vaultURL, 'https://skyflow.com') + self.assertEqual(client.vaultID, 'vaultid') + self.assertEqual(client.tokenProvider(), 'test') + + def testGetQueryResponseSuccessInvalidJson(self): + invalid_response = Response() + invalid_response.status_code = 200 + invalid_response._content = b'invalid-json' + try: + getQueryResponse(invalid_response) + self.fail('not failing on invalid json') + except SkyflowError as se: + self.assertEqual(se.code, 200) + self.assertEqual( + se.message, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % 'invalid-json') + + def testGetQueryResponseFailInvalidJson(self): + invalid_response = mock.Mock( + spec=requests.Response, + status_code=404, + content=b'error' + ) + invalid_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Not found") + try: + getQueryResponse(invalid_response) + self.fail('Not failing on invalid error json') + except SkyflowError as se: + self.assertEqual(se.code, 404) + self.assertEqual( + se.message, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % 'error') + + def testGetQueryResponseFail(self): + response = mock.Mock( + spec=requests.Response, + status_code=500, + content=json.dumps(self.mockApiError).encode('utf-8') + ) + response.headers = {"x-request-id": self.requestId} + response.raise_for_status.side_effect = requests.exceptions.HTTPError("Server Error") + try: + getQueryResponse(response) + self.fail('not throwing exception when error code is 500') + except SkyflowError as e: + self.assertEqual(e.code, 500) + self.assertEqual(e.message, SkyflowErrorMessages.SERVER_ERROR.value) + self.assertDictEqual(e.data, self.mockFailResponse) + + def testQueryInvalidToken(self): + config = Configuration('id', 'url', lambda: 'invalid-token') + try: + Client(config).query({'query': 'SELECT * FROM table_name'}) + self.fail() + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.TOKEN_PROVIDER_INVALID_TOKEN.value)