diff --git a/skyflow/errors/_skyflow_errors.py b/skyflow/errors/_skyflow_errors.py index bc0b8ba..2e79281 100644 --- a/skyflow/errors/_skyflow_errors.py +++ b/skyflow/errors/_skyflow_errors.py @@ -72,6 +72,8 @@ class SkyflowErrorMessages(Enum): INVALID_TOKEN_TYPE = "Token key has value of type %s, expected string" REDACTION_WITH_TOKENS_NOT_SUPPORTED = "Redaction cannot be used when tokens are true in options" TOKENS_GET_COLUMN_NOT_SUPPORTED = "Column_name or column_values cannot be used with tokens in options" + BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED = "Both skyflow ids and column details (name and/or values) are specified in payload" + PARTIAL_SUCCESS = "Server returned errors, check SkyflowError.data for more" VAULT_ID_INVALID_TYPE = "Expected Vault ID to be str, got %s" diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index dd2e320..e426f59 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -4,32 +4,19 @@ import json import types import requests +import asyncio from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody -from skyflow.vault._config import Configuration, GetOptions -from skyflow.vault._config import InsertOptions, ConnectionConfig, UpdateOptions +from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions from skyflow.vault._connection import createRequest from skyflow.vault._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody from skyflow.vault._get_by_id import sendGetByIdRequests, createGetResponseBody from skyflow.vault._get import sendGetRequests -import asyncio -from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages -from skyflow._utils import log_info, InfoMessages, InterfaceName, getMetrics -from skyflow.vault._token import tokenProviderWrapper - -from ._delete import deleteProcessResponse -from ._insert import getInsertRequestBody, processResponse, convertResponse -from ._update import sendUpdateRequests, createUpdateResponseBody -from ._config import Configuration, DeleteOptions, DetokenizeOptions, InsertOptions, ConnectionConfig, UpdateOptions, QueryOptions -from ._connection import createRequest -from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody -from ._get_by_id import sendGetByIdRequests, createGetResponseBody -from ._get import sendGetRequests -import asyncio +from skyflow.vault._delete import deleteProcessResponse +from skyflow.vault._query import getQueryRequestBody, getQueryResponse from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow._utils import log_info, log_error, InfoMessages, InterfaceName, getMetrics -from ._token import tokenProviderWrapper -from ._query import getQueryRequestBody, getQueryResponse +from skyflow.vault._token import tokenProviderWrapper class Client: def __init__(self, config: Configuration): @@ -109,7 +96,7 @@ def get(self, records, options: GetOptions = GetOptions()): self.storedToken, self.tokenProvider, interface) url = self._get_complete_vault_url() responses = asyncio.run(sendGetRequests( - records, options,url, self.storedToken)) + records, options, url, self.storedToken)) result, partial = createGetResponseBody(responses) if partial: raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, diff --git a/skyflow/vault/_get.py b/skyflow/vault/_get.py index abe3c5f..f00ed2e 100644 --- a/skyflow/vault/_get.py +++ b/skyflow/vault/_get.py @@ -11,8 +11,8 @@ interface = InterfaceName.GET.value - def getGetRequestBody(data, options: GetOptions): + requestBody = {} ids = None if "ids" in data: ids = data["ids"] @@ -25,6 +25,7 @@ def getGetRequestBody(data, options: GetOptions): idType = str(type(id)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % ( idType), interface=interface) + requestBody["skyflow_ids"] = ids try: table = data["table"] except KeyError: @@ -32,18 +33,20 @@ def getGetRequestBody(data, options: GetOptions): SkyflowErrorMessages.TABLE_KEY_ERROR, interface=interface) if not isinstance(table, str): tableType = str(type(table)) - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % ( tableType), interface=interface) + else: + requestBody["tableName"] = table - if options.tokens and data.get("redaction"): - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, - SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface) - if options.tokens and (data.get('columnName') or data.get('columnValues')): - raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED, + if options.tokens: + if data.get("redaction"): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface) + if (data.get('columnName') or data.get('columnValues')): + raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED, SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED, interface=interface) - - if not options.tokens: + requestBody["tokenization"] = options.tokens + else: try: redaction = data["redaction"] except KeyError: @@ -53,6 +56,8 @@ def getGetRequestBody(data, options: GetOptions): redactionType = str(type(redaction)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % ( redactionType), interface=interface) + else: + requestBody["redaction"] = redaction.value columnName = None if "columnName" in data: @@ -69,13 +74,17 @@ def getGetRequestBody(data, options: GetOptions): columnValuesType = str(type(columnValues)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % ( columnValuesType), interface=interface) + else: + requestBody["column_name"] = columnName + requestBody["column_values"] = columnValues if (ids is None and (columnName is None or columnValues is None)): raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, - SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface=interface) - return ids, table, redaction.value, columnName, columnValues - return ids, table, "DEFAULT", None, None - + SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR, interface=interface) + elif (ids != None and (columnName != None or columnValues != None)): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED, interface=interface) + return requestBody async def sendGetRequests(data, options: GetOptions, url, token): tasks = [] @@ -97,27 +106,22 @@ async def sendGetRequests(data, options: GetOptions, url, token): validatedRecords = [] for record in records: - ids, table, redaction, columnName, columnValues = getGetRequestBody(record, options) - validatedRecords.append((ids, table, redaction, columnName, columnValues)) + requestBody = getGetRequestBody(record, options) + validatedRecords.append(requestBody) async with ClientSession() as session: for record in validatedRecords: - ids, table, redaction, columnName, columnValues = record headers = { "Authorization": "Bearer " + token, "sky-metadata": json.dumps(getMetrics()) } - params = {"redaction": redaction} - - if ids is not None: - params["skyflow_ids"] = ids - if columnName is not None: - params["column_name"] = columnName - params["column_values"] = columnValues + table = record.pop("tableName") + params = record + if options.tokens: + params["tokenization"] = json.dumps(record["tokenization"]) task = asyncio.ensure_future( - get(url, headers, params, session, record[1], options.tokens) + get(url, headers, params, session, table) ) tasks.append(task) await asyncio.gather(*tasks) await session.close() - return tasks \ No newline at end of file diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index b8cb7af..26e412f 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -11,27 +11,6 @@ interface = InterfaceName.GET_BY_ID.value -def encrypt_data(data, token): - if token: - key = Fernet.generate_key() - fernet = Fernet(key) - encrypted_data = data.copy() - fields = encrypted_data["records"][0]["fields"] - for record in encrypted_data["records"]: - fields = record["fields"] - for key, value in fields.items(): - if isinstance(value, str): - encrypted_value = fernet.encrypt(value.encode()).decode() - fields[key] = encrypted_value - - serialized_data = json.dumps(encrypted_data) - encrypted_bytes = serialized_data.encode() - - return encrypted_bytes - else: - return data, None - - def getGetByIdRequestBody(data): try: ids = data["ids"] @@ -98,21 +77,13 @@ async def sendGetByIdRequests(data, url, token): await session.close() return tasks - -async def get(url, headers, params, session, table,token=False): +async def get(url, headers, params, session, table): async with session.get(url + "/" + table, headers=headers, params=params, ssl=False) as response: try: - response_data = await response.text() - - if token: - data = json.loads(response_data) - return (encrypt_data(data,token), response.status, table, response.headers['x-request-id']) - return (await response.read(), response.status, table, response.headers['x-request-id']) except KeyError: return (await response.read(), response.status, table) - def createGetResponseBody(responses): result = { "records": [], diff --git a/tests/vault/test_get.py b/tests/vault/test_get.py index b179f09..bd98efc 100644 --- a/tests/vault/test_get.py +++ b/tests/vault/test_get.py @@ -4,16 +4,14 @@ import unittest import os -from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages -from skyflow.vault import Client, Configuration, RedactionType, GetOptions -from skyflow.vault._get_by_id import encrypt_data -from skyflow.service_account import generate_bearer_token -from dotenv import dotenv_values import warnings import asyncio import json -from cryptography.fernet import Fernet - +from dotenv import dotenv_values +from skyflow.service_account import generate_bearer_token +from skyflow.vault import Client, Configuration, RedactionType, GetOptions +from skyflow.vault._get import getGetRequestBody +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages class TestGet(unittest.TestCase): @@ -169,7 +167,6 @@ def testGetByIdInvalidColumnValues(self): self.assertEqual( e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str) ) - def testGetByTokenAndRedaction(self): invalidData = {"records": [ {"ids": ["123","456"], @@ -184,8 +181,7 @@ def testGetByTokenAndRedaction(self): e.message, SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value) def testGetByNoOptionAndRedaction(self): - invalidData = {"records":[ - {"ids":["123","456"],"table":"newstripe"}]} + invalidData = {"records":[{"ids":["123", "456"], "table":"newstripe"}]} options = GetOptions(False) try: self.client.get(invalidData,options=options) @@ -196,10 +192,13 @@ def testGetByNoOptionAndRedaction(self): e.message,SkyflowErrorMessages.REDACTION_KEY_ERROR.value) def testGetByOptionAndUniqueColumnRedaction(self): - invalidData ={"records":[ - {"table":"newstripe","columnName":"card_number","columnValues":["456","980"],} - ]} - + invalidData ={ + "records":[{ + "table":"newstripe", + "columnName":"card_number", + "columnValues":["456","980"], + }] + } options = GetOptions(True) try: self.client.get(invalidData, options=options) @@ -210,9 +209,13 @@ def testGetByOptionAndUniqueColumnRedaction(self): e.message, SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED.value) def testInvalidRedactionTypeWithNoOption(self): - invalidData = {"records": [ - {"ids": ["123","456"], - "table": "stripe", "redaction": "invalid_redaction"}]} + invalidData = { + "records": [{ + "ids": ["123","456"], + "table": "stripe", + "redaction": "invalid_redaction" + }] + } options = GetOptions(False) try: self.client.get(invalidData, options=options) @@ -221,36 +224,36 @@ def testInvalidRedactionTypeWithNoOption(self): self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) self.assertEqual(e.message, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (str)) - def test_encrypt_data_with_token(self): - data = { + def testBothSkyflowIdsAndColumnDetailsPassed(self): + invalidData = { "records": [ { - "fields": { - "ids": ["123","456"], - "table": "stripe", - } + "ids": ["123", "456"], + "table": "stripe", + "redaction": RedactionType.PLAIN_TEXT, + "columnName": "email", + "columnValues": ["email1@gmail.com", "email2@gmail.co"] } ] } - token = "secret_token" - encrypted_bytes = encrypt_data(data, token) - self.assertIsNotNone(encrypted_bytes) - - def test_encrypt_data_without_token(self): - data = { - "records": [ - { - "fields": { - "ids": ["123", "456"], - "table": "stripe", - } - } - ] + options = GetOptions(False) + try: + self.client.get(invalidData, options=options) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual(e.message, SkyflowErrorMessages.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value) + + def testGetRequestBodyReturnsRequestBodyWithIds(self): + validData = { + "records": [{ + "ids": ["123", "456"], + "table": "stripe", + }] } - token = None - encrypted_data, key = encrypt_data(data, token) - self.assertEqual(encrypted_data, data) - self.assertIsNone(key) - - - + options = GetOptions(True) + try: + requestBody = getGetRequestBody(validData["records"][0], options) + self.assertTrue(requestBody["tokenization"]) + except SkyflowError as e: + self.fail('Should not have thrown an error') \ No newline at end of file