Skip to content

Commit

Permalink
Merge pull request #105 from skyflowapi/release/23.9.1.2
Browse files Browse the repository at this point in the history
SK-1039/Release/23.9.1.2
  • Loading branch information
skyflow-vivek authored Sep 29, 2023
2 parents e1a1203 + 3cf077b commit 2ef2f03
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 16 deletions.
5 changes: 5 additions & 0 deletions skyflow/errors/_skyflow_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class SkyflowErrorMessages(Enum):
BATCH_INSERT_PARTIAL_SUCCESS = "Insert Operation is partially successful"
BATCH_INSERT_FAILURE = "Insert Operation is unsuccessful"

INVALID_BYOT_TYPE = "byot option has value of type %s, expected Skyflow.BYOT"
NO_TOKENS_IN_INSERT = "Tokens are not passed in records for byot as %s"
TOKENS_PASSED_FOR_BYOT_DISABLE = "Pass byot parameter with ENABLE for token insertion"
INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT = "For byot as ENABLE_STRICT, tokens should be passed for all fields"

class SkyflowError(Exception):
def __init__(self, code, message="An Error occured", data={}, interface: str = 'Unknown') -> None:
if type(code) is SkyflowErrorCodes:
Expand Down
1 change: 0 additions & 1 deletion skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def insert(self, records: dict, options: InsertOptions = InsertOptions()):
response = requests.post(requestURL, data=jsonBody, headers=headers)
processedResponse = processResponse(response)
result, partial = convertResponse(records, processedResponse, options)
# these statements will be covered in Integration Tests
if partial:
log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface)
elif 'records' not in result:
Expand Down
8 changes: 7 additions & 1 deletion skyflow/vault/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,22 @@ def __init__(self, vaultID: str=None, vaultURL: str=None, tokenProvider: Functio
self.vaultURL = vaultURL or ""
self.tokenProvider = tokenProvider

class BYOT(Enum):
DISABLE = "DISABLE"
ENABLE = "ENABLE"
ENABLE_STRICT = "ENABLE_STRICT"

class UpsertOption:
def __init__(self,table: str,column: str):
self.table = table
self.column = column

class InsertOptions:
def __init__(self, tokens: bool=True, upsert :List[UpsertOption]=None, continueOnError:bool=None):
def __init__(self, tokens: bool=True, upsert :List[UpsertOption]=None, continueOnError:bool=None, byot:BYOT=BYOT.DISABLE):
self.tokens = tokens
self.upsert = upsert
self.continueOnError = continueOnError
self.byot = byot

class UpdateOptions:
def __init__(self, tokens: bool=True):
Expand Down
25 changes: 22 additions & 3 deletions skyflow/vault/_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from requests.models import HTTPError
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import InterfaceName
from skyflow.vault._config import InsertOptions
from skyflow.vault._config import BYOT, InsertOptions

interface = InterfaceName.INSERT.value

Expand Down Expand Up @@ -38,7 +38,8 @@ def getInsertRequestBody(data, options: InsertOptions):
"method": "POST",
"quorum": True,
}
if "tokens" in record:
validateTokensAndByotMode(record, options.byot)
if "tokens" in record:
tokens = getTokens(record)
postPayload["tokens"] = tokens

Expand All @@ -51,7 +52,8 @@ def getInsertRequestBody(data, options: InsertOptions):
requestPayload.append(postPayload)
requestBody = {
"records": requestPayload,
"continueOnError": options.continueOnError
"continueOnError": options.continueOnError,
"byot": options.byot.value
}
if options.continueOnError == None:
requestBody.pop('continueOnError')
Expand Down Expand Up @@ -89,6 +91,23 @@ def getTableAndFields(record):

return (table, fields)

def validateTokensAndByotMode(record, byot:BYOT):

if not isinstance(byot, BYOT):
byotType = str(type(byot))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_BYOT_TYPE.value % (byotType), interface=interface)

if byot == BYOT.DISABLE:
if "tokens" in record:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.TOKENS_PASSED_FOR_BYOT_DISABLE, interface=interface)
elif "tokens" not in record:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % byot.value, interface=interface)
elif byot == BYOT.ENABLE_STRICT:
tokens = record["tokens"]
fields = record["fields"]
if len(tokens) != len(fields):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT, interface=interface)

def getTokens(record):
tokens = record["tokens"]
if not isinstance(tokens, dict):
Expand Down
53 changes: 42 additions & 11 deletions tests/vault/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow.service_account import generate_bearer_token
from skyflow.vault._client import Client
from skyflow.vault._config import Configuration, InsertOptions, UpsertOption
from skyflow.vault._config import Configuration, InsertOptions, UpsertOption, BYOT


class TestInsert(unittest.TestCase):
Expand Down Expand Up @@ -92,14 +92,15 @@ def setUp(self) -> None:
}

