Skip to content

Commit

Permalink
SK-1308 Fix bug in Get method in Python SDK
Browse files Browse the repository at this point in the history
- Fix bug when calling Get method with options tokens as true.
- Remove redundant import statements from client file.
- Add missing validation case for Get method.
  • Loading branch information
skyflow-vivek committed Dec 6, 2023
1 parent 295f0ce commit 19d9668
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 118 deletions.
2 changes: 2 additions & 0 deletions skyflow/errors/_skyflow_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class SkyflowErrorMessages(Enum):
INVALID_TOKEN_TYPE = "Token key has value of type %s, expected string"
REDACTION_WITH_TOKENS_NOT_SUPPORTED = "Redaction cannot be used when tokens are true in options"
TOKENS_GET_COLUMN_NOT_SUPPORTED = "Column_name or column_values cannot be used with tokens in options"
BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED = "Both skyflow ids and column details (name and/or values) are specified in payload"

PARTIAL_SUCCESS = "Server returned errors, check SkyflowError.data for more"

VAULT_ID_INVALID_TYPE = "Expected Vault ID to be str, got %s"
Expand Down
25 changes: 6 additions & 19 deletions skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,19 @@
import json
import types
import requests
import asyncio
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse
from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody
from skyflow.vault._config import Configuration, GetOptions
from skyflow.vault._config import InsertOptions, ConnectionConfig, UpdateOptions
from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions
from skyflow.vault._connection import createRequest
from skyflow.vault._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
from skyflow.vault._get_by_id import sendGetByIdRequests, createGetResponseBody
from skyflow.vault._get import sendGetRequests
import asyncio
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import log_info, InfoMessages, InterfaceName, getMetrics
from skyflow.vault._token import tokenProviderWrapper

from ._delete import deleteProcessResponse
from ._insert import getInsertRequestBody, processResponse, convertResponse
from ._update import sendUpdateRequests, createUpdateResponseBody
from ._config import Configuration, DeleteOptions, DetokenizeOptions, InsertOptions, ConnectionConfig, UpdateOptions, QueryOptions
from ._connection import createRequest
from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
from ._get_by_id import sendGetByIdRequests, createGetResponseBody
from ._get import sendGetRequests
import asyncio
from skyflow.vault._delete import deleteProcessResponse
from skyflow.vault._query import getQueryRequestBody, getQueryResponse
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import log_info, log_error, InfoMessages, InterfaceName, getMetrics
from ._token import tokenProviderWrapper
from ._query import getQueryRequestBody, getQueryResponse
from skyflow.vault._token import tokenProviderWrapper

class Client:
def __init__(self, config: Configuration):
Expand Down Expand Up @@ -109,7 +96,7 @@ def get(self, records, options: GetOptions = GetOptions()):
self.storedToken, self.tokenProvider, interface)
url = self._get_complete_vault_url()
responses = asyncio.run(sendGetRequests(
records, options,url, self.storedToken))
records, options, url, self.storedToken))
result, partial = createGetResponseBody(responses)
if partial:
raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS,
Expand Down
54 changes: 29 additions & 25 deletions skyflow/vault/_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

interface = InterfaceName.GET.value


def getGetRequestBody(data, options: GetOptions):
requestBody = {}
ids = None
if "ids" in data:
ids = data["ids"]
Expand All @@ -25,25 +25,28 @@ def getGetRequestBody(data, options: GetOptions):
idType = str(type(id))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % (
idType), interface=interface)
requestBody["skyflow_ids"] = ids
try:
table = data["table"]
except KeyError:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.TABLE_KEY_ERROR, interface=interface)
if not isinstance(table, str):
tableType = str(type(table))

raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % (
tableType), interface=interface)
else:
requestBody["tableName"] = table

if options.tokens and data.get("redaction"):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface)
if options.tokens and (data.get('columnName') or data.get('columnValues')):
raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED,
if options.tokens:
if data.get("redaction"):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface)
if (data.get('columnName') or data.get('columnValues')):
raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED,
SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED, interface=interface)

