Skip to content

Commit

Permalink
Merge pull request #103 from skyflowapi/release/23.8.2
Browse files Browse the repository at this point in the history
SK-976/Release/23.8.2
  • Loading branch information
skyflow-vivek authored Sep 4, 2023
2 parents 4a69139 + cd70527 commit 37ebed7
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 3 deletions.
3 changes: 3 additions & 0 deletions skyflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ class InfoMessages(Enum):
INSERT_DATA_SUCCESS = "Data has been inserted successfully."
DETOKENIZE_SUCCESS = "Data has been detokenized successfully."
GET_BY_ID_SUCCESS = "Data fetched from ID successfully."
QUERY_SUCCESS = "Query executed successfully."
BEARER_TOKEN_RECEIVED = "tokenProvider returned token successfully."
INSERT_TRIGGERED = "Insert method triggered."
DETOKENIZE_TRIGGERED = "Detokenize method triggered."
GET_BY_ID_TRIGGERED = "Get by ID triggered."
INVOKE_CONNECTION_TRIGGERED = "Invoke connection triggered."
QUERY_TRIGGERED = "Query method triggered."
GENERATE_BEARER_TOKEN_TRIGGERED = "Generate bearer token triggered"
GENERATE_BEARER_TOKEN_SUCCESS = "Generate bearer token returned successfully"
IS_TOKEN_VALID_TRIGGERED = "isTokenValid() triggered"
Expand All @@ -87,6 +89,7 @@ class InterfaceName(Enum):
GET = "client.get"
UPDATE = "client.update"
INVOKE_CONNECTION = "client.invoke_connection"
QUERY = "client.query"
GENERATE_BEARER_TOKEN = "service_account.generate_bearer_token"

IS_TOKEN_VALID = "service_account.isTokenValid"
Expand Down
4 changes: 4 additions & 0 deletions skyflow/errors/_skyflow_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class SkyflowErrorMessages(Enum):
INVALID_UPSERT_COLUMN_TYPE = "upsert object column key has value of type %s, expected string"
EMPTY_UPSERT_OPTION_TABLE = "upsert object table value is empty string at index %s, expected non-empty string"
EMPTY_UPSERT_OPTION_COLUMN = "upsert object column value is empty string at index %s, expected non-empty string"
QUERY_KEY_ERROR = "Query key is missing from payload"
INVALID_QUERY_TYPE = "Query key has value of type %s, expected string"
EMPTY_QUERY = "Query key cannot be empty"
INVALID_QUERY_COMMAND = "only SELECT commands are supported, %s command was passed instead"
SERVER_ERROR = "Server returned errors, check SkyflowError.data for more"

BATCH_INSERT_PARTIAL_SUCCESS = "Insert Operation is partially successful"
Expand Down
27 changes: 24 additions & 3 deletions skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from ._delete import deleteProcessResponse
from ._insert import getInsertRequestBody, processResponse, convertResponse
from ._update import sendUpdateRequests, createUpdateResponseBody
from ._config import Configuration, DeleteOptions
from ._config import DetokenizeOptions, InsertOptions, ConnectionConfig, UpdateOptions
from ._config import Configuration, DeleteOptions, DetokenizeOptions, InsertOptions, ConnectionConfig, UpdateOptions, QueryOptions
from ._connection import createRequest
from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
from ._get_by_id import sendGetByIdRequests, createGetResponseBody
Expand All @@ -18,7 +17,7 @@
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import log_info, log_error, InfoMessages, InterfaceName, getMetrics
from ._token import tokenProviderWrapper

from ._query import getQueryRequestBody, getQueryResponse

class Client:
def __init__(self, config: Configuration):
Expand Down Expand Up @@ -147,6 +146,28 @@ def invoke_connection(self, config: ConnectionConfig):
session.close()
return processResponse(response, interface=interface)

def query(self, queryInput, options: QueryOptions = QueryOptions()):
interface = InterfaceName.QUERY.value
log_info(InfoMessages.QUERY_TRIGGERED.value, interface=interface)