self.insertOptions = InsertOptions(tokens=True)
self.insertOptions2 = InsertOptions(tokens=True, byot=BYOT.ENABLE)

return super().setUp()

def getDataPath(self, file):
return self.dataPath + file + '.json'

def testGetInsertRequestBodyWithValidBody(self):
body = json.loads(getInsertRequestBody(self.data, self.insertOptions))
body = json.loads(getInsertRequestBody(self.data, self.insertOptions2))
expectedOutput = {
"tableName": "pii_fields",
"fields": {
Expand Down Expand Up @@ -130,7 +131,7 @@ def testGetInsertRequestBodyWithValidBodyWithoutTokens(self):
self.assertEqual(body["records"][0], expectedOutput)

def testGetInsertRequestBodyWithValidUpsertOptions(self):
body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')])))
body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')], byot=BYOT.ENABLE)))
expectedOutput = {
"tableName": "pii_fields",
"fields": {
Expand Down Expand Up @@ -226,7 +227,7 @@ def testInvalidTokensInRecord(self):
}
]}
try:
getInsertRequestBody(invalidData, self.insertOptions)
getInsertRequestBody(invalidData, self.insertOptions2)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -244,7 +245,7 @@ def testEmptyTokensInRecord(self):
}
]}
try:
getInsertRequestBody(invalidData, self.insertOptions)
getInsertRequestBody(invalidData, self.insertOptions2)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -263,7 +264,7 @@ def testMismatchTokensInRecord(self):
}
]}
try:
getInsertRequestBody(invalidData, self.insertOptions)
getInsertRequestBody(invalidData, self.insertOptions2)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -290,7 +291,7 @@ def testMismatchTokensInRecord(self):
# e.message, SkyflowErrorMessages.MISMATCH_OF_FIELDS_AND_TOKENS.value)

def testGetInsertRequestBodyWithTokensValidBody(self):
body = json.loads(getInsertRequestBody(self.data, self.insertOptions))
body = json.loads(getInsertRequestBody(self.data, self.insertOptions2))
expectedOutput = {
"tableName": "pii_fields",
"fields": {
Expand Down Expand Up @@ -345,7 +346,7 @@ def testGetInsertRequestBodyInvalidTableType(self):

def testGetInsertRequestBodyWithContinueOnErrorAsTrue(self):
try:
options = InsertOptions(tokens=True, continueOnError=True)
options = InsertOptions(tokens=True, continueOnError=True, byot=BYOT.ENABLE)
request = getInsertRequestBody(self.data, options)
self.assertIn('continueOnError', request)
request = json.loads(request)
Expand All @@ -355,7 +356,7 @@ def testGetInsertRequestBodyWithContinueOnErrorAsTrue(self):

def testGetInsertRequestBodyWithContinueOnErrorAsFalse(self):
try:
options = InsertOptions(tokens=True, continueOnError=False)
options = InsertOptions(tokens=True, continueOnError=False, byot=BYOT.ENABLE)
request = getInsertRequestBody(self.data, options)
# assert 'continueOnError' in request
self.assertIn('continueOnError', request)
Expand All @@ -366,7 +367,7 @@ def testGetInsertRequestBodyWithContinueOnErrorAsFalse(self):

def testGetInsertRequestBodyWithoutContinueOnError(self):
try:
request = getInsertRequestBody(self.data, self.insertOptions)
request = getInsertRequestBody(self.data, self.insertOptions2)
# assert 'continueOnError' not in request
self.assertNotIn('continueOnError', request)
except SkyflowError as e:
Expand Down Expand Up @@ -615,4 +616,34 @@ def testValidUpsertOptions(self):
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_COLUMN.value % 0)


def testTokensPassedWithByotModeDisable(self):
try:
options = InsertOptions(byot=BYOT.DISABLE)
getInsertRequestBody(self.data, options)
self.fail("Should have thrown an error")
except SkyflowError as e:
self.assertEqual(e.message, SkyflowErrorMessages.TOKENS_PASSED_FOR_BYOT_DISABLE.value)

def testTokensNotPassedWithByotModeEnable(self):
try:
getInsertRequestBody(self.data2, self.insertOptions2)
self.fail("Should have thrown an error")
except SkyflowError as e:
self.assertEqual(e.message, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % "ENABLE")

def testTokensNotPassedWithByotModeEnableStrict(self):
try:
options = InsertOptions(byot=BYOT.ENABLE_STRICT)
getInsertRequestBody(self.data2, options)
self.fail("Should have thrown an error")
except SkyflowError as e:
self.assertEqual(e.message, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % "ENABLE_STRICT")

def testTokensPassedWithByotModeEnableStrict(self):
try:
options = InsertOptions(byot=BYOT.ENABLE_STRICT)
getInsertRequestBody(self.data, options)
self.fail("Should have thrown an error")
except SkyflowError as e:
self.assertEqual(e.message, SkyflowErrorMessages.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value)

0 comments on commit 2ef2f03

Please sign in to comment.