Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SK-1039/Release/23.9.1.2 #105

Merged
merged 5 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions skyflow/errors/_skyflow_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class SkyflowErrorMessages(Enum):
BATCH_INSERT_PARTIAL_SUCCESS = "Insert Operation is partially successful"
BATCH_INSERT_FAILURE = "Insert Operation is unsuccessful"

INVALID_BYOT_TYPE = "byot option has value of type %s, expected Skyflow.BYOT"
NO_TOKENS_IN_INSERT = "Tokens are not passed in records for byot as %s"
TOKENS_PASSED_FOR_BYOT_DISABLE = "Pass byot parameter with ENABLE for token insertion"
INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT = "For byot as ENABLE_STRICT, tokens should be passed for all fields"

class SkyflowError(Exception):
def __init__(self, code, message="An Error occured", data={}, interface: str = 'Unknown') -> None:
if type(code) is SkyflowErrorCodes:
Expand Down
1 change: 0 additions & 1 deletion skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def insert(self, records: dict, options: InsertOptions = InsertOptions()):
response = requests.post(requestURL, data=jsonBody, headers=headers)
processedResponse = processResponse(response)
result, partial = convertResponse(records, processedResponse, options)
# these statements will be covered in Integration Tests
if partial:
log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface)
elif 'records' not in result:
Expand Down
8 changes: 7 additions & 1 deletion skyflow/vault/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,22 @@ def __init__(self, vaultID: str=None, vaultURL: str=None, tokenProvider: Functio
self.vaultURL = vaultURL or ""
self.tokenProvider = tokenProvider

class BYOT(Enum):
DISABLE = "DISABLE"
ENABLE = "ENABLE"
ENABLE_STRICT = "ENABLE_STRICT"

class UpsertOption:
def __init__(self,table: str,column: str):
self.table = table
self.column = column

class InsertOptions:
def __init__(self, tokens: bool=True, upsert :List[UpsertOption]=None, continueOnError:bool=None):
def __init__(self, tokens: bool=True, upsert :List[UpsertOption]=None, continueOnError:bool=None, byot:BYOT=BYOT.DISABLE):
self.tokens = tokens
self.upsert = upsert
self.continueOnError = continueOnError
self.byot = byot

class UpdateOptions:
def __init__(self, tokens: bool=True):
Expand Down
25 changes: 22 additions & 3 deletions skyflow/vault/_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from requests.models import HTTPError
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import InterfaceName
from skyflow.vault._config import InsertOptions
from skyflow.vault._config import BYOT, InsertOptions

interface = InterfaceName.INSERT.value

Expand Down Expand Up @@ -38,7 +38,8 @@
"method": "POST",
"quorum": True,
}
if "tokens" in record:
validateTokensAndByotMode(record, options.byot)
if "tokens" in record:
tokens = getTokens(record)
postPayload["tokens"] = tokens

Expand All @@ -51,7 +52,8 @@
requestPayload.append(postPayload)
requestBody = {
"records": requestPayload,
"continueOnError": options.continueOnError
"continueOnError": options.continueOnError,
"byot": options.byot.value
}
if options.continueOnError == None:
requestBody.pop('continueOnError')
Expand Down Expand Up @@ -89,6 +91,23 @@

return (table, fields)

def validateTokensAndByotMode(record, byot:BYOT):

if not isinstance(byot, BYOT):
byotType = str(type(byot))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_BYOT_TYPE.value % (byotType), interface=interface)

Check warning on line 98 in skyflow/vault/_insert.py

View check run for this annotation

Codecov / codecov/patch

skyflow/vault/_insert.py#L97-L98

Added lines #L97 - L98 were not covered by tests

if byot == BYOT.DISABLE:
if "tokens" in record:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.TOKENS_PASSED_FOR_BYOT_DISABLE, interface=interface)
elif "tokens" not in record:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % byot.value, interface=interface)
elif byot == BYOT.ENABLE_STRICT:
tokens = record["tokens"]
fields = record["fields"]
if len(tokens) != len(fields):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT, interface=interface)

def getTokens(record):
tokens = record["tokens"]
if not isinstance(tokens, dict):
Expand Down
53 changes: 42 additions & 11 deletions tests/vault/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow.service_account import generate_bearer_token
from skyflow.vault._client import Client
from skyflow.vault._config import Configuration, InsertOptions, UpsertOption
from skyflow.vault._config import Configuration, InsertOptions, UpsertOption, BYOT


class TestInsert(unittest.TestCase):
Expand Down Expand Up @@ -92,14 +92,15 @@ def setUp(self) -> None:
}

self.insertOptions = InsertOptions(tokens=True)
self.insertOptions2 = InsertOptions(tokens=True, byot=BYOT.ENABLE)

return super().setUp()

def getDataPath(self, file):
return self.dataPath + file + '.json'

