Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SK-1308 Fix bug in Get method in Python SDK #114

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

All notable changes to this project will be documented in this file.

## [1.15.1] - 2023-12-07
## Fixed
- Not receiving tokens when calling Get with options tokens as true.

## [1.15.0] - 2023-10-30
## Added
- options tokens support for Get method.
Expand Down
2 changes: 2 additions & 0 deletions skyflow/errors/_skyflow_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class SkyflowErrorMessages(Enum):
INVALID_TOKEN_TYPE = "Token key has value of type %s, expected string"
REDACTION_WITH_TOKENS_NOT_SUPPORTED = "Redaction cannot be used when tokens are true in options"
TOKENS_GET_COLUMN_NOT_SUPPORTED = "Column_name or column_values cannot be used with tokens in options"
BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED = "Both skyflow ids and column details (name and/or values) are specified in payload"

PARTIAL_SUCCESS = "Server returned errors, check SkyflowError.data for more"

VAULT_ID_INVALID_TYPE = "Expected Vault ID to be str, got %s"
Expand Down
25 changes: 6 additions & 19 deletions skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,19 @@
import json
import types
import requests
import asyncio
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse
from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody
from skyflow.vault._config import Configuration, GetOptions
from skyflow.vault._config import InsertOptions, ConnectionConfig, UpdateOptions
from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions
from skyflow.vault._connection import createRequest
from skyflow.vault._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
from skyflow.vault._get_by_id import sendGetByIdRequests, createGetResponseBody
from skyflow.vault._get import sendGetRequests
import asyncio
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import log_info, InfoMessages, InterfaceName, getMetrics
from skyflow.vault._token import tokenProviderWrapper

from ._delete import deleteProcessResponse
from ._insert import getInsertRequestBody, processResponse, convertResponse
from ._update import sendUpdateRequests, createUpdateResponseBody
from ._config import Configuration, DeleteOptions, DetokenizeOptions, InsertOptions, ConnectionConfig, UpdateOptions, QueryOptions
from ._connection import createRequest
from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
from ._get_by_id import sendGetByIdRequests, createGetResponseBody
from ._get import sendGetRequests
import asyncio
from skyflow.vault._delete import deleteProcessResponse
from skyflow.vault._query import getQueryRequestBody, getQueryResponse
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import log_info, log_error, InfoMessages, InterfaceName, getMetrics
from ._token import tokenProviderWrapper
from ._query import getQueryRequestBody, getQueryResponse
from skyflow.vault._token import tokenProviderWrapper

class Client:
def __init__(self, config: Configuration):
Expand Down Expand Up @@ -109,7 +96,7 @@ def get(self, records, options: GetOptions = GetOptions()):
self.storedToken, self.tokenProvider, interface)
url = self._get_complete_vault_url()
responses = asyncio.run(sendGetRequests(
records, options,url, self.storedToken))
records, options, url, self.storedToken))
result, partial = createGetResponseBody(responses)
if partial:
raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS,
Expand Down
54 changes: 29 additions & 25 deletions skyflow/vault/_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

interface = InterfaceName.GET.value


def getGetRequestBody(data, options: GetOptions):
requestBody = {}
ids = None
if "ids" in data:
ids = data["ids"]
Expand All @@ -25,25 +25,28 @@
idType = str(type(id))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % (
idType), interface=interface)
requestBody["skyflow_ids"] = ids
try:
table = data["table"]
except KeyError:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.TABLE_KEY_ERROR, interface=interface)
if not isinstance(table, str):
tableType = str(type(table))

raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % (
tableType), interface=interface)
else:
requestBody["tableName"] = table

if options.tokens and data.get("redaction"):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface)
if options.tokens and (data.get('columnName') or data.get('columnValues')):
raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED,
if options.tokens:
if data.get("redaction"):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface)
if (data.get('columnName') or data.get('columnValues')):
raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED,
SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED, interface=interface)