if not options.tokens:
requestBody["tokenization"] = options.tokens
else:
try:
redaction = data["redaction"]
except KeyError:
Expand All @@ -53,6 +56,8 @@ def getGetRequestBody(data, options: GetOptions):
redactionType = str(type(redaction))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (
redactionType), interface=interface)
else:
requestBody["redaction"] = redaction.value

columnName = None
if "columnName" in data:
Expand All @@ -69,13 +74,17 @@ def getGetRequestBody(data, options: GetOptions):
columnValuesType = str(type(columnValues))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (
columnValuesType), interface=interface)
else:
requestBody["column_name"] = columnName
requestBody["column_values"] = columnValues

if (ids is None and (columnName is None or columnValues is None)):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface=interface)
return ids, table, redaction.value, columnName, columnValues
return ids, table, "DEFAULT", None, None

SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR, interface=interface)
elif (ids != None and (columnName != None or columnValues != None)):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED, interface=interface)
return requestBody

async def sendGetRequests(data, options: GetOptions, url, token):
tasks = []
Expand All @@ -97,27 +106,22 @@ async def sendGetRequests(data, options: GetOptions, url, token):

validatedRecords = []
for record in records:
ids, table, redaction, columnName, columnValues = getGetRequestBody(record, options)
validatedRecords.append((ids, table, redaction, columnName, columnValues))
requestBody = getGetRequestBody(record, options)
validatedRecords.append(requestBody)

Check warning on line 110 in skyflow/vault/_get.py

View check run for this annotation

Codecov / codecov/patch

skyflow/vault/_get.py#L110

Added line #L110 was not covered by tests
async with ClientSession() as session:
for record in validatedRecords:
ids, table, redaction, columnName, columnValues = record
headers = {
"Authorization": "Bearer " + token,
"sky-metadata": json.dumps(getMetrics())
}
params = {"redaction": redaction}

if ids is not None:
params["skyflow_ids"] = ids
if columnName is not None:
params["column_name"] = columnName
params["column_values"] = columnValues
table = record.pop("tableName")
params = record
if options.tokens:
params["tokenization"] = json.dumps(record["tokenization"])

Check warning on line 120 in skyflow/vault/_get.py

View check run for this annotation

Codecov / codecov/patch

skyflow/vault/_get.py#L117-L120

Added lines #L117 - L120 were not covered by tests
task = asyncio.ensure_future(
get(url, headers, params, session, record[1], options.tokens)
get(url, headers, params, session, table)
)
tasks.append(task)
await asyncio.gather(*tasks)
await session.close()

return tasks
31 changes: 1 addition & 30 deletions skyflow/vault/_get_by_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,6 @@

interface = InterfaceName.GET_BY_ID.value

def encrypt_data(data, token):
if token:
key = Fernet.generate_key()
fernet = Fernet(key)
encrypted_data = data.copy()
fields = encrypted_data["records"][0]["fields"]
for record in encrypted_data["records"]:
fields = record["fields"]
for key, value in fields.items():
if isinstance(value, str):
encrypted_value = fernet.encrypt(value.encode()).decode()
fields[key] = encrypted_value

serialized_data = json.dumps(encrypted_data)
encrypted_bytes = serialized_data.encode()

return encrypted_bytes
else:
return data, None


def getGetByIdRequestBody(data):
try:
ids = data["ids"]
Expand Down Expand Up @@ -98,21 +77,13 @@ async def sendGetByIdRequests(data, url, token):
await session.close()
return tasks


async def get(url, headers, params, session, table,token=False):
async def get(url, headers, params, session, table):
async with session.get(url + "/" + table, headers=headers, params=params, ssl=False) as response:
try:
response_data = await response.text()

if token:
data = json.loads(response_data)
return (encrypt_data(data,token), response.status, table, response.headers['x-request-id'])

return (await response.read(), response.status, table, response.headers['x-request-id'])
except KeyError:
return (await response.read(), response.status, table)


