From 3046dc4a66c9fcbee9c11dddc74bf3f498e7a714 Mon Sep 17 00:00:00 2001 From: yaswanth-pula-skyflow Date: Fri, 30 Dec 2022 11:21:39 +0530 Subject: [PATCH 01/15] SK-260 add unique columns support in getByID method. --- skyflow/vault/_get_by_id.py | 61 ++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index fbda2b8..8d951a9 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -12,20 +12,23 @@ def getGetByIdRequestBody(data): - try: + # try: + # ids = data["ids"] + # except KeyError: + # raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + # SkyflowErrorMessages.IDS_KEY_ERROR, interface=interface) + ids = None + if "ids" in data: ids = data["ids"] - except KeyError: - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, - SkyflowErrorMessages.IDS_KEY_ERROR, interface=interface) - if not isinstance(ids, list): - idsType = str(type(ids)) - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, - SkyflowErrorMessages.INVALID_IDS_TYPE.value % (idsType), interface=interface) - for id in ids: - if not isinstance(id, str): - idType = str(type(id)) - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % ( - idType), interface=interface) + if not isinstance(ids, list): + idsType = str(type(ids)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.INVALID_IDS_TYPE.value % (idsType), interface=interface) + for id in ids: + if not isinstance(id, str): + idType = str(type(id)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % ( + idType), interface=interface) try: table = data["table"] except KeyError: @@ -44,7 +47,21 @@ def getGetByIdRequestBody(data): redactionType = str(type(redaction)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % ( redactionType), interface=interface) - return ids, table, redaction.value + + columnName = None + if "columnName" in data: + columnName = data["columnName"] + if not isinstance(columnName, str): + columnName = str(type(columnName)) + # changes error text + # raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % ( + # tableType), interface=interface) + + columnValues = None + if "columnValues" in data: + columnValues = data["columnValues"] + + return ids, table, redaction.value, columnName, columnValues async def sendGetByIdRequests(data, url, token): @@ -61,14 +78,22 @@ async def sendGetByIdRequests(data, url, token): validatedRecords = [] for record in records: - ids, table, redaction = getGetByIdRequestBody(record) - validatedRecords.append((ids, table, redaction)) + ids, table, redaction, columnName, columnValues = getGetByIdRequestBody(record) + validatedRecords.append((ids, table, redaction,columnName,columnValues)) + print(validatedRecords.__len__) + print(validatedRecords) async with ClientSession() as session: for record in validatedRecords: headers = { "Authorization": "Bearer " + token } - params = {"skyflow_ids": record[0], "redaction": record[2]} + params = {"redaction": record[2]} + if record[0] is not None: + params["skyflow_ids"] = record[0] + if record[3] is not None: + params["column_name"] = record[3] + params["column_values"] = record[4] + print(params) task = asyncio.ensure_future( get(url, headers, params, session, record[1])) tasks.append(task) @@ -93,6 +118,7 @@ def createGetByIdResponseBody(responses): for response in responses: partial = False r = response.result() + print(r) status = r[1] try: jsonRes = json.loads(r[0].decode('utf-8')) @@ -114,6 +140,5 @@ def createGetByIdResponseBody(responses): if len(r) > 3 and r[3] != None: temp["error"]["description"] += ' - Request ID: ' + str(r[3]) result["errors"].append(temp) - result["errors"].append(temp) partial = True return result, partial From 09c660b8b38987c79e1b41ff819afce5b86de71a Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Fri, 30 Dec 2022 17:00:38 +0530 Subject: [PATCH 02/15] SK-262 added sample for get records by unique column values --- samples/get_by_ids_sample.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/samples/get_by_ids_sample.py b/samples/get_by_ids_sample.py index 9eeece8..5834a63 100644 --- a/samples/get_by_ids_sample.py +++ b/samples/get_by_ids_sample.py @@ -27,6 +27,13 @@ def token_provider(): "ids": ["", "", ""], "table": "", "redaction": RedactionType.PLAIN_TEXT + }, + #To get records using unique column name and values. + { + "redaction" : "", + "table": "", + "columnName": "", + "columnValues": "[,]", } ]} From 5723359b58cfff0d33e28b5d0526fe0bcfb2b6e2 Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Fri, 30 Dec 2022 17:06:39 +0530 Subject: [PATCH 03/15] SK-261 added testcases for get records by unique column values --- skyflow/errors/_skyflow_errors.py | 3 ++ skyflow/vault/_get_by_id.py | 18 ++++++++---- tests/vault/test_get_by_id.py | 49 +++++++++++++++++++++++++++++-- 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/skyflow/errors/_skyflow_errors.py b/skyflow/errors/_skyflow_errors.py index d617898..29594a5 100644 --- a/skyflow/errors/_skyflow_errors.py +++ b/skyflow/errors/_skyflow_errors.py @@ -36,6 +36,7 @@ class SkyflowErrorMessages(Enum): TOKEN_KEY_ERROR = "Token key is missing from payload" IDS_KEY_ERROR = "Ids key is missing from payload" REDACTION_KEY_ERROR = "Redaction key is missing from payload" + UNIQUE_COLUMN_OR_IDS_KEY_ERROR = "Ids or Unique column key is missing from payload" INVALID_JSON = "Given %s is invalid JSON" INVALID_RECORDS_TYPE = "Records key has value of type %s, expected list" @@ -44,6 +45,8 @@ class SkyflowErrorMessages(Enum): INVALID_IDS_TYPE = "Ids key has value of type %s, expected list" INVALID_ID_TYPE = "Id key has value of type %s, expected string" INVALID_REDACTION_TYPE = "Redaction key has value of type %s, expected Skyflow.Redaction" + INVALID_COLUMN_NAME = "Column name has value of type %s, expected string" + INVALID_COLUMN_VALUE = "Column values has value of type %s, expected list" INVALID_REQUEST_BODY = "Given request body is not valid" INVALID_RESPONSE_BODY = "Given response body is not valid" diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index 8d951a9..e943a89 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -52,15 +52,21 @@ def getGetByIdRequestBody(data): if "columnName" in data: columnName = data["columnName"] if not isinstance(columnName, str): - columnName = str(type(columnName)) - # changes error text - # raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % ( - # tableType), interface=interface) + columnNameType = str(type(columnName)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % ( + columnNameType), interface=interface) columnValues = None if "columnValues" in data: - columnValues = data["columnValues"] - + columnValues = data["columnValues"] + if not isinstance(columnValues, list): + columnValuesType= str(type(columnValues)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % ( + columnValuesType), interface=interface) + + if(ids is None and (columnName is None or columnValues is None)): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface= interface) return ids, table, redaction.value, columnName, columnValues diff --git a/tests/vault/test_get_by_id.py b/tests/vault/test_get_by_id.py index 9e87428..13b45fb 100644 --- a/tests/vault/test_get_by_id.py +++ b/tests/vault/test_get_by_id.py @@ -70,14 +70,14 @@ def testGetByIdRecordsInvalidType(self): def testGetByIdNoIds(self): invalidData = {"records": [ - {"invalid": "invalid", "table": "pii_fields", "redaction": "PLAIN_TEXT"}]} + {"invalid": "invalid", "table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT}]} try: self.client.get_by_id(invalidData) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) self.assertEqual( - e.message, SkyflowErrorMessages.IDS_KEY_ERROR.value) + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) def testGetByIdInvalidIdsType(self): invalidData = {"records": [ @@ -192,3 +192,48 @@ def testCreateResponseBodyInvalidJson(self): expectedError = SkyflowErrorMessages.RESPONSE_NOT_JSON self.assertEqual(error.code, 200) self.assertEqual(error.message, expectedError.value % response) + + def testGetByIdNoColumnName(self): + invalidData = {"records": [ + {"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT}]} + try: + self.client.get_by_id(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + + def testGetByIdInvalidColumnName(self): + invalidData = {"records": [ + {"ids": ["123", "456"],"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": ["invalid"]}]} + try: + self.client.get_by_id(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % (list)) + + def testGetByIdNoColumnValues(self): + invalidData = {"records": [ + {"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": "first_name"}]} + try: + self.client.get_by_id(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + + def testGetByIdInvalidColumnValues(self): + invalidData = {"records": [ + {"ids": ["123", "456"], "table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": "first_name", "columnValues": "invalid"}]} + try: + self.client.get_by_id(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str)) + \ No newline at end of file From b4be710509b890dfcdd48ff0696ea830351cd5f8 Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Mon, 2 Jan 2023 11:54:20 +0530 Subject: [PATCH 04/15] SK-260 updated get interface --- skyflow/_utils.py | 1 + skyflow/vault/_client.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/skyflow/_utils.py b/skyflow/_utils.py index fa32b69..0ac73c8 100644 --- a/skyflow/_utils.py +++ b/skyflow/_utils.py @@ -75,6 +75,7 @@ class InterfaceName(Enum): INSERT = "client.insert" DETOKENIZE = "client.detokenize" GET_BY_ID = "client.get_by_id" + GET = "client.get" INVOKE_CONNECTION = "client.invoke_connection" GENERATE_BEARER_TOKEN = "service_account.generate_bearer_token" diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index 8aeeb8f..85fb1c6 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -78,6 +78,11 @@ def detokenize(self, records): log_info(InfoMessages.DETOKENIZE_SUCCESS.value, interface) return result + def get(self, records): + interface = InterfaceName.GET.value + log_info(InfoMessages.GET_BY_ID_TRIGGERED.value, interface) + self.get_by_id(self, records) + def get_by_id(self, records): interface = InterfaceName.GET_BY_ID.value log_info(InfoMessages.GET_BY_ID_TRIGGERED.value, interface) From a4fd19f7b50a2dd7070a3b0a8d908e0b7be6630a Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Mon, 2 Jan 2023 15:45:22 +0530 Subject: [PATCH 05/15] SK-260 removed print statements --- skyflow/vault/_client.py | 2 +- skyflow/vault/_get_by_id.py | 14 +++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index 85fb1c6..20d616f 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -81,7 +81,7 @@ def detokenize(self, records): def get(self, records): interface = InterfaceName.GET.value log_info(InfoMessages.GET_BY_ID_TRIGGERED.value, interface) - self.get_by_id(self, records) + self.get_by_id(records) def get_by_id(self, records): interface = InterfaceName.GET_BY_ID.value diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index e943a89..20e161f 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -86,20 +86,17 @@ async def sendGetByIdRequests(data, url, token): for record in records: ids, table, redaction, columnName, columnValues = getGetByIdRequestBody(record) validatedRecords.append((ids, table, redaction,columnName,columnValues)) - print(validatedRecords.__len__) - print(validatedRecords) async with ClientSession() as session: for record in validatedRecords: headers = { "Authorization": "Bearer " + token } params = {"redaction": record[2]} - if record[0] is not None: - params["skyflow_ids"] = record[0] - if record[3] is not None: - params["column_name"] = record[3] - params["column_values"] = record[4] - print(params) + if ids is not None: + params["skyflow_ids"] = ids + if columnName is not None: + params["column_name"] = columnName + params["column_values"] = columnValues task = asyncio.ensure_future( get(url, headers, params, session, record[1])) tasks.append(task) @@ -124,7 +121,6 @@ def createGetByIdResponseBody(responses): for response in responses: partial = False r = response.result() - print(r) status = r[1] try: jsonRes = json.loads(r[0].decode('utf-8')) From fab08e64e4c60626f8e00e0e082900de871b461a Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Mon, 2 Jan 2023 16:49:39 +0530 Subject: [PATCH 06/15] SK-260 refactored code --- skyflow/vault/_get_by_id.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index 20e161f..2a1e9a4 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -85,13 +85,13 @@ async def sendGetByIdRequests(data, url, token): validatedRecords = [] for record in records: ids, table, redaction, columnName, columnValues = getGetByIdRequestBody(record) - validatedRecords.append((ids, table, redaction,columnName,columnValues)) + validatedRecords.append((ids, table, redaction, columnName, columnValues)) async with ClientSession() as session: for record in validatedRecords: headers = { "Authorization": "Bearer " + token } - params = {"redaction": record[2]} + params = {"redaction": redaction} if ids is not None: params["skyflow_ids"] = ids if columnName is not None: From f754e60eb06dc417e339d334518fce516fc46b16 Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Mon, 2 Jan 2023 17:47:53 +0530 Subject: [PATCH 07/15] SK-260 fixed logic error --- skyflow/vault/_client.py | 2 +- skyflow/vault/_get_by_id.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index 20d616f..14aa6b4 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -81,7 +81,7 @@ def detokenize(self, records): def get(self, records): interface = InterfaceName.GET.value log_info(InfoMessages.GET_BY_ID_TRIGGERED.value, interface) - self.get_by_id(records) + return self.get_by_id(records) def get_by_id(self, records): interface = InterfaceName.GET_BY_ID.value diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index 2a1e9a4..d115147 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -57,7 +57,7 @@ def getGetByIdRequestBody(data): columnNameType), interface=interface) columnValues = None - if "columnValues" in data: + if columnName is not None and "columnValues" in data: columnValues = data["columnValues"] if not isinstance(columnValues, list): columnValuesType= str(type(columnValues)) From c949b531f902933b3a1f11a0e2e55569fcd0f72f Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Mon, 2 Jan 2023 18:10:07 +0530 Subject: [PATCH 08/15] SK-260 added testcase for get method --- tests/vault/test_get_by_id.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/vault/test_get_by_id.py b/tests/vault/test_get_by_id.py index 13b45fb..9fdbc7a 100644 --- a/tests/vault/test_get_by_id.py +++ b/tests/vault/test_get_by_id.py @@ -236,4 +236,14 @@ def testGetByIdInvalidColumnValues(self): self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) self.assertEqual( e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str)) - \ No newline at end of file + + def testGet(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "invalid": "invalid", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.TABLE_KEY_ERROR.value) \ No newline at end of file From f27725b1b20af8a705aa3c474ed8985a0549e78f Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Tue, 3 Jan 2023 11:15:38 +0530 Subject: [PATCH 09/15] SK-260 resolved review comments --- samples/get_by_ids_sample.py | 2 +- skyflow/vault/_get_by_id.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/samples/get_by_ids_sample.py b/samples/get_by_ids_sample.py index 5834a63..6c6035c 100644 --- a/samples/get_by_ids_sample.py +++ b/samples/get_by_ids_sample.py @@ -32,7 +32,7 @@ def token_provider(): { "redaction" : "", "table": "", - "columnName": "", + "columnName": "", "columnValues": "[,]", } ]} diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index d115147..e2c89d7 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -12,11 +12,6 @@ def getGetByIdRequestBody(data): - # try: - # ids = data["ids"] - # except KeyError: - # raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, - # SkyflowErrorMessages.IDS_KEY_ERROR, interface=interface) ids = None if "ids" in data: ids = data["ids"] From fa511d3624e127cb3ce2b8b5776afa9133998ef5 Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Tue, 3 Jan 2023 17:51:22 +0530 Subject: [PATCH 10/15] SK-263 added update interface --- skyflow/_utils.py | 3 + skyflow/errors/_skyflow_errors.py | 5 +- skyflow/vault/_client.py | 21 +++++- skyflow/vault/_config.py | 4 ++ skyflow/vault/_insert.py | 2 +- skyflow/vault/_update.py | 104 ++++++++++++++++++++++++++++++ 6 files changed, 135 insertions(+), 4 deletions(-) create mode 100644 skyflow/vault/_update.py diff --git a/skyflow/_utils.py b/skyflow/_utils.py index 0ac73c8..50735e7 100644 --- a/skyflow/_utils.py +++ b/skyflow/_utils.py @@ -68,6 +68,8 @@ class InfoMessages(Enum): IS_EXPIRED_TRIGGERED = "is_expired() triggered" EMPTY_ACCESS_TOKEN = "Give access token is empty" INVALID_TOKEN = "Given token is invalid" + UPDATE_TRIGGERED = "Update method triggered" + UPDATE_DATA_SUCCESS = "Data has been updated successfully" class InterfaceName(Enum): @@ -76,6 +78,7 @@ class InterfaceName(Enum): DETOKENIZE = "client.detokenize" GET_BY_ID = "client.get_by_id" GET = "client.get" + UPDATE = "client.update" INVOKE_CONNECTION = "client.invoke_connection" GENERATE_BEARER_TOKEN = "service_account.generate_bearer_token" diff --git a/skyflow/errors/_skyflow_errors.py b/skyflow/errors/_skyflow_errors.py index 29594a5..a2dbbf0 100644 --- a/skyflow/errors/_skyflow_errors.py +++ b/skyflow/errors/_skyflow_errors.py @@ -34,13 +34,14 @@ class SkyflowErrorMessages(Enum): FIELDS_KEY_ERROR = "Fields key is missing from payload" TABLE_KEY_ERROR = "Table key is missing from payload" TOKEN_KEY_ERROR = "Token key is missing from payload" - IDS_KEY_ERROR = "Ids key is missing from payload" + IDS_KEY_ERROR = "Id(s) key is missing from payload" REDACTION_KEY_ERROR = "Redaction key is missing from payload" UNIQUE_COLUMN_OR_IDS_KEY_ERROR = "Ids or Unique column key is missing from payload" + UPDATE_FIELD_KEY_ERROR = "Atleast one field should be provided to update" INVALID_JSON = "Given %s is invalid JSON" INVALID_RECORDS_TYPE = "Records key has value of type %s, expected list" - INVALID_FIELDS_TYPE = "Fields key has value of type %s, expected string" + INVALID_FIELDS_TYPE = "Fields key has value of type %s, expected dict" INVALID_TABLE_TYPE = "Table key has value of type %s, expected string" INVALID_IDS_TYPE = "Ids key has value of type %s, expected list" INVALID_ID_TYPE = "Id key has value of type %s, expected string" diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index 14aa6b4..2c540f2 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -4,8 +4,9 @@ import types import requests from ._insert import getInsertRequestBody, processResponse, convertResponse +from ._update import sendUpdateRequests, createUpdateResponseBody from ._config import Configuration -from ._config import InsertOptions, ConnectionConfig +from ._config import InsertOptions, ConnectionConfig, UpdateOptions from ._connection import createRequest from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody from ._get_by_id import sendGetByIdRequests, createGetByIdResponseBody @@ -135,3 +136,21 @@ def _get_complete_vault_url(self): Get the complete vault url from given vault url and vault id ''' return self.vaultURL + "/v1/vaults/" + self.vaultID + + def update(self, updateInput, options: UpdateOptions = UpdateOptions()): + interface = InterfaceName.UPDATE.value + log_info(InfoMessages.UPDATE_TRIGGERED.value, interface=interface) + + self._checkConfig(interface) + self.storedToken = tokenProviderWrapper( + self.storedToken, self.tokenProvider, interface) + url = self._get_complete_vault_url() + responses = asyncio.run(sendUpdateRequests( + updateInput, options, url, self.storedToken)) + result, partial = createUpdateResponseBody(responses, interface) + if partial: + raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, + SkyflowErrorMessages.PARTIAL_SUCCESS, result, interface=interface) + else: + log_info(InfoMessages.UPDATE_DATA_SUCCESS.value, interface) + return result \ No newline at end of file diff --git a/skyflow/vault/_config.py b/skyflow/vault/_config.py index 741a99a..25fe3cd 100644 --- a/skyflow/vault/_config.py +++ b/skyflow/vault/_config.py @@ -34,6 +34,10 @@ def __init__(self, tokens: bool=True,upsert :List[UpsertOption]=None): self.tokens = tokens self.upsert = upsert +class UpdateOptions: + def __init__(self, tokens: bool=True): + self.tokens = tokens + class RequestMethod(Enum): GET = 'GET' POST = 'POST' diff --git a/skyflow/vault/_insert.py b/skyflow/vault/_insert.py index 94660f7..e0150b4 100644 --- a/skyflow/vault/_insert.py +++ b/skyflow/vault/_insert.py @@ -74,7 +74,7 @@ def getTableAndFields(record): SkyflowErrorMessages.FIELDS_KEY_ERROR, interface=interface) if not isinstance(fields, dict): - fieldsType = str(type(table)) + fieldsType = str(type(fields)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_FIELDS_TYPE.value % ( fieldsType), interface=interface) diff --git a/skyflow/vault/_update.py b/skyflow/vault/_update.py new file mode 100644 index 0000000..8d3d7aa --- /dev/null +++ b/skyflow/vault/_update.py @@ -0,0 +1,104 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +import json + +import asyncio +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +from ._insert import getTableAndFields +from skyflow._utils import InterfaceName +from aiohttp import ClientSession, request + +interface = InterfaceName.UPDATE.value + +async def sendUpdateRequests(data,options,url,token): + tasks = [] + + try: + records = data["records"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.RECORDS_KEY_ERROR, interface=interface) + if not isinstance(records, list): + recordsType = str(type(records)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % ( + recordsType), interface=interface) + + validatedRecords = [] + for record in records: + tableName = validateUpdateRecord(record) + validatedRecords.append(record) + async with ClientSession() as session: + for record in validatedRecords: + recordUrl = url +'/'+ tableName +'/'+ record["id"] + reqBody = { + "record": { + "fields": record["fields"] + }, + "tokenization": options["tokens"] + } + reqBody = json.dumps(reqBody) + headers = { + "Authorization": "Bearer " + token + } + task = asyncio.ensure_future(put(recordUrl, reqBody, headers, session)) + tasks.append(task) + await asyncio.gather(*tasks) + await session.close() + return tasks + +def validateUpdateRecord(record): + try: + id = record["id"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.IDS_KEY_ERROR, interface=interface) + if not isinstance(id, str): + idType = str(type(id)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.INVALID_ID_TYPE.value % (idType), interface=interface) + table, fields = getTableAndFields(record) + keysLength = len(fields.keys()) + if(keysLength < 1): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.UPDATE_FIELD_KEY_ERROR, interface= interface) + return table + +async def put(url, data, headers, session): + async with session.put(url, data=data, headers=headers, ssl=False) as response: + try: + return (await response.read(), response.status, response.headers['x-request-id']) + except KeyError: + return (await response.read(), response.status) + + +def createUpdateResponseBody(responses): + result = { + "records": [], + "errors": [] + } + partial = False + for response in responses: + r = response.result() + status = r[1] + try: + jsonRes = json.loads(r[0].decode('utf-8')) + except: + raise SkyflowError(status, + SkyflowErrorMessages.RESPONSE_NOT_JSON.value % r[0].decode('utf-8'), interface=interface) + + if status == 200: + temp = {} + temp["id"] = jsonRes["skyflow_id"] + if "tokens" in jsonRes: + temp["fields"] = jsonRes["tokens"] + result["records"].append(temp) + else: + temp = {"error": {}} + temp["error"]["code"] = jsonRes["error"]["http_code"] + temp["error"]["description"] = jsonRes["error"]["message"] + if len(r) > 2 and r[2] != None: + temp["error"]["description"] += ' - Request ID: ' + str(r[2]) + result["errors"].append(temp) + partial = True + return result, partial From 6585a543e52fb3c8da79da3cfb90f32cf911eeba Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Wed, 4 Jan 2023 10:53:11 +0530 Subject: [PATCH 11/15] SK-263 fixed logic error --- skyflow/vault/_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index 2c540f2..b2c29f7 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -147,7 +147,7 @@ def update(self, updateInput, options: UpdateOptions = UpdateOptions()): url = self._get_complete_vault_url() responses = asyncio.run(sendUpdateRequests( updateInput, options, url, self.storedToken)) - result, partial = createUpdateResponseBody(responses, interface) + result, partial = createUpdateResponseBody(responses) if partial: raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, SkyflowErrorMessages.PARTIAL_SUCCESS, result, interface=interface) From bf582ec002089c86fce89b2f58ab2ddc85ada5ba Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Thu, 5 Jan 2023 10:24:02 +0530 Subject: [PATCH 12/15] SK-264 added testcases for update interface --- tests/vault/test_update.py | 184 +++++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 tests/vault/test_update.py diff --git a/tests/vault/test_update.py b/tests/vault/test_update.py new file mode 100644 index 0000000..c6a00ef --- /dev/null +++ b/tests/vault/test_update.py @@ -0,0 +1,184 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +import json +import unittest +import os +import asyncio +import warnings + +from dotenv import dotenv_values +from skyflow.vault._client import Client, Configuration +from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +from skyflow.service_account import generate_bearer_token +from skyflow.vault._client import Client +from skyflow.vault._config import UpdateOptions + + +class TestUpdate(unittest.TestCase): + + def setUp(self) -> None: + self.envValues = dotenv_values(".env") + self.dataPath = os.path.join(os.getcwd(), 'tests/vault/data/') + self.mocked_futures = [] + self.event_loop = asyncio.new_event_loop() + + def tokenProvider(): + token, _ = generate_bearer_token( + self.envValues["CREDENTIALS_FILE_PATH"]) + return token + + config = Configuration( + self.envValues["VAULT_ID"], self.envValues["VAULT_URL"], tokenProvider) + self.client = Client(config) + warnings.filterwarnings( + action="ignore", message="unclosed", category=ResourceWarning) + return super().setUp() + + def add_mock_response(self, response, statusCode, encode=True): + future = asyncio.Future(loop=self.event_loop) + if encode: + future.set_result((json.dumps(response).encode(), statusCode)) + else: + future.set_result((response, statusCode)) + future.done() + self.mocked_futures.append(future) + + def getDataPath(self, file): + return self.dataPath + file + '.json' + + def testUpdateNoRecords(self): + invalidData = {} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.RECORDS_KEY_ERROR.value) + + def testUpdateInvalidType(self): + invalidData = {"records": "invalid"} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % (str)) + + def testUpdateNoIds(self): + invalidData = {"records": [ + {"table": "pii_fields"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.IDS_KEY_ERROR.value) + + def testUpdateInvalidIdType(self): + invalidData = {"records": [ + {"id": ["123"], "table": "pii_fields"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_ID_TYPE.value % (list)) + + def testUpdateNoTable(self): + invalidData = {"records": [ + {"id": "id"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.TABLE_KEY_ERROR.value) + + def testUpdateInvalidTableType(self): + invalidData = {"records": [ + {"id": "id1", "table": ["invalid"]}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % (list)) + + def testUpdateNoFields(self): + invalidData = {"records": [ + {"id": "id", "table": "pii_fields"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.FIELDS_KEY_ERROR.value) + + def testUpdateInvalidFieldsType(self): + invalidData = {"records": [ + {"id": "id1", "table": "pii_fields", "fields": "invalid"}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_FIELDS_TYPE.value % (str)) + + def testUpdateInvalidFieldsType2(self): + invalidData = {"records": [ + {"id": "id1", "table": "pii_fields", "fields": {}}]} + try: + self.client.update(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UPDATE_FIELD_KEY_ERROR.value) + + def testResponseBodySuccess(self): + response = {"skyflow_id": "123", "tokens": {"first_name": "John"}} + mock_response = [{"id": "123", "fields": {"first_name": "John"}}] + self.add_mock_response(response, 200) + print("Seld.mockedFuturs", self.mocked_futures) + res, partial = createUpdateResponseBody(self.mocked_futures) + self.assertEqual(partial, False) + self.assertEqual(res, {"records": mock_response, "errors": []}) + + def testResponseBodyPartialSuccess(self): + success_response = {"skyflow_id": "123", "tokens": {"first_name": "John"}} + mock_success_response = [{"id": "123", "fields": {"first_name": "John"}}] + error_response = {"error": {"http_code": 404, "message": "not found"}} + self.add_mock_response(success_response, 200) + self.add_mock_response(error_response, 404) + res, partial = createUpdateResponseBody(self.mocked_futures) + self.assertTrue(partial) + self.assertEqual(res["records"], mock_success_response) + errors = res["errors"] + + self.assertIsNotNone(errors) + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0]["error"]["code"], + error_response["error"]["http_code"]) + self.assertEqual( + errors[0]["error"]["description"], error_response["error"]["message"]) + + def testResponseNotJson(self): + response = "not a valid json".encode() + self.add_mock_response(response, 200, encode=False) + try: + createUpdateResponseBody(self.mocked_futures) + except SkyflowError as error: + expectedError = SkyflowErrorMessages.RESPONSE_NOT_JSON + self.assertEqual(error.code, 200) + self.assertEqual(error.message, expectedError.value % + response.decode('utf-8')) From d957a6b85c93fd6e9a54958835cb9d4e67c680fc Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Thu, 5 Jan 2023 10:29:21 +0530 Subject: [PATCH 13/15] SK-265 added samples for update interface --- samples/update_sample.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 samples/update_sample.py diff --git a/samples/update_sample.py b/samples/update_sample.py new file mode 100644 index 0000000..9c6ea90 --- /dev/null +++ b/samples/update_sample.py @@ -0,0 +1,39 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +from skyflow.errors import SkyflowError +from skyflow.service_account import generate_bearer_token, is_expired +from skyflow.vault import Client, UpdateOptions, Configuration + +# cache token for reuse +bearerToken = '' + +def token_provider(): + global bearerToken + if is_expired(bearerToken): + bearerToken, _ = generate_bearer_token('') + return bearerToken + + +try: + config = Configuration( + '', '', token_provider) + client = Client(config) + + options = UpdateOptions(True) + + data = { + "records": [ + { + "id": "", + "table": "", + "fields": { + "": "" + } + } + ] + } + response = client.update(data, options=options) + print('Response:', response) +except SkyflowError as e: + print('Error Occurred:', e) From 336b02cfab56c1e2622cc92389d139b6db6adacf Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Thu, 5 Jan 2023 21:25:29 +0530 Subject: [PATCH 14/15] SK-263 fixed logic error --- skyflow/vault/_update.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/skyflow/vault/_update.py b/skyflow/vault/_update.py index 8d3d7aa..ade08b2 100644 --- a/skyflow/vault/_update.py +++ b/skyflow/vault/_update.py @@ -7,11 +7,12 @@ from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from ._insert import getTableAndFields from skyflow._utils import InterfaceName -from aiohttp import ClientSession, request +from aiohttp import ClientSession +from ._config import UpdateOptions interface = InterfaceName.UPDATE.value -async def sendUpdateRequests(data,options,url,token): +async def sendUpdateRequests(data,options: UpdateOptions,url,token): tasks = [] try: @@ -35,7 +36,7 @@ async def sendUpdateRequests(data,options,url,token): "record": { "fields": record["fields"] }, - "tokenization": options["tokens"] + "tokenization": options.tokens } reqBody = json.dumps(reqBody) headers = { From a492e631bc4c18a65cdf95a9317b6074c03b7883 Mon Sep 17 00:00:00 2001 From: skyflow-srivyshnavi Date: Fri, 6 Jan 2023 16:12:07 +0530 Subject: [PATCH 15/15] SK-260 reverting get_by_id changes and added new get interface --- samples/get_by_ids_sample.py | 7 -- samples/get_sample.py | 43 ++++++++ skyflow/_utils.py | 2 + skyflow/vault/_client.py | 23 ++++- skyflow/vault/_get.py | 99 ++++++++++++++++++ skyflow/vault/_get_by_id.py | 59 ++++------- tests/vault/test_get.py | 188 ++++++++++++++++++++++++++++++++++ tests/vault/test_get_by_id.py | 70 ++----------- 8 files changed, 376 insertions(+), 115 deletions(-) create mode 100644 samples/get_sample.py create mode 100644 skyflow/vault/_get.py create mode 100644 tests/vault/test_get.py diff --git a/samples/get_by_ids_sample.py b/samples/get_by_ids_sample.py index 6c6035c..9eeece8 100644 --- a/samples/get_by_ids_sample.py +++ b/samples/get_by_ids_sample.py @@ -27,13 +27,6 @@ def token_provider(): "ids": ["", "", ""], "table": "", "redaction": RedactionType.PLAIN_TEXT - }, - #To get records using unique column name and values. - { - "redaction" : "", - "table": "", - "columnName": "", - "columnValues": "[,]", } ]} diff --git a/samples/get_sample.py b/samples/get_sample.py new file mode 100644 index 0000000..867d9aa --- /dev/null +++ b/samples/get_sample.py @@ -0,0 +1,43 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +from skyflow.errors import SkyflowError +from skyflow.service_account import generate_bearer_token, is_expired +from skyflow.vault import Client, Configuration, RedactionType + + +# cache token for reuse +bearerToken = '' + + +def token_provider(): + global bearerToken + if is_expired(bearerToken): + bearerToken, _ = generate_bearer_token('') + return bearerToken + + +try: + config = Configuration( + '', '', token_provider) + client = Client(config) + + data = {"records": [ + { + "ids": ["", "", ""], + "table": "", + "redaction": RedactionType.PLAIN_TEXT + }, + #To get records using unique column name and values. + { + "redaction" : "", + "table": "", + "columnName": "", + "columnValues": "[,]", + } + ]} + + response = client.get(data) + print('Response:', response) +except SkyflowError as e: + print('Error Occurred:', e) diff --git a/skyflow/_utils.py b/skyflow/_utils.py index 50735e7..5186dbe 100644 --- a/skyflow/_utils.py +++ b/skyflow/_utils.py @@ -70,6 +70,8 @@ class InfoMessages(Enum): INVALID_TOKEN = "Given token is invalid" UPDATE_TRIGGERED = "Update method triggered" UPDATE_DATA_SUCCESS = "Data has been updated successfully" + GET_TRIGGERED = "Get triggered." + GET_SUCCESS = "Data fetched successfully." class InterfaceName(Enum): diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index b2c29f7..ff7aaa4 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -9,7 +9,8 @@ from ._config import InsertOptions, ConnectionConfig, UpdateOptions from ._connection import createRequest from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody -from ._get_by_id import sendGetByIdRequests, createGetByIdResponseBody +from ._get_by_id import sendGetByIdRequests, createGetResponseBody +from ._get import sendGetRequests import asyncio from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow._utils import log_info, InfoMessages, InterfaceName @@ -81,8 +82,22 @@ def detokenize(self, records): def get(self, records): interface = InterfaceName.GET.value - log_info(InfoMessages.GET_BY_ID_TRIGGERED.value, interface) - return self.get_by_id(records) + log_info(InfoMessages.GET_TRIGGERED.value, interface) + + self._checkConfig(interface) + self.storedToken = tokenProviderWrapper( + self.storedToken, self.tokenProvider, interface) + url = self._get_complete_vault_url() + responses = asyncio.run(sendGetRequests( + records, url, self.storedToken)) + result, partial = createGetResponseBody(responses) + if partial: + raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, + SkyflowErrorMessages.PARTIAL_SUCCESS, result, interface=interface) + else: + log_info(InfoMessages.GET_SUCCESS.value, interface) + + return result def get_by_id(self, records): interface = InterfaceName.GET_BY_ID.value @@ -94,7 +109,7 @@ def get_by_id(self, records): url = self._get_complete_vault_url() responses = asyncio.run(sendGetByIdRequests( records, url, self.storedToken)) - result, partial = createGetByIdResponseBody(responses) + result, partial = createGetResponseBody(responses) if partial: raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, SkyflowErrorMessages.PARTIAL_SUCCESS, result, interface=interface) diff --git a/skyflow/vault/_get.py b/skyflow/vault/_get.py new file mode 100644 index 0000000..0670206 --- /dev/null +++ b/skyflow/vault/_get.py @@ -0,0 +1,99 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +import asyncio +from aiohttp import ClientSession +from ._config import RedactionType +from skyflow._utils import InterfaceName +from ._get_by_id import get + +interface = InterfaceName.GET.value + +def getGetRequestBody(data): + ids = None + if "ids" in data: + ids = data["ids"] + if not isinstance(ids, list): + idsType = str(type(ids)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.INVALID_IDS_TYPE.value % (idsType), interface=interface) + for id in ids: + if not isinstance(id, str): + idType = str(type(id)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % ( + idType), interface=interface) + try: + table = data["table"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.TABLE_KEY_ERROR, interface=interface) + if not isinstance(table, str): + tableType = str(type(table)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % ( + tableType), interface=interface) + try: + redaction = data["redaction"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.REDACTION_KEY_ERROR, interface=interface) + if not isinstance(redaction, RedactionType): + redactionType = str(type(redaction)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % ( + redactionType), interface=interface) + + columnName = None + if "columnName" in data: + columnName = data["columnName"] + if not isinstance(columnName, str): + columnNameType = str(type(columnName)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % ( + columnNameType), interface=interface) + + columnValues = None + if columnName is not None and "columnValues" in data: + columnValues = data["columnValues"] + if not isinstance(columnValues, list): + columnValuesType= str(type(columnValues)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % ( + columnValuesType), interface=interface) + + if(ids is None and (columnName is None or columnValues is None)): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface= interface) + return ids, table, redaction.value, columnName, columnValues + + +async def sendGetRequests(data, url, token): + tasks = [] + try: + records = data["records"] + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.RECORDS_KEY_ERROR, interface=interface) + if not isinstance(records, list): + recordsType = str(type(records)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % ( + recordsType), interface=interface) + + validatedRecords = [] + for record in records: + ids, table, redaction, columnName, columnValues = getGetRequestBody(record) + validatedRecords.append((ids, table, redaction, columnName, columnValues)) + async with ClientSession() as session: + for record in validatedRecords: + headers = { + "Authorization": "Bearer " + token + } + params = {"redaction": redaction} + if ids is not None: + params["skyflow_ids"] = ids + if columnName is not None: + params["column_name"] = columnName + params["column_values"] = columnValues + task = asyncio.ensure_future( + get(url, headers, params, session, record[1])) + tasks.append(task) + await asyncio.gather(*tasks) + await session.close() + return tasks \ No newline at end of file diff --git a/skyflow/vault/_get_by_id.py b/skyflow/vault/_get_by_id.py index e2c89d7..7cd3adc 100644 --- a/skyflow/vault/_get_by_id.py +++ b/skyflow/vault/_get_by_id.py @@ -12,18 +12,20 @@ def getGetByIdRequestBody(data): - ids = None - if "ids" in data: + try: ids = data["ids"] - if not isinstance(ids, list): - idsType = str(type(ids)) - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, - SkyflowErrorMessages.INVALID_IDS_TYPE.value % (idsType), interface=interface) - for id in ids: - if not isinstance(id, str): - idType = str(type(id)) - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % ( - idType), interface=interface) + except KeyError: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.IDS_KEY_ERROR, interface=interface) + if not isinstance(ids, list): + idsType = str(type(ids)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, + SkyflowErrorMessages.INVALID_IDS_TYPE.value % (idsType), interface=interface) + for id in ids: + if not isinstance(id, str): + idType = str(type(id)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % ( + idType), interface=interface) try: table = data["table"] except KeyError: @@ -42,27 +44,7 @@ def getGetByIdRequestBody(data): redactionType = str(type(redaction)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % ( redactionType), interface=interface) - - columnName = None - if "columnName" in data: - columnName = data["columnName"] - if not isinstance(columnName, str): - columnNameType = str(type(columnName)) - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % ( - columnNameType), interface=interface) - - columnValues = None - if columnName is not None and "columnValues" in data: - columnValues = data["columnValues"] - if not isinstance(columnValues, list): - columnValuesType= str(type(columnValues)) - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % ( - columnValuesType), interface=interface) - - if(ids is None and (columnName is None or columnValues is None)): - raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, - SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface= interface) - return ids, table, redaction.value, columnName, columnValues + return ids, table, redaction.value async def sendGetByIdRequests(data, url, token): @@ -79,19 +61,14 @@ async def sendGetByIdRequests(data, url, token): validatedRecords = [] for record in records: - ids, table, redaction, columnName, columnValues = getGetByIdRequestBody(record) - validatedRecords.append((ids, table, redaction, columnName, columnValues)) + ids, table, redaction = getGetByIdRequestBody(record) + validatedRecords.append((ids, table, redaction)) async with ClientSession() as session: for record in validatedRecords: headers = { "Authorization": "Bearer " + token } - params = {"redaction": redaction} - if ids is not None: - params["skyflow_ids"] = ids - if columnName is not None: - params["column_name"] = columnName - params["column_values"] = columnValues + params = {"skyflow_ids": record[0], "redaction": record[2]} task = asyncio.ensure_future( get(url, headers, params, session, record[1])) tasks.append(task) @@ -108,7 +85,7 @@ async def get(url, headers, params, session, table): return (await response.read(), response.status, table) -def createGetByIdResponseBody(responses): +def createGetResponseBody(responses): result = { "records": [], "errors": [] diff --git a/tests/vault/test_get.py b/tests/vault/test_get.py new file mode 100644 index 0000000..95171a0 --- /dev/null +++ b/tests/vault/test_get.py @@ -0,0 +1,188 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +import unittest +import os + +from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages +from skyflow.vault import Client, Configuration, RedactionType +from skyflow.service_account import generate_bearer_token +from dotenv import dotenv_values +import warnings +import asyncio +import json + + +class TestGet(unittest.TestCase): + + def setUp(self) -> None: + self.envValues = dotenv_values(".env") + self.dataPath = os.path.join(os.getcwd(), 'tests/vault/data/') + self.event_loop = asyncio.new_event_loop() + self.mocked_futures = [] + + def tokenProvider(): + token, type = generate_bearer_token( + self.envValues["CREDENTIALS_FILE_PATH"]) + return token + + config = Configuration( + self.envValues["VAULT_ID"], self.envValues["VAULT_URL"], tokenProvider) + self.client = Client(config) + warnings.filterwarnings( + action="ignore", message="unclosed", category=ResourceWarning) + return super().setUp() + + def add_mock_response(self, response, statusCode, table, encode=True): + future = asyncio.Future(loop=self.event_loop) + if encode: + future.set_result( + (json.dumps(response).encode(), statusCode, table)) + else: + future.set_result((response, statusCode, table)) + future.done() + self.mocked_futures.append(future) + + def getDataPath(self, file): + return self.dataPath + file + '.json' + + def testGetByIdNoRecords(self): + invalidData = {"invalidKey": "invalid"} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.RECORDS_KEY_ERROR.value) + + def testGetByIdRecordsInvalidType(self): + invalidData = {"records": "invalid"} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % (str)) + + def testGetByIdNoIds(self): + invalidData = {"records": [ + {"invalid": "invalid", "table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + + def testGetByIdInvalidIdsType(self): + invalidData = {"records": [ + {"ids": "invalid", "table": "pii_fields", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_IDS_TYPE.value % (str)) + + def testGetByIdInvalidIdsType2(self): + invalidData = {"records": [ + {"ids": ["123", 123], "table": "pii_fields", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_ID_TYPE.value % (int)) + + def testGetByIdNoTable(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "invalid": "invalid", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.TABLE_KEY_ERROR.value) + + def testGetByIdInvalidTableType(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "table": ["invalid"], "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % (list)) + + def testGetByIdNoRedaction(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "table": "pii_fields", "invalid": "invalid"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.REDACTION_KEY_ERROR.value) + + def testGetByIdInvalidRedactionType(self): + invalidData = {"records": [ + {"ids": ["id1", "id2"], "table": "pii_fields", "redaction": "PLAIN_TEXT"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (str)) + + def testGetByIdNoColumnName(self): + invalidData = {"records": [ + {"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + + def testGetByIdInvalidColumnName(self): + invalidData = {"records": [ + {"ids": ["123", "456"],"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": ["invalid"]}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % (list)) + + def testGetByIdNoColumnValues(self): + invalidData = {"records": [ + {"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": "first_name"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + + def testGetByIdInvalidColumnValues(self): + invalidData = {"records": [ + {"ids": ["123", "456"], "table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": "first_name", "columnValues": "invalid"}]} + try: + self.client.get(invalidData) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str)) diff --git a/tests/vault/test_get_by_id.py b/tests/vault/test_get_by_id.py index 9fdbc7a..d967623 100644 --- a/tests/vault/test_get_by_id.py +++ b/tests/vault/test_get_by_id.py @@ -4,10 +4,9 @@ import unittest import os -import aiohttp from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow.vault import Client, Configuration, RedactionType -from skyflow.vault._get_by_id import createGetByIdResponseBody +from skyflow.vault._get_by_id import createGetResponseBody from skyflow.service_account import generate_bearer_token from dotenv import dotenv_values import warnings @@ -70,14 +69,14 @@ def testGetByIdRecordsInvalidType(self): def testGetByIdNoIds(self): invalidData = {"records": [ - {"invalid": "invalid", "table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT}]} + {"invalid": "invalid", "table": "pii_fields", "redaction": "PLAIN_TEXT"}]} try: self.client.get_by_id(invalidData) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) self.assertEqual( - e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) + e.message, SkyflowErrorMessages.IDS_KEY_ERROR.value) def testGetByIdInvalidIdsType(self): invalidData = {"records": [ @@ -149,7 +148,7 @@ def testCreateResponseBodySuccess(self): response = {"records": [ {"fields": {"card_number": "4111-1111-1111-1111"}}]} self.add_mock_response(response, 200, "table") - result, partial = createGetByIdResponseBody(self.mocked_futures) + result, partial = createGetResponseBody(self.mocked_futures) self.assertFalse(partial) self.assertEqual(len(result["records"]), 1) @@ -168,7 +167,7 @@ def testCreateResponseBodyPartialSuccess(self): }} self.add_mock_response(failed_response, 404, "ok") - result, partial = createGetByIdResponseBody(self.mocked_futures) + result, partial = createGetResponseBody(self.mocked_futures) self.assertTrue(partial) self.assertEqual(len(result["records"]), 1) @@ -187,63 +186,8 @@ def testCreateResponseBodyInvalidJson(self): self.add_mock_response(response.encode(), 200, 'table', encode=False) try: - createGetByIdResponseBody(self.mocked_futures) + createGetResponseBody(self.mocked_futures) except SkyflowError as error: expectedError = SkyflowErrorMessages.RESPONSE_NOT_JSON self.assertEqual(error.code, 200) - self.assertEqual(error.message, expectedError.value % response) - - def testGetByIdNoColumnName(self): - invalidData = {"records": [ - {"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT}]} - try: - self.client.get_by_id(invalidData) - self.fail('Should have thrown an error') - except SkyflowError as e: - self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) - self.assertEqual( - e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) - - def testGetByIdInvalidColumnName(self): - invalidData = {"records": [ - {"ids": ["123", "456"],"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": ["invalid"]}]} - try: - self.client.get_by_id(invalidData) - self.fail('Should have thrown an error') - except SkyflowError as e: - self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) - self.assertEqual( - e.message, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % (list)) - - def testGetByIdNoColumnValues(self): - invalidData = {"records": [ - {"table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": "first_name"}]} - try: - self.client.get_by_id(invalidData) - self.fail('Should have thrown an error') - except SkyflowError as e: - self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) - self.assertEqual( - e.message, SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value) - - def testGetByIdInvalidColumnValues(self): - invalidData = {"records": [ - {"ids": ["123", "456"], "table": "pii_fields", "redaction": RedactionType.PLAIN_TEXT, "columnName": "first_name", "columnValues": "invalid"}]} - try: - self.client.get_by_id(invalidData) - self.fail('Should have thrown an error') - except SkyflowError as e: - self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) - self.assertEqual( - e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str)) - - def testGet(self): - invalidData = {"records": [ - {"ids": ["id1", "id2"], "invalid": "invalid", "redaction": "PLAIN_TEXT"}]} - try: - self.client.get(invalidData) - self.fail('Should have thrown an error') - except SkyflowError as e: - self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) - self.assertEqual( - e.message, SkyflowErrorMessages.TABLE_KEY_ERROR.value) \ No newline at end of file + self.assertEqual(error.message, expectedError.value % response) \ No newline at end of file