diff --git a/skyflow/vault/_detokenize.py b/skyflow/vault/_detokenize.py index 8370de0..48f1d1b 100644 --- a/skyflow/vault/_detokenize.py +++ b/skyflow/vault/_detokenize.py @@ -5,6 +5,7 @@ import asyncio from aiohttp import ClientSession, request import json +from ._config import RedactionType from skyflow._utils import InterfaceName, getMetrics interface = InterfaceName.DETOKENIZE.value @@ -15,14 +16,27 @@ def getDetokenizeRequestBody(data): token = data["token"] except KeyError: raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, - SkyflowErrorMessages.TOKEN_KEY_ERROR, interface=interface) + SkyflowErrorMessages.TOKEN_KEY_ERROR, interface=interface) if not isinstance(token, str): tokenType = str(type(token)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TOKEN_TYPE.value % ( tokenType), interface=interface) + + if "redaction" in data: + if not isinstance(data["redaction"], RedactionType): + redactionType = str(type(data["redaction"])) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % ( + redactionType), interface=interface) + else: + redactionType = data["redaction"] + else: + redactionType = RedactionType.PLAIN_TEXT + requestBody = {"detokenizationParameters": []} requestBody["detokenizationParameters"].append({ - "token": token}) + "token": token, + "redaction": redactionType.value + }) return requestBody @@ -39,7 +53,6 @@ async def sendDetokenizeRequests(data, url, token): recordsType = str(type(records)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % ( recordsType), interface=interface) - validatedRecords = [] for record in records: requestBody = getDetokenizeRequestBody(record) diff --git a/tests/vault/test_detokenize.py b/tests/vault/test_detokenize.py index 609d954..8edbfb3 100644 --- a/tests/vault/test_detokenize.py +++ b/tests/vault/test_detokenize.py @@ -7,6 +7,7 @@ from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow.vault._client import Client, Configuration from skyflow.service_account import generate_bearer_token +from skyflow.vault._config import RedactionType from dotenv import dotenv_values import warnings @@ -55,7 +56,8 @@ def testGetDetokenizeRequestBodyWithValidBody(self): body = getDetokenizeRequestBody(self.tokenField) expectedOutput = { "detokenizationParameters": [{ - "token": self.testToken + "token": self.testToken, + "redaction": "PLAIN_TEXT" }] } @@ -101,6 +103,15 @@ def testDetokenizeTokenInvalidType(self): self.assertEqual( e.message, SkyflowErrorMessages.INVALID_TOKEN_TYPE.value % (list)) + def testDetokenizeRedactionInvalidType(self): + invalidData = {"records": [{"token": "valid", "redaction": 'demo'}]} + try: + self.client.detokenize(invalidData) + except SkyflowError as error: + self.assertTrue(error) + self.assertEqual(error.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual(error.message, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % str(type("demo"))) + def testResponseBodySuccess(self): response = {"records": [{"token": "abc", "value": "secret"}]} self.add_mock_response(response, 200) @@ -135,3 +146,39 @@ def testResponseNotJson(self): self.assertEqual(error.code, 200) self.assertEqual(error.message, expectedError.value % response.decode('utf-8')) + + def testRequestBodyNoRedactionKey(self): + expectedOutput = { + "detokenizationParameters": [{ + "token": self.testToken, + "redaction": "PLAIN_TEXT" + }] + } + requestBody = getDetokenizeRequestBody(self.tokenField) + self.assertEqual(requestBody, expectedOutput) + + def testRequestBodyWithValidRedaction(self): + expectedOutput = { + "detokenizationParameters": [{ + "token": self.testToken, + "redaction": "REDACTED" + }] + } + data = { + "token": self.testToken, + "redaction": RedactionType.REDACTED + } + requestBody = getDetokenizeRequestBody(data) + self.assertEqual(expectedOutput, requestBody) + + def testRequestBodyWithInValidRedaction(self): + data = { + "token": self.testToken, + "redaction": "123" + } + try: + getDetokenizeRequestBody(data) + except SkyflowError as error: + self.assertTrue(error) + self.assertEqual(error.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual(error.message, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % str(type(data["redaction"]))) \ No newline at end of file