self._checkConfig(interface)

jsonBody = getQueryRequestBody(queryInput, options)
requestURL = self._get_complete_vault_url() + "/query"
self.storedToken = tokenProviderWrapper(
self.storedToken, self.tokenProvider, interface)
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.storedToken,
"sky-metadata": json.dumps(getMetrics())
}

response = requests.post(requestURL, data=jsonBody, headers=headers)
result = getQueryResponse(response)

log_info(InfoMessages.QUERY_SUCCESS.value, interface)
return result

def _checkConfig(self, interface):
'''
Performs basic check on the given client config
Expand Down
4 changes: 4 additions & 0 deletions skyflow/vault/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self, tokens: bool=True):
class DeleteOptions:
def __init__(self, tokens: bool=False):
self.tokens = tokens

class QueryOptions:
def __init__(self):
pass

class DetokenizeOptions:
def __init__(self, continueOnError: bool=True):
Expand Down
62 changes: 62 additions & 0 deletions skyflow/vault/_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
'''
Copyright (c) 2022 Skyflow, Inc.
'''
import json

import requests
from ._config import QueryOptions
from requests.models import HTTPError
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import InterfaceName

interface = InterfaceName.QUERY.value


def getQueryRequestBody(data, options):
try:
query = data["query"]
except KeyError:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.QUERY_KEY_ERROR, interface=interface)

if not isinstance(query, str):
queryType = str(type(query))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_QUERY_TYPE.value % queryType, interface=interface)

if not query.strip():
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,SkyflowErrorMessages.EMPTY_QUERY.value, interface=interface)

requestBody = {"query": query}
try:
jsonBody = json.dumps(requestBody)
except Exception as e:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_JSON.value % (
'query payload'), interface=interface)

return jsonBody

def getQueryResponse(response: requests.Response, interface=interface):
statusCode = response.status_code
content = response.content.decode('utf-8')
try:
response.raise_for_status()
try:
return json.loads(content)
except:
raise SkyflowError(
statusCode, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % content, interface=interface)
except HTTPError:
message = SkyflowErrorMessages.API_ERROR.value % statusCode
if response != None and response.content != None:
try:
errorResponse = json.loads(content)
if 'error' in errorResponse and type(errorResponse['error']) == type({}) and 'message' in errorResponse['error']:
message = errorResponse['error']['message']
except:
message = SkyflowErrorMessages.RESPONSE_NOT_JSON.value % content
raise SkyflowError(SkyflowErrorCodes.INVALID_INDEX, message, interface=interface)
error = {"error": {}}
if 'x-request-id' in response.headers:
message += ' - request id: ' + response.headers['x-request-id']
error['error'].update({"code": statusCode, "description": message})
raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, SkyflowErrorMessages.SERVER_ERROR.value, error, interface=interface)
175 changes: 175 additions & 0 deletions tests/vault/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
'''
Copyright (c) 2022 Skyflow, Inc.
'''
import json
import unittest
import os
from unittest import mock
import requests
from requests.models import Response
from skyflow.vault._query import getQueryRequestBody, getQueryResponse
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow.vault._client import Client
from skyflow.vault._config import Configuration, QueryOptions

class TestQuery(unittest.TestCase):

def setUp(self) -> None:
self.dataPath = os.path.join(os.getcwd(), 'tests/vault/data/')
query = "SELECT * FROM pii_fields WHERE skyflow_id='3ea3861-x107-40w8-la98-106sp08ea83f'"
self.data = {"query": query}
self.mockRequest = {"records": [query]}

self.mockResponse = {
"records": [
{
"fields": {
"card_number": "XXXXXXXXXXXX1111",
"card_pin": "*REDACTED*",
"cvv": "",
"expiration_date": "*REDACTED*",
"expiration_month": "*REDACTED*",
"expiration_year": "*REDACTED*",
"name": "a***te",
"skyflow_id": "3ea3861-x107-40w8-la98-106sp08ea83f",
"ssn": "XXX-XX-6789",
"zip_code": None
},
"tokens": None
}
]
}

self.requestId = '5d5d7e21-c789-9fcc-ba31-2a279d3a28ef'

self.mockApiError = {
"error": {
"grpc_code": 13,
"http_code": 500,
"message": "ERROR (internal_error): Could not find Notebook Mapping Notebook Name was not found",
"http_status": "Internal Server Error",
"details": []
}
}

self.mockFailResponse = {
"error": {
"code": 500,
"description": "ERROR (internal_error): Could not find Notebook Mapping Notebook Name was not found - request id: 5d5d7e21-c789-9fcc-ba31-2a279d3a28ef"
}
}

self.queryOptions = QueryOptions()

return super().setUp()

def getDataPath(self, file):
return self.dataPath + file + '.json'

def testGetQueryRequestBodyWithValidBody(self):
body = json.loads(getQueryRequestBody(self.data, self.queryOptions))
expectedOutput = {
"query": "SELECT * FROM pii_fields WHERE skyflow_id='3ea3861-x107-40w8-la98-106sp08ea83f'",
}
self.assertEqual(body, expectedOutput)

def testGetQueryRequestBodyNoQuery(self):
invalidData = {"invalidKey": self.data["query"]}
try:
getQueryRequestBody(invalidData, self.queryOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.QUERY_KEY_ERROR.value)

def testGetQueryRequestBodyInvalidType(self):
invalidData = {"query": ['SELECT * FROM table_name']}
try:
getQueryRequestBody(invalidData, self.queryOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.INVALID_QUERY_TYPE.value % (str(type(invalidData["query"]))))

def testGetQueryRequestBodyEmptyBody(self):
invalidData = {"query": ''}
try:
getQueryRequestBody(invalidData, self.queryOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.EMPTY_QUERY.value)

def testGetQueryValidResponse(self):
response = Response()
response.status_code = 200
response._content = b'{"key": "value"}'
try:
responseDict = getQueryResponse(response)
self.assertDictEqual(responseDict, {'key': 'value'})
except SkyflowError as e:
self.fail()

def testClientInit(self):
config = Configuration(
'vaultid', 'https://skyflow.com', lambda: 'test')
client = Client(config)
self.assertEqual(client.vaultURL, 'https://skyflow.com')
self.assertEqual(client.vaultID, 'vaultid')
self.assertEqual(client.tokenProvider(), 'test')

def testGetQueryResponseSuccessInvalidJson(self):
invalid_response = Response()
invalid_response.status_code = 200
invalid_response._content = b'invalid-json'
try:
getQueryResponse(invalid_response)
self.fail('not failing on invalid json')
except SkyflowError as se:
self.assertEqual(se.code, 200)
self.assertEqual(
se.message, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % 'invalid-json')

def testGetQueryResponseFailInvalidJson(self):
invalid_response = mock.Mock(
spec=requests.Response,
status_code=404,
content=b'error'
)
invalid_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Not found")
try:
getQueryResponse(invalid_response)
self.fail('Not failing on invalid error json')
except SkyflowError as se:
self.assertEqual(se.code, 404)
self.assertEqual(
se.message, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % 'error')

def testGetQueryResponseFail(self):
response = mock.Mock(
spec=requests.Response,
status_code=500,
content=json.dumps(self.mockApiError).encode('utf-8')
)
response.headers = {"x-request-id": self.requestId}
response.raise_for_status.side_effect = requests.exceptions.HTTPError("Server Error")
try:
getQueryResponse(response)
self.fail('not throwing exception when error code is 500')
except SkyflowError as e:
self.assertEqual(e.code, 500)
self.assertEqual(e.message, SkyflowErrorMessages.SERVER_ERROR.value)
self.assertDictEqual(e.data, self.mockFailResponse)

def testQueryInvalidToken(self):
config = Configuration('id', 'url', lambda: 'invalid-token')
try:
Client(config).query({'query': 'SELECT * FROM table_name'})
self.fail()
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.TOKEN_PROVIDER_INVALID_TOKEN.value)

0 comments on commit 37ebed7

Please sign in to comment.