diff --git a/samples/get_sample.py b/samples/get_sample.py new file mode 100644 index 0000000..867d9aa --- /dev/null +++ b/samples/get_sample.py @@ -0,0 +1,43 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +from skyflow.errors import SkyflowError +from skyflow.service_account import generate_bearer_token, is_expired +from skyflow.vault import Client, Configuration, RedactionType + + +# cache token for reuse +bearerToken = '' + + +def token_provider(): + global bearerToken + if is_expired(bearerToken): + bearerToken, _ = generate_bearer_token('') + return bearerToken + + +try: + config = Configuration( + '', '', token_provider) + client = Client(config) + + data = {"records": [ + { + "ids": ["", "", ""], + "table": "", + "redaction": RedactionType.PLAIN_TEXT + }, + #To get records using unique column name and values. + { + "redaction" : "", + "table": "", + "columnName": "", + "columnValues": "[,]", + } + ]} + + response = client.get(data) + print('Response:', response) +except SkyflowError as e: + print('Error Occurred:', e) diff --git a/samples/update_sample.py b/samples/update_sample.py new file mode 100644 index 0000000..9c6ea90 --- /dev/null +++ b/samples/update_sample.py @@ -0,0 +1,39 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +from skyflow.errors import SkyflowError +from skyflow.service_account import generate_bearer_token, is_expired +from skyflow.vault import Client, UpdateOptions, Configuration + +# cache token for reuse +bearerToken = '' + +def token_provider(): + global bearerToken + if is_expired(bearerToken): + bearerToken, _ = generate_bearer_token('') + return bearerToken + + +try: + config = Configuration( + '', '', token_provider) + client = Client(config) + + options = UpdateOptions(True) + + data = { + "records": [ + { + "id": "", + "table": "", + "fields": { + "": "" + } + } + ] + } + response = client.update(data, options=options) + print('Response:', response) +except SkyflowError as e: + print('Error Occurred:', e) diff --git a/skyflow/_utils.py b/skyflow/_utils.py index fa32b69..5186dbe 100644 --- a/skyflow/_utils.py +++ b/skyflow/_utils.py @@ -68,6 +68,10 @@ class InfoMessages(Enum): IS_EXPIRED_TRIGGERED = "is_expired() triggered" EMPTY_ACCESS_TOKEN = "Give access token is empty" INVALID_TOKEN = "Given token is invalid" + UPDATE_TRIGGERED = "Update method triggered" + UPDATE_DATA_SUCCESS = "Data has been updated successfully" + GET_TRIGGERED = "Get triggered." + GET_SUCCESS = "Data fetched successfully." class InterfaceName(Enum): @@ -75,6 +79,8 @@ class InterfaceName(Enum): INSERT = "client.insert" DETOKENIZE = "client.detokenize" GET_BY_ID = "client.get_by_id" + GET = "client.get" + UPDATE = "client.update" INVOKE_CONNECTION = "client.invoke_connection" GENERATE_BEARER_TOKEN = "service_account.generate_bearer_token" diff --git a/skyflow/errors/_skyflow_errors.py b/skyflow/errors/_skyflow_errors.py index d617898..a2dbbf0 100644 --- a/skyflow/errors/_skyflow_errors.py +++ b/skyflow/errors/_skyflow_errors.py @@ -34,16 +34,20 @@ class SkyflowErrorMessages(Enum): FIELDS_KEY_ERROR = "Fields key is missing from payload" TABLE_KEY_ERROR = "Table key is missing from payload" TOKEN_KEY_ERROR = "Token key is missing from payload" - IDS_KEY_ERROR = "Ids key is missing from payload" + IDS_KEY_ERROR = "Id(s) key is missing from payload" REDACTION_KEY_ERROR = "Redaction key is missing from payload" + UNIQUE_COLUMN_OR_IDS_KEY_ERROR = "Ids or Unique column key is missing from payload" + UPDATE_FIELD_KEY_ERROR = "Atleast one field should be provided to update" INVALID_JSON = "Given %s is invalid JSON" INVALID_RECORDS_TYPE = "Records key has value of type %s, expected list" - INVALID_FIELDS_TYPE = "Fields key has value of type %s, expected string" + INVALID_FIELDS_TYPE = "Fields key has value of type %s, expected dict" INVALID_TABLE_TYPE = "Table key has value of type %s, expected string" INVALID_IDS_TYPE = "Ids key has value of type %s, expected list" INVALID_ID_TYPE = "Id key has value of type %s, expected string" INVALID_REDACTION_TYPE = "Redaction key has value of type %s, expected Skyflow.Redaction" + INVALID_COLUMN_NAME = "Column name has value of type %s, expected string" + INVALID_COLUMN_VALUE = "Column values has value of type %s, expected list" INVALID_REQUEST_BODY = "Given request body is not valid" INVALID_RESPONSE_BODY = "Given response body is not valid" diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index 8aeeb8f..ff7aaa4 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -4,11 +4,13 @@ import types import requests from ._insert import getInsertRequestBody, processResponse, convertResponse +from ._update import sendUpdateRequests, createUpdateResponseBody from ._config import Configuration -from ._config import InsertOptions, ConnectionConfig +from ._config import InsertOptions, ConnectionConfig, UpdateOptions from ._connection import createRequest from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody -from ._get_by_id import sendGetByIdRequests, createGetByIdResponseBody +from ._get_by_id import sendGetByIdRequests, createGetResponseBody +from ._get import sendGetRequests import asyncio from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow._utils import log_info, InfoMessages, InterfaceName @@ -78,6 +80,25 @@ def detokenize(self, records): log_info(InfoMessages.DETOKENIZE_SUCCESS.value, interface) return result + def get(self, records): + interface = InterfaceName.GET.value + log_info(InfoMessages.GET_TRIGGERED.value, interface) + + self._checkConfig(interface) + self.storedToken = tokenProviderWrapper( + self.storedToken, self.tokenProvider, interface) + url = self._get_complete_vault_url() + responses = asyncio.run(sendGetRequests( + records, url, self.storedToken)) + result, partial = createGetResponseBody(responses) + if partial: + raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, + SkyflowErrorMessages.PARTIAL_SUCCESS, result, interface=interface) + else: + log_info(InfoMessages.GET_SUCCESS.value, interface) + + return result + def get_by_id(self, records): interface = InterfaceName.GET_BY_ID.value log_info(InfoMessages.GET_BY_ID_TRIGGERED.value, interface) @@ -88,7 +109,7 @@ def get_by_id(self, records): url = self._get_complete_vault_url() responses = asyncio.run(sendGetByIdRequests( records, url, self.storedToken)) - result, partial = createGetByIdResponseBody(responses) + result, partial = createGetResponseBody(responses) if partial: raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, SkyflowErrorMessages.PARTIAL_SUCCESS, result, interface=interface) @@ -130,3 +151,21 @@ def _get_complete_vault_url(self): Get the complete vault url from given vault url and vault id ''' return self.vaultURL + "/v1/vaults/" + self.vaultID + + def update(self, updateInput, options: UpdateOptions = UpdateOptions()): + interface = InterfaceName.UPDATE.value + log_info(InfoMessages.UPDATE_TRIGGERED.value, interface=interface) + + self._checkConfig(interface) + self.storedToken = tokenProviderWrapper( + self.storedToken, self.tokenProvider, interface) + url = self._get_complete_vault_url() + responses = asyncio.run(sendUpdateRequests( + updateInput, options, url, self.storedToken)) + result, partial = createUpdateResponseBody(responses) + if partial: + raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, + SkyflowErrorMessages.PARTIAL_SUCCESS, result, interface=interface) + else: + log_info(InfoMessages.UPDATE_DATA_SUCCESS.value, interface) + return result \ No newline at end of file diff --git a/skyflow/vault/_config.py b/skyflow/vault/_config.py index 741a99a..25fe3cd 100644 --- a/skyflow/vault/_config.py +++ b/skyflow/vault/_config.py @@ -34,6 +34,10 @@ def __init__(self, tokens: bool=True,upsert :List[UpsertOption]=None): self.tokens = tokens self.upsert = upsert +class UpdateOptions: + def __init__(self, tokens: bool=True): + self.tokens = tokens + class RequestMethod(Enum): GET = 'GET' POST = 'POST' diff --git a/skyflow/vault/_get.py b/skyflow/vault/_get.py new file mode 100644 index 0000000..0670206 --- /dev/null +++ b/skyflow/vault/_get.py @@ -0,0 +1,99 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +import asyncio +from aiohttp import ClientSession +from ._config import RedactionType +from skyflow._utils import InterfaceName +from ._get_by_id import get + +interface = InterfaceName.GET.value + +def getGetRequestBody(data): + ids = None + if "ids" in data: + ids = data["ids"] + if not isinstance(ids, list): + idsType = str(type(ids)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.INVALID_IDS_TYPE.value % (idsType), interface=interface) + for id in ids: + if not isinstance(id, str): + idType = str(type(id)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % ( + idType), interface=interface) + try: + table = data["table"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.TABLE_KEY_ERROR, interface=interface) + if not isinstance(table, str): + tableType = str(type(table)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % ( + tableType), interface=interface) + try: + redaction = data["redaction"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.REDACTION_KEY_ERROR, interface=interface) + if not isinstance(redaction, RedactionType): + redactionType = str(type(redaction)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % ( + redactionType), interface=interface) + + columnName = None + if "columnName" in data: + columnName = data["columnName"] + if not isinstance(columnName, str): + columnNameType = str(type(columnName)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % ( + columnNameType), interface=interface) + + columnValues = None + if columnName is not None and "columnValues" in data: + columnValues = data["columnValues"] + if not isinstance(columnValues, list): + columnValuesType= str(type(columnValues)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % ( + columnValuesType), interface=interface) + + if(ids is None and (columnName is None or columnValues is None)): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface= interface) + return ids, table, redaction.value, columnName, columnValues + + +async def sendGetRequests(data, url, token): + tasks = [] + try: + records = data["records"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.RECORDS_KEY_ERROR, interface=interface) + if not isinstance(records, list): + recordsType = str(type(records)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % ( + recordsType), interface=interface) + + validatedRecords = [] + for record in records: + ids, table, redaction, columnName, columnValues = getGetRequestBody(record) + validatedRecords.append((ids, table, redaction, columnName, columnValues)) + async with ClientSession() as session: + for record in validatedRecords: + headers = { + "Authorization": "Bearer " + token + } + params = {"redaction": redaction} + if ids is not None: + params["skyflow_ids"] = ids + if columnName is not None: + params["column_name"] = columnName + params["column_values"] = columnValues + task = asyncio.ensure_future( + get(url, headers, params, session, record[1])) + tasks.append(task) + await asyncio.gather(*tasks) + await session.close() + return tasks \ No newline at end of file diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index fbda2b8..7cd3adc 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -85,7 +85,7 @@ async def get(url, headers, params, session, table): return (await response.read(), response.status, table) -def createGetByIdResponseBody(responses): +def createGetResponseBody(responses): result = { "records": [], "errors": [] @@ -114,6 +114,5 @@ def createGetByIdResponseBody(responses): if len(r) > 3 and r[3] != None: temp["error"]["description"] += ' - Request ID: ' + str(r[3]) result["errors"].append(temp) - result["errors"].append(temp) partial = True return result, partial diff --git a/skyflow/vault/_insert.py b/skyflow/vault/_insert.py index 94660f7..e0150b4 100644 --- a/skyflow/vault/_insert.py +++ b/skyflow/vault/_insert.py @@ -74,7 +74,7 @@ def getTableAndFields(record): SkyflowErrorMessages.FIELDS_KEY_ERROR, interface=interface) if not isinstance(fields, dict): - fieldsType = str(type(table)) + fieldsType = str(type(fields)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_FIELDS_TYPE.value % ( fieldsType), interface=interface) diff --git a/skyflow/vault/_update.py b/skyflow/vault/_update.py new file mode 100644 index 0000000..ade08b2 --- /dev/null +++ b/skyflow/vault/_update.py @@ -0,0 +1,105 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +import json + +import asyncio +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +from ._insert import getTableAndFields +from skyflow._utils import InterfaceName +from aiohttp import ClientSession +from ._config import UpdateOptions + +interface = InterfaceName.UPDATE.value + +async def sendUpdateRequests(data,options: UpdateOptions,url,token): + tasks = [] + + try: + records = data["records"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.RECORDS_KEY_ERROR, interface=interface) + if not isinstance(records, list): + recordsType = str(type(records)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % ( + recordsType), interface=interface) + + validatedRecords = [] + for record in records: + tableName = validateUpdateRecord(record) + validatedRecords.append(record) + async with ClientSession() as session: + for record in validatedRecords: + recordUrl = url +'/'+ tableName +'/'+ record["id"] + reqBody = { + "record": { + "fields": record["fields"] + }, + "tokenization": options.tokens + } + reqBody = json.dumps(reqBody) + headers = { + "Authorization": "Bearer " + token + } + task = asyncio.ensure_future(put(recordUrl, reqBody, headers, session)) + tasks.append(task) + await asyncio.gather(*tasks) + await session.close() + return tasks + +def validateUpdateRecord(record): + try: + id = record["id"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.IDS_KEY_ERROR, interface=interface) + if not isinstance(id, str): + idType = str(type(id)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.INVALID_ID_TYPE.value % (idType), interface=interface) + table, fields = getTableAndFields(record) + keysLength = len(fields.keys()) + if(keysLength < 1): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.UPDATE_FIELD_KEY_ERROR, interface= interface) + return table + +async def put(url, data, headers, session): + async with session.put(url, data=data, headers=headers, ssl=False) as response: + try: + return (await response.read(), response.status, response.headers['x-request-id']) + except KeyError: + return (await response.read(), response.status) + + +def createUpdateResponseBody(responses): + result = { + "records": [], + "errors": [] + } + partial = False + for response in responses: + r = response.result() + status = r[1] + try: + jsonRes = json.loads(r[0].decode('utf-8')) + except: + raise SkyflowError(status, + SkyflowErrorMessages.RESPONSE_NOT_JSON.value % r[0].decode('utf-8'), interface=interface) + + if status == 200: + temp = {} + temp["id"] = jsonRes["skyflow_id"] + if "tokens" in jsonRes: + temp["fields"] = jsonRes["tokens"] + result["records"].append(temp) + else: + temp = {"error": {}} + temp["error"]["code"] = jsonRes["error"]["http_code"] + temp["error"]["description"] = jsonRes["error"]["message"] + if len(r) > 2 and r[2] != None: + temp["error"]["description"] += ' - Request ID: ' + str(r[2]) + result["errors"].append(temp) + partial = True + return result, partial diff --git a/tests/vault/test_get.py b/tests/vault/test_get.py new file mode 100644 index 0000000..95171a0 --- /dev/null +++ b/tests/vault/test_get.py @@ -0,0 +1,188 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +import unittest +import os + +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +from skyflow.vault import Client, Configuration, RedactionType +from skyflow.service_account import generate_bearer_token +from dotenv import dotenv_values +import warnings +import asyncio +import json + + +class TestGet(unittest.TestCase): + + def setUp(self) -> None: + self.envValues = dotenv_values(".env") + self.dataPath = os.path.join(os.getcwd(), 'tests/vault/data/') + self.event_loop = asyncio.new_event_loop() + self.mocked_futures = [] + + def tokenProvider(): + token, type = generate_bearer_token( + self.envValues["CREDENTIALS_FILE_PATH"]) + return token + + config = Configuration( + self.envValues["VAULT_ID"], self.envValues["VAULT_URL"], tokenProvider) + self.client = Client(config) + warnings.filterwarnings( + action="ignore", message="unclosed", category=ResourceWarning) + return super().setUp() + + def add_mock_response(self, response, statusCode, table, encode=True): + future = asyncio.Future(loop=self.event_loop) + if encode: + future.set_result( + (json.dumps(response).encode(), statusCode, table)) + else: + future.set_result((response, statusCode, table)) + future.done() + self.mocked_futures.append(future) + + def getDataPath(self, file): + return self.dataPath + file + '.json' + + def testGetByIdNoRecords(self): + invalidData = {"invalidKey": "invalid"} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.RECORDS_KEY_ERROR.value) + + def testGetByIdRecordsInvalidType(self): + invalidData = {"records": "invalid"} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % (str)) + + def testGetByIdNoIds(self): + invalidData = {"records": [ + {"invalid": "invalid", "table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + + def testGetByIdInvalidIdsType(self): + invalidData = {"records": [ + {"ids": "invalid", "table": "pii_fields", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_IDS_TYPE.value % (str)) + + def testGetByIdInvalidIdsType2(self): + invalidData = {"records": [ + {"ids": ["123", 123], "table": "pii_fields", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_ID_TYPE.value % (int)) + + def testGetByIdNoTable(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "invalid": "invalid", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.TABLE_KEY_ERROR.value) + + def testGetByIdInvalidTableType(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "table": ["invalid"], "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % (list)) + + def testGetByIdNoRedaction(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "table": "pii_fields", "invalid": "invalid"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.REDACTION_KEY_ERROR.value) + + def testGetByIdInvalidRedactionType(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "table": "pii_fields", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (str)) + + def testGetByIdNoColumnName(self): + invalidData = {"records": [ + {"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + + def testGetByIdInvalidColumnName(self): + invalidData = {"records": [ + {"ids": ["123", "456"],"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": ["invalid"]}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % (list)) + + def testGetByIdNoColumnValues(self): + invalidData = {"records": [ + {"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": "first_name"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + + def testGetByIdInvalidColumnValues(self): + invalidData = {"records": [ + {"ids": ["123", "456"], "table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": "first_name", "columnValues": "invalid"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str)) diff --git a/tests/vault/test_get_by_id.py b/tests/vault/test_get_by_id.py index 9e87428..d967623 100644 --- a/tests/vault/test_get_by_id.py +++ b/tests/vault/test_get_by_id.py @@ -4,10 +4,9 @@ import unittest import os -import aiohttp from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow.vault import Client, Configuration, RedactionType -from skyflow.vault._get_by_id import createGetByIdResponseBody +from skyflow.vault._get_by_id import createGetResponseBody from skyflow.service_account import generate_bearer_token from dotenv import dotenv_values import warnings @@ -149,7 +148,7 @@ def testCreateResponseBodySuccess(self): response = {"records": [ {"fields": {"card_number": "4111-1111-1111-1111"}}]} self.add_mock_response(response, 200, "table") - result, partial = createGetByIdResponseBody(self.mocked_futures) + result, partial = createGetResponseBody(self.mocked_futures) self.assertFalse(partial) self.assertEqual(len(result["records"]), 1) @@ -168,7 +167,7 @@ def testCreateResponseBodyPartialSuccess(self): }} self.add_mock_response(failed_response, 404, "ok") - result, partial = createGetByIdResponseBody(self.mocked_futures) + result, partial = createGetResponseBody(self.mocked_futures) self.assertTrue(partial) self.assertEqual(len(result["records"]), 1) @@ -187,8 +186,8 @@ def testCreateResponseBodyInvalidJson(self): self.add_mock_response(response.encode(), 200, 'table', encode=False) try: - createGetByIdResponseBody(self.mocked_futures) + createGetResponseBody(self.mocked_futures) except SkyflowError as error: expectedError = SkyflowErrorMessages.RESPONSE_NOT_JSON self.assertEqual(error.code, 200) - self.assertEqual(error.message, expectedError.value % response) + self.assertEqual(error.message, expectedError.value % response) \ No newline at end of file diff --git a/tests/vault/test_update.py b/tests/vault/test_update.py new file mode 100644 index 0000000..c6a00ef --- /dev/null +++ b/tests/vault/test_update.py @@ -0,0 +1,184 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +import json +import unittest +import os +import asyncio +import warnings + +from dotenv import dotenv_values +from skyflow.vault._client import Client, Configuration +from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +from skyflow.service_account import generate_bearer_token +from skyflow.vault._client import Client +from skyflow.vault._config import UpdateOptions + + +class TestUpdate(unittest.TestCase): + + def setUp(self) -> None: + self.envValues = dotenv_values(".env") + self.dataPath = os.path.join(os.getcwd(), 'tests/vault/data/') + self.mocked_futures = [] + self.event_loop = asyncio.new_event_loop() + + def tokenProvider(): + token, _ = generate_bearer_token( + self.envValues["CREDENTIALS_FILE_PATH"]) + return token + + config = Configuration( + self.envValues["VAULT_ID"], self.envValues["VAULT_URL"], tokenProvider) + self.client = Client(config) + warnings.filterwarnings( + action="ignore", message="unclosed", category=ResourceWarning) + return super().setUp() + + def add_mock_response(self, response, statusCode, encode=True): + future = asyncio.Future(loop=self.event_loop) + if encode: + future.set_result((json.dumps(response).encode(), statusCode)) + else: + future.set_result((response, statusCode)) + future.done() + self.mocked_futures.append(future) + + def getDataPath(self, file): + return self.dataPath + file + '.json' + + def testUpdateNoRecords(self): + invalidData = {} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.RECORDS_KEY_ERROR.value) + + def testUpdateInvalidType(self): + invalidData = {"records": "invalid"} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % (str)) + + def testUpdateNoIds(self): + invalidData = {"records": [ + {"table": "pii_fields"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.IDS_KEY_ERROR.value) + + def testUpdateInvalidIdType(self): + invalidData = {"records": [ + {"id": ["123"], "table": "pii_fields"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_ID_TYPE.value % (list)) + + def testUpdateNoTable(self): + invalidData = {"records": [ + {"id": "id"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.TABLE_KEY_ERROR.value) + + def testUpdateInvalidTableType(self): + invalidData = {"records": [ + {"id": "id1", "table": ["invalid"]}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % (list)) + + def testUpdateNoFields(self): + invalidData = {"records": [ + {"id": "id", "table": "pii_fields"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.FIELDS_KEY_ERROR.value) + + def testUpdateInvalidFieldsType(self): + invalidData = {"records": [ + {"id": "id1", "table": "pii_fields", "fields": "invalid"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_FIELDS_TYPE.value % (str)) + + def testUpdateInvalidFieldsType2(self): + invalidData = {"records": [ + {"id": "id1", "table": "pii_fields", "fields": {}}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UPDATE_FIELD_KEY_ERROR.value) + + def testResponseBodySuccess(self): + response = {"skyflow_id": "123", "tokens": {"first_name": "John"}} + mock_response = [{"id": "123", "fields": {"first_name": "John"}}] + self.add_mock_response(response, 200) + print("Seld.mockedFuturs", self.mocked_futures) + res, partial = createUpdateResponseBody(self.mocked_futures) + self.assertEqual(partial, False) + self.assertEqual(res, {"records": mock_response, "errors": []}) + + def testResponseBodyPartialSuccess(self): + success_response = {"skyflow_id": "123", "tokens": {"first_name": "John"}} + mock_success_response = [{"id": "123", "fields": {"first_name": "John"}}] + error_response = {"error": {"http_code": 404, "message": "not found"}} + self.add_mock_response(success_response, 200) + self.add_mock_response(error_response, 404) + res, partial = createUpdateResponseBody(self.mocked_futures) + self.assertTrue(partial) + self.assertEqual(res["records"], mock_success_response) + errors = res["errors"] + + self.assertIsNotNone(errors) + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0]["error"]["code"], + error_response["error"]["http_code"]) + self.assertEqual( + errors[0]["error"]["description"], error_response["error"]["message"]) + + def testResponseNotJson(self): + response = "not a valid json".encode() + self.add_mock_response(response, 200, encode=False) + try: + createUpdateResponseBody(self.mocked_futures) + except SkyflowError as error: + expectedError = SkyflowErrorMessages.RESPONSE_NOT_JSON + self.assertEqual(error.code, 200) + self.assertEqual(error.message, expectedError.value % + response.decode('utf-8'))