if not options.tokens:
requestBody["tokenization"] = options.tokens
else:
try:
redaction = data["redaction"]
except KeyError:
Expand All @@ -53,6 +56,8 @@
redactionType = str(type(redaction))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (
redactionType), interface=interface)
else:
requestBody["redaction"] = redaction.value

columnName = None
if "columnName" in data:
Expand All @@ -69,13 +74,17 @@
columnValuesType = str(type(columnValues))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (
columnValuesType), interface=interface)
else:
requestBody["column_name"] = columnName
requestBody["column_values"] = columnValues

if (ids is None and (columnName is None or columnValues is None)):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface=interface)
return ids, table, redaction.value, columnName, columnValues
return ids, table, "DEFAULT", None, None

SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR, interface=interface)
elif (ids != None and (columnName != None or columnValues != None)):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED, interface=interface)
return requestBody

async def sendGetRequests(data, options: GetOptions, url, token):
tasks = []
Expand All @@ -97,27 +106,22 @@

validatedRecords = []
for record in records:
ids, table, redaction, columnName, columnValues = getGetRequestBody(record, options)
validatedRecords.append((ids, table, redaction, columnName, columnValues))
requestBody = getGetRequestBody(record, options)
validatedRecords.append(requestBody)

Check warning on line 110 in skyflow/vault/_get.py

View check run for this annotation

Codecov / codecov/patch

skyflow/vault/_get.py#L110

Added line #L110 was not covered by tests
async with ClientSession() as session:
for record in validatedRecords:
ids, table, redaction, columnName, columnValues = record
headers = {
"Authorization": "Bearer " + token,
"sky-metadata": json.dumps(getMetrics())
}
params = {"redaction": redaction}

if ids is not None:
params["skyflow_ids"] = ids
if columnName is not None:
params["column_name"] = columnName
params["column_values"] = columnValues
table = record.pop("tableName")
params = record
if options.tokens:
params["tokenization"] = json.dumps(record["tokenization"])

Check warning on line 120 in skyflow/vault/_get.py

View check run for this annotation

Codecov / codecov/patch

skyflow/vault/_get.py#L117-L120

Added lines #L117 - L120 were not covered by tests
task = asyncio.ensure_future(
get(url, headers, params, session, record[1], options.tokens)
get(url, headers, params, session, table)
)
tasks.append(task)
await asyncio.gather(*tasks)
await session.close()

return tasks
34 changes: 2 additions & 32 deletions skyflow/vault/_get_by_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,9 @@
import json
from ._config import RedactionType
from skyflow._utils import InterfaceName, getMetrics
from cryptography.fernet import Fernet

interface = InterfaceName.GET_BY_ID.value

def encrypt_data(data, token):
if token:
key = Fernet.generate_key()
fernet = Fernet(key)
encrypted_data = data.copy()
fields = encrypted_data["records"][0]["fields"]
for record in encrypted_data["records"]:
fields = record["fields"]
for key, value in fields.items():
if isinstance(value, str):
encrypted_value = fernet.encrypt(value.encode()).decode()
fields[key] = encrypted_value

serialized_data = json.dumps(encrypted_data)
encrypted_bytes = serialized_data.encode()

return encrypted_bytes
else:
return data, None


def getGetByIdRequestBody(data):
try:
ids = data["ids"]
Expand Down Expand Up @@ -98,28 +76,20 @@ async def sendGetByIdRequests(data, url, token):
await session.close()
return tasks


async def get(url, headers, params, session, table,token=False):
async def get(url, headers, params, session, table):
async with session.get(url + "/" + table, headers=headers, params=params, ssl=False) as response:
try:
response_data = await response.text()

if token:
data = json.loads(response_data)
return (encrypt_data(data,token), response.status, table, response.headers['x-request-id'])

return (await response.read(), response.status, table, response.headers['x-request-id'])
except KeyError:
return (await response.read(), response.status, table)


def createGetResponseBody(responses):
result = {
"records": [],
"errors": []
}
partial = False
for response in responses:
partial = False
r = response.result()
status = r[1]
try:
Expand Down
91 changes: 47 additions & 44 deletions tests/vault/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import unittest
import os

