diff --git a/skyflow/errors/_skyflow_errors.py b/skyflow/errors/_skyflow_errors.py index 60f6c74..f31a1c2 100644 --- a/skyflow/errors/_skyflow_errors.py +++ b/skyflow/errors/_skyflow_errors.py @@ -43,6 +43,9 @@ class SkyflowErrorMessages(Enum): INVALID_JSON = "Given %s is invalid JSON" INVALID_RECORDS_TYPE = "Records key has value of type %s, expected list" INVALID_FIELDS_TYPE = "Fields key has value of type %s, expected dict" + INVALID_TOKENS_TYPE = "Tokens key has value of type %s, expected dict" + EMPTY_TOKENS_IN_INSERT = "Tokens is empty in records" + MISMATCH_OF_FIELDS_AND_TOKENS = "Fields and Tokens object are not matching" INVALID_TABLE_TYPE = "Table key has value of type %s, expected string" INVALID_TABLE_TYPE_DELETE = "Table of type string is required at index %s in records array" INVALID_IDS_TYPE = "Ids key has value of type %s, expected list" diff --git a/skyflow/vault/_insert.py b/skyflow/vault/_insert.py index e0150b4..12bc68c 100644 --- a/skyflow/vault/_insert.py +++ b/skyflow/vault/_insert.py @@ -29,23 +29,26 @@ def getInsertRequestBody(data, options): validateUpsertOptions(upsertOptions=upsertOptions) requestPayload = [] - insertTokenPayload = [] for index, record in enumerate(records): tableName, fields = getTableAndFields(record) - postPayload = {"tableName": tableName, "fields": fields,"method": "POST","quorum": True} + postPayload = { + "tableName": tableName, + "fields": fields, + "method": "POST", + "quorum": True, + } + if "tokens" in record: + tokens = getTokens(record) + postPayload["tokens"] = tokens if upsertOptions: postPayload["upsert"] = getUpsertColumn(tableName,upsertOptions) - requestPayload.append(postPayload) if options.tokens: - insertTokenPayload.append({ - "method": "GET", - "tableName": tableName, - "ID": "$responses." + str(index) + ".records.0.skyflow_id", - "tokenization": True - }) - requestBody = {"records": requestPayload + insertTokenPayload} + postPayload['tokenization'] = True + + requestPayload.append(postPayload) + requestBody = {"records": requestPayload } try: jsonBody = json.dumps(requestBody) except Exception as e: @@ -80,6 +83,21 @@ def getTableAndFields(record): return (table, fields) +def getTokens(record): + tokens = record["tokens"] + if not isinstance(tokens, dict): + tokensType = str(type(tokens)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TOKENS_TYPE.value % ( + tokensType), interface=interface) + + if len(tokens) == 0 : + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.EMPTY_TOKENS_IN_INSERT, interface= interface) + + fields = record["fields"] + for tokenKey in tokens: + if tokenKey not in fields: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.MISMATCH_OF_FIELDS_AND_TOKENS, interface= interface) + return tokens def processResponse(response: requests.Response, interface=interface): statusCode = response.status_code @@ -110,11 +128,11 @@ def convertResponse(request: dict, response: dict, tokens: bool): records = request['records'] recordsSize = len(records) result = [] - for idx, _ in enumerate(records): + for idx, _ in enumerate(responseArray): table = records[idx]['table'] skyflow_id = responseArray[idx]['records'][0]['skyflow_id'] if tokens: - fieldsDict = responseArray[recordsSize + idx]['fields'] + fieldsDict = responseArray[idx]['records'][0]['tokens'] fieldsDict['skyflow_id'] = skyflow_id result.append({'table': table, 'fields': fieldsDict}) else: diff --git a/tests/vault/test_insert.py b/tests/vault/test_insert.py index 75b2cff..4f94ec1 100644 --- a/tests/vault/test_insert.py +++ b/tests/vault/test_insert.py @@ -22,23 +22,39 @@ def setUp(self) -> None: "fields": { "cardNumber": "4111-1111-1111-1111", "cvv": "234" + }, + "tokens":{ + "cardNumber": "4111-1111-1111-1111", } } self.data = {"records": [record]} self.mockRequest = {"records": [record]} + record2 = { + "table": "pii_fields", + "fields": { + "cardNumber": "4111-1111-1111-1111", + "cvv": "234" + } + } + self.data2 = {"records": [record2]} + self.mockRequest2 = {"records": [record2]} self.mockResponse = {"responses": [ - { - "records": [{"skyflow_id": 123}], - "table": "pii_fields" - }, - { - "fields": { - "cardNumber": "card_number_token", - "cvv": "cvv_token" + { + "records": [ + { + "skyflow_id": 123, + "tokens": { + "first_name": "4db12c22-758e-4fc9-b41d-e8e48b876776", + "cardNumber": "card_number_token", + "cvv": "cvv_token", + "expiry_date": "6b45daa3-0e81-42a8-a911-23929f1cf9da" + + } } - } - ]} + ] + } + ]} self.insertOptions = InsertOptions(tokens=True) @@ -55,11 +71,29 @@ def testGetInsertRequestBodyWithValidBody(self): "cardNumber": "4111-1111-1111-1111", "cvv": "234" }, + "tokens":{ + "cardNumber": "4111-1111-1111-1111", + }, "method": 'POST', - "quorum": True + "quorum": True, + "tokenization": True } self.assertEqual(body["records"][0], expectedOutput) + def testGetInsertRequestBodyWithValidBodyWithoutTokens(self): + body = json.loads(getInsertRequestBody(self.data2, self.insertOptions)) + expectedOutput = { + "tableName": "pii_fields", + "fields": { + "cardNumber": "4111-1111-1111-1111", + "cvv": "234" + }, + "method": 'POST', + "quorum": True, + "tokenization": True + } + self.assertEqual(body["records"][0], expectedOutput) + def testGetInsertRequestBodyWithValidUpsertOptions(self): body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')]))) expectedOutput = { @@ -67,9 +101,28 @@ def testGetInsertRequestBodyWithValidUpsertOptions(self): "fields": { "cardNumber": "4111-1111-1111-1111", "cvv": "234" + }, + "tokens": { + "cardNumber": "4111-1111-1111-1111", }, "method": 'POST', "quorum": True, + "tokenization": True, + "upsert": 'column1', + } + self.assertEqual(body["records"][0], expectedOutput) + + def testGetInsertRequestBodyWithValidUpsertOptionsWithOutTokens(self): + body = json.loads(getInsertRequestBody(self.data2, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')]))) + expectedOutput = { + "tableName": "pii_fields", + "fields": { + "cardNumber": "4111-1111-1111-1111", + "cvv": "234" + }, + "method": 'POST', + "quorum": True, + "tokenization": True, "upsert": 'column1', } self.assertEqual(body["records"][0], expectedOutput) @@ -128,6 +181,97 @@ def testGetInsertRequestBodyInvalidFieldsType(self): self.assertEqual( e.message, SkyflowErrorMessages.INVALID_FIELDS_TYPE.value % (str(type('str')))) + def testInvalidTokensInRecord(self): + invalidData = {"records": [{ + "table": "table", + "fields": { + "card_number": "4111-1111" + }, + "tokens": "tokens" + } + ]} + try: + getInsertRequestBody(invalidData, self.insertOptions) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_TOKENS_TYPE.value % (str(type('str')))) + + def testEmptyTokensInRecord(self): + invalidData = {"records": [{ + "table": "table", + "fields": { + "card_number": "4111-1111" + }, + "tokens": { + } + } + ]} + try: + getInsertRequestBody(invalidData, self.insertOptions) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.EMPTY_TOKENS_IN_INSERT.value) + + def testMismatchTokensInRecord(self): + invalidData = {"records": [{ + "table": "table", + "fields": { + "card_number": "4111-1111" + }, + "tokens": { + "cvv": "123" + } + } + ]} + try: + getInsertRequestBody(invalidData, self.insertOptions) + self.fail('Should have thrown an error') + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.MISMATCH_OF_FIELDS_AND_TOKENS.value) + + # def testTokensInRecord(self): + # invalidData = {"records": [{ + # "table": "table", + # "fields": { + # "card_number": "4111-1111" + # }, + # "tokens": { + # "cvv": "123" + # } + # } + # ]} + # try: + # getInsertRequestBody(invalidData, self.insertOptions) + # self.fail('Should have thrown an error') + # except SkyflowError as e: + # self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + # self.assertEqual( + # e.message, SkyflowErrorMessages.MISMATCH_OF_FIELDS_AND_TOKENS.value) + + def testGetInsertRequestBodyWithTokensValidBody(self): + body = json.loads(getInsertRequestBody(self.data, self.insertOptions)) + expectedOutput = { + "tableName": "pii_fields", + "fields": { + "cardNumber": "4111-1111-1111-1111", + "cvv": "234" + }, + "tokens": { + "cardNumber": "4111-1111-1111-1111", + + }, + "method": 'POST', + "quorum": True, + "tokenization": True + } + self.assertEqual(body["records"][0], expectedOutput) + def testGetInsertRequestBodyNoTable(self): invalidData = {"records": [{ "noTable": "tableshouldbehere", @@ -239,11 +383,10 @@ def testProcessResponseFail(self): def testConvertResponseNoTokens(self): tokens = False result = convertResponse(self.mockRequest, self.mockResponse, tokens) - self.assertEqual(len(result["records"]), 1) self.assertEqual(result["records"][0]["skyflow_id"], 123) self.assertEqual(result["records"][0]["table"], "pii_fields") - self.assertNotIn("fields", result["records"][0]) + self.assertNotIn("tokens", result["records"][0]) def testConvertResponseWithTokens(self): tokens = True