def testGetInsertRequestBodyWithValidBody(self):
body = json.loads(getInsertRequestBody(self.data, self.insertOptions))
body = json.loads(getInsertRequestBody(self.data, self.insertOptions2))
expectedOutput = {
"tableName": "pii_fields",
"fields": {
Expand Down Expand Up @@ -130,7 +131,7 @@ def testGetInsertRequestBodyWithValidBodyWithoutTokens(self):
self.assertEqual(body["records"][0], expectedOutput)

def testGetInsertRequestBodyWithValidUpsertOptions(self):
body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')])))
body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')], byot=BYOT.ENABLE)))
expectedOutput = {
"tableName": "pii_fields",
"fields": {
Expand Down Expand Up @@ -226,7 +227,7 @@ def testInvalidTokensInRecord(self):
}
]}
try:
getInsertRequestBody(invalidData, self.insertOptions)
getInsertRequestBody(invalidData, self.insertOptions2)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -244,7 +245,7 @@ def testEmptyTokensInRecord(self):
}
]}
try:
getInsertRequestBody(invalidData, self.insertOptions)
getInsertRequestBody(invalidData, self.insertOptions2)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -263,7 +264,7 @@ def testMismatchTokensInRecord(self):
}
]}
try:
getInsertRequestBody(invalidData, self.insertOptions)
getInsertRequestBody(invalidData, self.insertOptions2)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -290,7 +291,7 @@ def testMismatchTokensInRecord(self):
# e.message, SkyflowErrorMessages.MISMATCH_OF_FIELDS_AND_TOKENS.value)

def testGetInsertRequestBodyWithTokensValidBody(self):
body = json.loads(getInsertRequestBody(self.data, self.insertOptions))
body = json.loads(getInsertRequestBody(self.data, self.insertOptions2))
expectedOutput = {
"tableName": "pii_fields",
"fields": {
Expand Down Expand Up @@ -345,7 +346,7 @@ def testGetInsertRequestBodyInvalidTableType(self):

def testGetInsertRequestBodyWithContinueOnErrorAsTrue(self):
try:
options = InsertOptions(tokens=True, continueOnError=True)
options = InsertOptions(tokens=True, continueOnError=True, byot=BYOT.ENABLE)
request = getInsertRequestBody(self.data, options)
self.assertIn('continueOnError', request)
request = json.loads(request)
Expand All @@ -355,7 +356,7 @@ def testGetInsertRequestBodyWithContinueOnErrorAsTrue(self):

def testGetInsertRequestBodyWithContinueOnErrorAsFalse(self):
try:
options = InsertOptions(tokens=True, continueOnError=False)
options = InsertOptions(tokens=True, continueOnError=False, byot=BYOT.ENABLE)
request = getInsertRequestBody(self.data, options)
# assert 'continueOnError' in request
self.assertIn('continueOnError', request)
Expand All @@ -366,7 +367,7 @@ def testGetInsertRequestBodyWithContinueOnErrorAsFalse(self):

def testGetInsertRequestBodyWithoutContinueOnError(self):
try:
request = getInsertRequestBody(self.data, self.insertOptions)
request = getInsertRequestBody(self.data, self.insertOptions2)
# assert 'continueOnError' not in request
self.assertNotIn('continueOnError', request)
except SkyflowError as e:
Expand Down Expand Up @@ -582,4 +583,34 @@ def testValidUpsertOptions(self):
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_COLUMN.value % 0)


def testTokensPassedWithByotModeDisable(self):
try:
options = InsertOptions(byot=BYOT.DISABLE)
getInsertRequestBody(self.data, options)
self.fail("Should have thrown an error")
except SkyflowError as e:
self.assertEqual(e.message, SkyflowErrorMessages.TOKENS_PASSED_FOR_BYOT_DISABLE.value)

def testTokensNotPassedWithByotModeEnable(self):
try:
getInsertRequestBody(self.data2, self.insertOptions2)
self.fail("Should have thrown an error")
except SkyflowError as e:
self.assertEqual(e.message, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % "ENABLE")

def testTokensNotPassedWithByotModeEnableStrict(self):
try:
options = InsertOptions(byot=BYOT.ENABLE_STRICT)
getInsertRequestBody(self.data2, options)
self.fail("Should have thrown an error")
except SkyflowError as e:
self.assertEqual(e.message, SkyflowErrorMessages.NO_TOKENS_IN_INSERT.value % "ENABLE_STRICT")

def testTokensPassedWithByotModeEnableStrict(self):
try:
options = InsertOptions(byot=BYOT.ENABLE_STRICT)
getInsertRequestBody(self.data, options)
self.fail("Should have thrown an error")
except SkyflowError as e:
self.assertEqual(e.message, SkyflowErrorMessages.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value)
Loading