def createGetResponseBody(responses):
result = {
"records": [],
Expand Down
91 changes: 47 additions & 44 deletions tests/vault/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import unittest
import os

from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow.vault import Client, Configuration, RedactionType, GetOptions
from skyflow.vault._get_by_id import encrypt_data
from skyflow.service_account import generate_bearer_token
from dotenv import dotenv_values
import warnings
import asyncio
import json
from cryptography.fernet import Fernet

from dotenv import dotenv_values
from skyflow.service_account import generate_bearer_token
from skyflow.vault import Client, Configuration, RedactionType, GetOptions
from skyflow.vault._get import getGetRequestBody
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages

class TestGet(unittest.TestCase):

Expand Down Expand Up @@ -169,7 +167,6 @@ def testGetByIdInvalidColumnValues(self):
self.assertEqual(
e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str) )


def testGetByTokenAndRedaction(self):
invalidData = {"records": [
{"ids": ["123","456"],
Expand All @@ -184,8 +181,7 @@ def testGetByTokenAndRedaction(self):
e.message, SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value)

def testGetByNoOptionAndRedaction(self):
invalidData = {"records":[
{"ids":["123","456"],"table":"newstripe"}]}
invalidData = {"records":[{"ids":["123", "456"], "table":"newstripe"}]}
options = GetOptions(False)
try:
self.client.get(invalidData,options=options)
Expand All @@ -196,10 +192,13 @@ def testGetByNoOptionAndRedaction(self):
e.message,SkyflowErrorMessages.REDACTION_KEY_ERROR.value)

def testGetByOptionAndUniqueColumnRedaction(self):
invalidData ={"records":[
{"table":"newstripe","columnName":"card_number","columnValues":["456","980"],}
]}

invalidData ={
"records":[{
"table":"newstripe",
"columnName":"card_number",
"columnValues":["456","980"],
}]
}
options = GetOptions(True)
try:
self.client.get(invalidData, options=options)
Expand All @@ -210,9 +209,13 @@ def testGetByOptionAndUniqueColumnRedaction(self):
e.message, SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED.value)

def testInvalidRedactionTypeWithNoOption(self):
invalidData = {"records": [
{"ids": ["123","456"],
"table": "stripe", "redaction": "invalid_redaction"}]}
invalidData = {
"records": [{
"ids": ["123","456"],
"table": "stripe",
"redaction": "invalid_redaction"
}]
}
options = GetOptions(False)
try:
self.client.get(invalidData, options=options)
Expand All @@ -221,36 +224,36 @@ def testInvalidRedactionTypeWithNoOption(self):
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(e.message, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (str))

def test_encrypt_data_with_token(self):
data = {
def testBothSkyflowIdsAndColumnDetailsPassed(self):
invalidData = {
"records": [
{
"fields": {
"ids": ["123","456"],
"table": "stripe",
}
"ids": ["123", "456"],
"table": "stripe",
"redaction": RedactionType.PLAIN_TEXT,
"columnName": "email",
"columnValues": ["[email protected]", "[email protected]"]
}
]
}
token = "secret_token"
encrypted_bytes = encrypt_data(data, token)
self.assertIsNotNone(encrypted_bytes)

def test_encrypt_data_without_token(self):
data = {
"records": [
{
"fields": {
"ids": ["123", "456"],
"table": "stripe",
}
}
]
options = GetOptions(False)
try:
self.client.get(invalidData, options=options)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(e.message, SkyflowErrorMessages.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value)

def testGetRequestBodyReturnsRequestBodyWithIds(self):
validData = {
"records": [{
"ids": ["123", "456"],
"table": "stripe",
}]
}
token = None
encrypted_data, key = encrypt_data(data, token)
self.assertEqual(encrypted_data, data)
self.assertIsNone(key)



options = GetOptions(True)
try:
requestBody = getGetRequestBody(validData["records"][0], options)
self.assertTrue(requestBody["tokenization"])
except SkyflowError as e:
self.fail('Should not have thrown an error')

0 comments on commit 19d9668

Please sign in to comment.