diff --git a/CHANGELOG.md b/CHANGELOG.md index 328bdef..d00b897 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,11 @@ All notable changes to this project will be documented in this file. -## [1.16.2] - 2022-06-28 +## [1.7.0] - 2022-12-06 +### Added +- `upsert` support for insert method. + +## [1.6.2] - 2022-06-28 ### Added - Copyright header to all files diff --git a/samples/insert_upsert_sample.py b/samples/insert_upsert_sample.py new file mode 100644 index 0000000..aec06c4 --- /dev/null +++ b/samples/insert_upsert_sample.py @@ -0,0 +1,39 @@ +''' + Copyright (c) 2022 Skyflow, Inc. +''' +from skyflow.errors import SkyflowError +from skyflow.service_account import generate_bearer_token, is_expired +from skyflow.vault import Client, InsertOptions, Configuration, UpsertOption + +# cache token for reuse +bearerToken = '' + +def token_provider(): + global bearerToken + if is_expired(bearerToken): + bearerToken, _ = generate_bearer_token('') + return bearerToken + + +try: + config = Configuration( + '', '', token_provider) + client = Client(config) + + upsertOption = UpsertOption(table='',column='') + options = InsertOptions(tokens=True,upsert=[upsertOption]) + + data = { + 'records': [ + { + 'table': '', + 'fields': { + '': '' + } + } + ] + } + response = client.insert(data, options=options) + print('Response:', response) +except SkyflowError as e: + print('Error Occurred:', e) diff --git a/skyflow/errors/_skyflow_errors.py b/skyflow/errors/_skyflow_errors.py index ddbd617..d617898 100644 --- a/skyflow/errors/_skyflow_errors.py +++ b/skyflow/errors/_skyflow_errors.py @@ -65,7 +65,12 @@ class SkyflowErrorMessages(Enum): RESPONSE_NOT_JSON = "Response %s is not valid JSON" TOKEN_PROVIDER_INVALID_TOKEN = "Invalid token from tokenProvider" - + INVALID_UPSERT_OPTIONS_TYPE = "upsertOptions key has value of type %s, expected list" + EMPTY_UPSERT_OPTIONS_LIST = "upsert option cannot be an empty array, atleast one object of table and column is required" + INVALID_UPSERT_TABLE_TYPE = "upsert object table key has value of type %s, expected string" + INVALID_UPSERT_COLUMN_TYPE = "upsert object column key has value of type %s, expected string" + EMPTY_UPSERT_OPTION_TABLE = "upsert object table value is empty string at index %s, expected non-empty string" + EMPTY_UPSERT_OPTION_COLUMN = "upsert object column value is empty string at index %s, expected non-empty string" class SkyflowError(Exception): def __init__(self, code, message="An Error occured", data={}, interface: str = 'Unknown') -> None: diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index bc19547..8aeeb8f 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -45,7 +45,7 @@ def insert(self, records: dict, options: InsertOptions = InsertOptions()): self._checkConfig(interface) - jsonBody = getInsertRequestBody(records, options.tokens) + jsonBody = getInsertRequestBody(records, options) requestURL = self._get_complete_vault_url() self.storedToken = tokenProviderWrapper( self.storedToken, self.tokenProvider, interface) diff --git a/skyflow/vault/_config.py b/skyflow/vault/_config.py index 270aff5..741a99a 100644 --- a/skyflow/vault/_config.py +++ b/skyflow/vault/_config.py @@ -3,7 +3,7 @@ ''' from enum import Enum from types import FunctionType -from typing import OrderedDict +from typing import List class Configuration: @@ -24,9 +24,15 @@ def __init__(self, vaultID: str=None, vaultURL: str=None, tokenProvider: Functio self.vaultURL = vaultURL or "" self.tokenProvider = tokenProvider +class UpsertOption: + def __init__(self,table: str,column: str): + self.table = table + self.column = column + class InsertOptions: - def __init__(self, tokens: bool=True): + def __init__(self, tokens: bool=True,upsert :List[UpsertOption]=None): self.tokens = tokens + self.upsert = upsert class RequestMethod(Enum): GET = 'GET' diff --git a/skyflow/vault/_insert.py b/skyflow/vault/_insert.py index cf4f1d0..94660f7 100644 --- a/skyflow/vault/_insert.py +++ b/skyflow/vault/_insert.py @@ -11,7 +11,7 @@ interface = InterfaceName.INSERT.value -def getInsertRequestBody(data, tokens: bool): +def getInsertRequestBody(data, options): try: records = data["records"] except KeyError: @@ -22,17 +22,23 @@ def getInsertRequestBody(data, tokens: bool): recordsType = str(type(records)) raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % ( recordsType), interface=interface) - + + upsertOptions = options.upsert + + if upsertOptions: + validateUpsertOptions(upsertOptions=upsertOptions) + requestPayload = [] insertTokenPayload = [] for index, record in enumerate(records): tableName, fields = getTableAndFields(record) - requestPayload.append({ - "tableName": tableName, - "fields": fields, - "method": "POST", - "quorum": True}) - if tokens: + postPayload = {"tableName": tableName, "fields": fields,"method": "POST","quorum": True} + + if upsertOptions: + postPayload["upsert"] = getUpsertColumn(tableName,upsertOptions) + + requestPayload.append(postPayload) + if options.tokens: insertTokenPayload.append({ "method": "GET", "tableName": tableName, @@ -114,3 +120,32 @@ def convertResponse(request: dict, response: dict, tokens: bool): else: result.append({'table': table, 'skyflow_id': skyflow_id}) return {'records': result} + +def getUpsertColumn(tableName, upsertOptions): + uniqueColumn:str = '' + for upsertOption in upsertOptions: + if tableName == upsertOption.table: + uniqueColumn = upsertOption.column + return uniqueColumn + +def validateUpsertOptions(upsertOptions): + if not isinstance(upsertOptions,list): + upsertOptionsType = str(type(upsertOptions)) + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_UPSERT_OPTIONS_TYPE.value %( + upsertOptionsType),interface=interface) + if len(upsertOptions) == 0: + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.EMPTY_UPSERT_OPTIONS_LIST.value, interface=interface) + + for index, upsertOption in enumerate(upsertOptions): + if upsertOption.table == None or not isinstance(upsertOption.table,str): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_UPSERT_TABLE_TYPE.value %( + index),interface=interface) + if upsertOption.table == '': + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_TABLE.value %( + index),interface=interface) + if upsertOption.column == None or not isinstance(upsertOption.column,str): + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_UPSERT_COLUMN_TYPE.value %( + index),interface=interface) + if upsertOption.column == '': + raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_COLUMN.value %( + index),interface=interface) \ No newline at end of file diff --git a/tests/vault/test_insert.py b/tests/vault/test_insert.py index 81e4745..75b2cff 100644 --- a/tests/vault/test_insert.py +++ b/tests/vault/test_insert.py @@ -6,11 +6,11 @@ import os from requests.models import Response from dotenv import dotenv_values -from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse +from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse, getUpsertColumn, validateUpsertOptions from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages from skyflow.service_account import generate_bearer_token from skyflow.vault._client import Client -from skyflow.vault._config import Configuration, InsertOptions +from skyflow.vault._config import Configuration, InsertOptions, UpsertOption class TestInsert(unittest.TestCase): @@ -39,6 +39,8 @@ def setUp(self) -> None: } } ]} + + self.insertOptions = InsertOptions(tokens=True) return super().setUp() @@ -46,7 +48,7 @@ def getDataPath(self, file): return self.dataPath + file + '.json' def testGetInsertRequestBodyWithValidBody(self): - body = json.loads(getInsertRequestBody(self.data, True)) + body = json.loads(getInsertRequestBody(self.data, self.insertOptions)) expectedOutput = { "tableName": "pii_fields", "fields": { @@ -57,11 +59,25 @@ def testGetInsertRequestBodyWithValidBody(self): "quorum": True } self.assertEqual(body["records"][0], expectedOutput) + + def testGetInsertRequestBodyWithValidUpsertOptions(self): + body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')]))) + expectedOutput = { + "tableName": "pii_fields", + "fields": { + "cardNumber": "4111-1111-1111-1111", + "cvv": "234" + }, + "method": 'POST', + "quorum": True, + "upsert": 'column1', + } + self.assertEqual(body["records"][0], expectedOutput) def testGetInsertRequestBodyNoRecords(self): invalidData = {"invalidKey": self.data["records"]} try: - getInsertRequestBody(invalidData, True) + getInsertRequestBody(invalidData, self.insertOptions) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -71,7 +87,7 @@ def testGetInsertRequestBodyNoRecords(self): def testGetInsertRequestBodyRecordsInvalidType(self): invalidData = {"records": 'records'} try: - getInsertRequestBody(invalidData, True) + getInsertRequestBody(invalidData, self.insertOptions) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -91,7 +107,7 @@ def testGetInsertRequestBodyNoFields(self): } ]} try: - getInsertRequestBody(invalidData, True) + getInsertRequestBody(invalidData, self.insertOptions) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -105,7 +121,7 @@ def testGetInsertRequestBodyInvalidFieldsType(self): } ]} try: - getInsertRequestBody(invalidData, True) + getInsertRequestBody(invalidData, self.insertOptions) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -125,7 +141,7 @@ def testGetInsertRequestBodyNoTable(self): } ]} try: - getInsertRequestBody(invalidData, True) + getInsertRequestBody(invalidData, self.insertOptions) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -141,7 +157,7 @@ def testGetInsertRequestBodyInvalidTableType(self): } ]} try: - getInsertRequestBody(invalidData, True) + getInsertRequestBody(invalidData, self.insertOptions) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -159,7 +175,7 @@ def testInsertInvalidJson(self): } try: - getInsertRequestBody(invalidjson, True) + getInsertRequestBody(invalidjson, self.insertOptions) self.fail('Should have thrown an error') except SkyflowError as e: self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) @@ -254,3 +270,51 @@ def testInsertInvalidToken(self): self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) self.assertEqual( e.message, SkyflowErrorMessages.TOKEN_PROVIDER_INVALID_TOKEN.value) + + def testGetUpsertColumn(self): + testUpsertOptions = [UpsertOption(table='table1',column='column1'), + UpsertOption(table='table2',column='column2')] + upsertValid = getUpsertColumn('table1',upsertOptions=testUpsertOptions) + upsertInvalid = getUpsertColumn('table3',upsertOptions=testUpsertOptions) + self.assertEqual(upsertValid,'column1') + self.assertEqual(upsertInvalid,'') + + def testValidUpsertOptions(self): + testUpsertOptions = 'upsert_string' + try: + validateUpsertOptions(testUpsertOptions) + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_UPSERT_OPTIONS_TYPE.value % type(testUpsertOptions) ) + try: + validateUpsertOptions(upsertOptions=[]) + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTIONS_LIST.value) + try: + validateUpsertOptions(upsertOptions=[UpsertOption(table=123,column='')]) + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_UPSERT_TABLE_TYPE.value % 0) + try: + validateUpsertOptions(upsertOptions=[UpsertOption(table='',column='')]) + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_TABLE.value % 0) + try: + validateUpsertOptions(upsertOptions=[UpsertOption(table='table1',column=1343)]) + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.INVALID_UPSERT_COLUMN_TYPE.value % 0) + try: + validateUpsertOptions(upsertOptions=[UpsertOption(table='table2',column='')]) + except SkyflowError as e: + self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value) + self.assertEqual( + e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_COLUMN.value % 0) +