diff --git a/skyflow/errors/_skyflow_errors.py b/skyflow/errors/_skyflow_errors.py index efd4fc4..979c9e6 100644 --- a/skyflow/errors/_skyflow_errors.py +++ b/skyflow/errors/_skyflow_errors.py @@ -94,6 +94,11 @@ class SkyflowErrorMessages(Enum): BATCH_INSERT_PARTIAL_SUCCESS = "Insert Operation is partially successful" BATCH_INSERT_FAILURE = "Insert Operation is unsuccessful" + INVALID_BYOT_TYPE = "byot option has value of type %s, expected Skyflow.BYOT" + NO_TOKENS_IN_INSERT = "Tokens are not passed in records for byot as %s" + TOKENS_PASSED_FOR_BYOT_DISABLE = "Pass byot parameter with ENABLE for token insertion" + INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT = "For byot as ENABLE_STRICT, tokens should be passed for all fields" + class SkyflowError(Exception): def __init__(self, code, message="An Error occured", data={}, interface: str = 'Unknown') -> None: if type(code) is SkyflowErrorCodes: diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index 57a97e8..f54f2ab 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -61,7 +61,6 @@ def insert(self, records: dict, options: InsertOptions = InsertOptions()): response = requests.post(requestURL, data=jsonBody, headers=headers) processedResponse = processResponse(response) result, partial = convertResponse(records, processedResponse, options) - # these statements will be covered in Integration Tests if partial: log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface) elif 'records' not in result: diff --git a/skyflow/vault/_config.py b/skyflow/vault/_config.py index d7bc2f4..c0799b8 100644 --- a/skyflow/vault/_config.py +++ b/skyflow/vault/_config.py @@ -24,16 +24,22 @@ def __init__(self, vaultID: str=None, vaultURL: str=None, tokenProvider: Functio self.vaultURL = vaultURL or "" self.tokenProvider = tokenProvider +class BYOT(Enum): + DISABLE = "DISABLE" + ENABLE = "ENABLE" + ENABLE_STRICT = "ENABLE_STRICT" + class UpsertOption: def __init__(self,table: str,column: str): self.table = table self.column = column class InsertOptions: - def __init__(self, tokens: bool=True, upsert :List[UpsertOption]=None, continueOnError:bool=None): + def __init__(self, tokens: bool=True, upsert :List[UpsertOption]=None, continueOnError:bool=None, byot:BYOT=BYOT.DISABLE): self.tokens = tokens self.upsert = upsert self.continueOnError = continueOnError + self.byot = byot class UpdateOptions: def __init__(self, tokens: bool=True): diff --git a/skyflow/vault/_insert.py b/skyflow/vault/_insert.py index a8557f1..8de342a 100644 --- a/skyflow/vault/_insert.py +++ b/skyflow/vault/_insert.py @@ -7,7 +7,7 @@ from requests.models import HTTPError from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow._utils import InterfaceName -from skyflow.vault._config import InsertOptions +from skyflow.vault._config import BYOT, InsertOptions interface = InterfaceName.INSERT.value @@ -38,7 +38,8 @@ def getInsertRequestBody(data, options: InsertOptions): "method": "POST", "quorum": True, } - if "tokens" in record: + validateTokensAndByotMode(record, options.byot) + if "tokens" in record: tokens = getTokens(record) postPayload["tokens"] = tokens @@ -51,7 +52,8 @@ def getInsertRequestBody(data, options: InsertOptions): requestPayload.append(postPayload) requestBody = { "records": requestPayload, - "continueOnError": options.continueOnError + "continueOnError": options.continueOnError, + "byot": options.byot.value } if options.continueOnError == None: requestBody.pop('continueOnError') @@ -89,6 +91,23 @@ def getTableAndFields(record): return (table, fields) +def validateTokensAndByotMode(record, byot:BYOT): + + if not isinstance(byot, BYOT): + byotType = str(type(byot)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_BYOT_TYPE.value % (byotType), interface=interface) + + if byot == BYOT.DISABLE: + if "tokens" in record: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.TOKENS_PASSED_FOR_BYOT_DISABLE, interface=interface) + elif "tokens" not in record: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % byot.value, interface=interface) + elif byot == BYOT.ENABLE_STRICT: + tokens = record["tokens"] + fields = record["fields"] + if len(tokens) != len(fields): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT, interface=interface) + def getTokens(record): tokens = record["tokens"] if not isinstance(tokens, dict): diff --git a/tests/vault/test_insert.py b/tests/vault/test_insert.py index 189df9f..c39e8e3 100644 --- a/tests/vault/test_insert.py +++ b/tests/vault/test_insert.py @@ -10,7 +10,7 @@ from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow.service_account import generate_bearer_token from skyflow.vault._client import Client -from skyflow.vault._config import Configuration, InsertOptions, UpsertOption +from skyflow.vault._config import Configuration, InsertOptions, UpsertOption, BYOT class TestInsert(unittest.TestCase): @@ -92,6 +92,7 @@ def setUp(self) -> None: } self.insertOptions = InsertOptions(tokens=True) + self.insertOptions2 = InsertOptions(tokens=True, byot=BYOT.ENABLE) return super().setUp() @@ -99,7 +100,7 @@ def getDataPath(self, file): return self.dataPath + file + '.json' def testGetInsertRequestBodyWithValidBody(self): - body = json.loads(getInsertRequestBody(self.data, self.insertOptions)) + body = json.loads(getInsertRequestBody(self.data, self.insertOptions2)) expectedOutput = { "tableName": "pii_fields", "fields": { @@ -130,7 +131,7 @@ def testGetInsertRequestBodyWithValidBodyWithoutTokens(self): self.assertEqual(body["records"][0], expectedOutput) def testGetInsertRequestBodyWithValidUpsertOptions(self): - body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')]))) + body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')], byot=BYOT.ENABLE))) expectedOutput = { "tableName": "pii_fields", "fields": { @@ -226,7 +227,7 @@ def testInvalidTokensInRecord(self): } ]} try: - getInsertRequestBody(invalidData, self.insertOptions) + getInsertRequestBody(invalidData, self.insertOptions2) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -244,7 +245,7 @@ def testEmptyTokensInRecord(self): } ]} try: - getInsertRequestBody(invalidData, self.insertOptions) + getInsertRequestBody(invalidData, self.insertOptions2) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -263,7 +264,7 @@ def testMismatchTokensInRecord(self): } ]} try: - getInsertRequestBody(invalidData, self.insertOptions) + getInsertRequestBody(invalidData, self.insertOptions2) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -290,7 +291,7 @@ def testMismatchTokensInRecord(self): # e.message, SkyflowErrorMessages.MISMATCH_OF_FIELDS_AND_TOKENS.value) def testGetInsertRequestBodyWithTokensValidBody(self): - body = json.loads(getInsertRequestBody(self.data, self.insertOptions)) + body = json.loads(getInsertRequestBody(self.data, self.insertOptions2)) expectedOutput = { "tableName": "pii_fields", "fields": { @@ -345,7 +346,7 @@ def testGetInsertRequestBodyInvalidTableType(self): def testGetInsertRequestBodyWithContinueOnErrorAsTrue(self): try: - options = InsertOptions(tokens=True, continueOnError=True) + options = InsertOptions(tokens=True, continueOnError=True, byot=BYOT.ENABLE) request = getInsertRequestBody(self.data, options) self.assertIn('continueOnError', request) request = json.loads(request) @@ -355,7 +356,7 @@ def testGetInsertRequestBodyWithContinueOnErrorAsTrue(self): def testGetInsertRequestBodyWithContinueOnErrorAsFalse(self): try: - options = InsertOptions(tokens=True, continueOnError=False) + options = InsertOptions(tokens=True, continueOnError=False, byot=BYOT.ENABLE) request = getInsertRequestBody(self.data, options) # assert 'continueOnError' in request self.assertIn('continueOnError', request) @@ -366,7 +367,7 @@ def testGetInsertRequestBodyWithContinueOnErrorAsFalse(self): def testGetInsertRequestBodyWithoutContinueOnError(self): try: - request = getInsertRequestBody(self.data, self.insertOptions) + request = getInsertRequestBody(self.data, self.insertOptions2) # assert 'continueOnError' not in request self.assertNotIn('continueOnError', request) except SkyflowError as e: @@ -615,4 +616,34 @@ def testValidUpsertOptions(self): self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) self.assertEqual( e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_COLUMN.value % 0) - + + def testTokensPassedWithByotModeDisable(self): + try: + options = InsertOptions(byot=BYOT.DISABLE) + getInsertRequestBody(self.data, options) + self.fail("Should have thrown an error") + except SkyflowError as e: + self.assertEqual(e.message, SkyflowErrorMessages.TOKENS_PASSED_FOR_BYOT_DISABLE.value) + + def testTokensNotPassedWithByotModeEnable(self): + try: + getInsertRequestBody(self.data2, self.insertOptions2) + self.fail("Should have thrown an error") + except SkyflowError as e: + self.assertEqual(e.message, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % "ENABLE") + + def testTokensNotPassedWithByotModeEnableStrict(self): + try: + options = InsertOptions(byot=BYOT.ENABLE_STRICT) + getInsertRequestBody(self.data2, options) + self.fail("Should have thrown an error") + except SkyflowError as e: + self.assertEqual(e.message, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % "ENABLE_STRICT") + + def testTokensPassedWithByotModeEnableStrict(self): + try: + options = InsertOptions(byot=BYOT.ENABLE_STRICT) + getInsertRequestBody(self.data, options) + self.fail("Should have thrown an error") + except SkyflowError as e: + self.assertEqual(e.message, SkyflowErrorMessages.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value)