from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow.vault import Client, Configuration, RedactionType, GetOptions
from skyflow.vault._get_by_id import encrypt_data
from skyflow.service_account import generate_bearer_token
from dotenv import dotenv_values
import warnings
import asyncio
import json
from cryptography.fernet import Fernet

from dotenv import dotenv_values
from skyflow.service_account import generate_bearer_token
from skyflow.vault import Client, Configuration, RedactionType, GetOptions
from skyflow.vault._get import getGetRequestBody
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages

class TestGet(unittest.TestCase):

Expand Down Expand Up @@ -169,7 +167,6 @@ def testGetByIdInvalidColumnValues(self):
self.assertEqual(
e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str) )


def testGetByTokenAndRedaction(self):
invalidData = {"records": [
{"ids": ["123","456"],
Expand All @@ -184,8 +181,7 @@ def testGetByTokenAndRedaction(self):
e.message, SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value)

def testGetByNoOptionAndRedaction(self):
invalidData = {"records":[
{"ids":["123","456"],"table":"newstripe"}]}
invalidData = {"records":[{"ids":["123", "456"], "table":"newstripe"}]}
options = GetOptions(False)
try:
self.client.get(invalidData,options=options)
Expand All @@ -196,10 +192,13 @@ def testGetByNoOptionAndRedaction(self):
e.message,SkyflowErrorMessages.REDACTION_KEY_ERROR.value)

def testGetByOptionAndUniqueColumnRedaction(self):
invalidData ={"records":[
{"table":"newstripe","columnName":"card_number","columnValues":["456","980"],}
]}

invalidData ={
"records":[{
"table":"newstripe",
"columnName":"card_number",
"columnValues":["456","980"],
}]
}
options = GetOptions(True)
try:
self.client.get(invalidData, options=options)
Expand All @@ -210,9 +209,13 @@ def testGetByOptionAndUniqueColumnRedaction(self):
e.message, SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED.value)

def testInvalidRedactionTypeWithNoOption(self):
invalidData = {"records": [
{"ids": ["123","456"],
"table": "stripe", "redaction": "invalid_redaction"}]}
invalidData = {
"records": [{
"ids": ["123","456"],
"table": "stripe",
"redaction": "invalid_redaction"
}]
}
options = GetOptions(False)
try:
self.client.get(invalidData, options=options)
Expand All @@ -221,36 +224,36 @@ def testInvalidRedactionTypeWithNoOption(self):
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(e.message, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (str))

def test_encrypt_data_with_token(self):
data = {
def testBothSkyflowIdsAndColumnDetailsPassed(self):
invalidData = {
"records": [
{
"fields": {
"ids": ["123","456"],
"table": "stripe",
}
"ids": ["123", "456"],
"table": "stripe",
"redaction": RedactionType.PLAIN_TEXT,
"columnName": "email",
"columnValues": ["[email protected]", "[email protected]"]
}
]
}
token = "secret_token"
encrypted_bytes = encrypt_data(data, token)
self.assertIsNotNone(encrypted_bytes)

def test_encrypt_data_without_token(self):
data = {
"records": [
{
"fields": {
"ids": ["123", "456"],
"table": "stripe",
}
}
]
options = GetOptions(False)
try:
self.client.get(invalidData, options=options)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(e.message, SkyflowErrorMessages.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value)

def testGetRequestBodyReturnsRequestBodyWithIds(self):
validData = {
"records": [{
"ids": ["123", "456"],
"table": "stripe",
}]
}
token = None
encrypted_data, key = encrypt_data(data, token)
self.assertEqual(encrypted_data, data)
self.assertIsNone(key)



options = GetOptions(True)
try:
requestBody = getGetRequestBody(validData["records"][0], options)
self.assertTrue(requestBody["tokenization"])
except SkyflowError as e:
self.fail('Should not have thrown an error')
Loading