Skip to content

Commit

Permalink
Merge pull request #66 from skyflowapi/SK-47/upsert-support
Browse files Browse the repository at this point in the history
SK-47 add upsert support in insert method options.
  • Loading branch information
Shaik-Luqmaan authored Dec 7, 2022
2 parents 0041e66 + 8a8e4c9 commit c567a8e
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 23 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

All notable changes to this project will be documented in this file.

## [1.16.2] - 2022-06-28
## [1.7.0] - 2022-12-06
### Added
- `upsert` support for insert method.

## [1.6.2] - 2022-06-28

### Added
- Copyright header to all files
Expand Down
39 changes: 39 additions & 0 deletions samples/insert_upsert_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
'''
Copyright (c) 2022 Skyflow, Inc.
'''
from skyflow.errors import SkyflowError
from skyflow.service_account import generate_bearer_token, is_expired
from skyflow.vault import Client, InsertOptions, Configuration, UpsertOption

# cache token for reuse
bearerToken = ''

def token_provider():
global bearerToken
if is_expired(bearerToken):
bearerToken, _ = generate_bearer_token('<YOUR_CREDENTIALS_FILE_PATH>')
return bearerToken


try:
config = Configuration(
'<YOUR_VAULT_ID>', '<YOUR_VAULT_URL>', token_provider)
client = Client(config)

upsertOption = UpsertOption(table='<TABLE_NAME>',column='<UNIQUE_COLUMN_NAME>')
options = InsertOptions(tokens=True,upsert=[upsertOption])

data = {
'records': [
{
'table': '<TABLE_NAME>',
'fields': {
'<FIELDNAME>': '<VALUE>'
}
}
]
}
response = client.insert(data, options=options)
print('Response:', response)
except SkyflowError as e:
print('Error Occurred:', e)
7 changes: 6 additions & 1 deletion skyflow/errors/_skyflow_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ class SkyflowErrorMessages(Enum):
RESPONSE_NOT_JSON = "Response %s is not valid JSON"

TOKEN_PROVIDER_INVALID_TOKEN = "Invalid token from tokenProvider"

INVALID_UPSERT_OPTIONS_TYPE = "upsertOptions key has value of type %s, expected list"
EMPTY_UPSERT_OPTIONS_LIST = "upsert option cannot be an empty array, atleast one object of table and column is required"
INVALID_UPSERT_TABLE_TYPE = "upsert object table key has value of type %s, expected string"
INVALID_UPSERT_COLUMN_TYPE = "upsert object column key has value of type %s, expected string"
EMPTY_UPSERT_OPTION_TABLE = "upsert object table value is empty string at index %s, expected non-empty string"
EMPTY_UPSERT_OPTION_COLUMN = "upsert object column value is empty string at index %s, expected non-empty string"

class SkyflowError(Exception):
def __init__(self, code, message="An Error occured", data={}, interface: str = 'Unknown') -> None:
Expand Down
2 changes: 1 addition & 1 deletion skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def insert(self, records: dict, options: InsertOptions = InsertOptions()):

self._checkConfig(interface)

jsonBody = getInsertRequestBody(records, options.tokens)
jsonBody = getInsertRequestBody(records, options)
requestURL = self._get_complete_vault_url()
self.storedToken = tokenProviderWrapper(
self.storedToken, self.tokenProvider, interface)
Expand Down
10 changes: 8 additions & 2 deletions skyflow/vault/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
'''
from enum import Enum
from types import FunctionType
from typing import OrderedDict
from typing import List


class Configuration:
Expand All @@ -24,9 +24,15 @@ def __init__(self, vaultID: str=None, vaultURL: str=None, tokenProvider: Functio
self.vaultURL = vaultURL or ""
self.tokenProvider = tokenProvider

class UpsertOption:
def __init__(self,table: str,column: str):
self.table = table
self.column = column

class InsertOptions:
def __init__(self, tokens: bool=True):
def __init__(self, tokens: bool=True,upsert :List[UpsertOption]=None):
self.tokens = tokens
self.upsert = upsert

class RequestMethod(Enum):
GET = 'GET'
Expand Down
51 changes: 43 additions & 8 deletions skyflow/vault/_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
interface = InterfaceName.INSERT.value


def getInsertRequestBody(data, tokens: bool):
def getInsertRequestBody(data, options):
try:
records = data["records"]
except KeyError:
Expand All @@ -22,17 +22,23 @@ def getInsertRequestBody(data, tokens: bool):
recordsType = str(type(records))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % (
recordsType), interface=interface)


upsertOptions = options.upsert

if upsertOptions:
validateUpsertOptions(upsertOptions=upsertOptions)

requestPayload = []
insertTokenPayload = []
for index, record in enumerate(records):
tableName, fields = getTableAndFields(record)
requestPayload.append({
"tableName": tableName,
"fields": fields,
"method": "POST",
"quorum": True})
if tokens:
postPayload = {"tableName": tableName, "fields": fields,"method": "POST","quorum": True}

if upsertOptions:
postPayload["upsert"] = getUpsertColumn(tableName,upsertOptions)

requestPayload.append(postPayload)
if options.tokens:
insertTokenPayload.append({
"method": "GET",
"tableName": tableName,
Expand Down Expand Up @@ -114,3 +120,32 @@ def convertResponse(request: dict, response: dict, tokens: bool):
else:
result.append({'table': table, 'skyflow_id': skyflow_id})
return {'records': result}

def getUpsertColumn(tableName, upsertOptions):
uniqueColumn:str = ''
for upsertOption in upsertOptions:
if tableName == upsertOption.table:
uniqueColumn = upsertOption.column
return uniqueColumn

def validateUpsertOptions(upsertOptions):
if not isinstance(upsertOptions,list):
upsertOptionsType = str(type(upsertOptions))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_UPSERT_OPTIONS_TYPE.value %(
upsertOptionsType),interface=interface)
if len(upsertOptions) == 0:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.EMPTY_UPSERT_OPTIONS_LIST.value, interface=interface)

for index, upsertOption in enumerate(upsertOptions):
if upsertOption.table == None or not isinstance(upsertOption.table,str):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_UPSERT_TABLE_TYPE.value %(
index),interface=interface)
if upsertOption.table == '':
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_TABLE.value %(
index),interface=interface)
if upsertOption.column == None or not isinstance(upsertOption.column,str):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_UPSERT_COLUMN_TYPE.value %(
index),interface=interface)
if upsertOption.column == '':
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_COLUMN.value %(
index),interface=interface)
84 changes: 74 additions & 10 deletions tests/vault/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import os
from requests.models import Response
from dotenv import dotenv_values
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse, getUpsertColumn, validateUpsertOptions
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow.service_account import generate_bearer_token
from skyflow.vault._client import Client
from skyflow.vault._config import Configuration, InsertOptions
from skyflow.vault._config import Configuration, InsertOptions, UpsertOption


class TestInsert(unittest.TestCase):
Expand Down Expand Up @@ -39,14 +39,16 @@ def setUp(self) -> None:
}
}
]}

self.insertOptions = InsertOptions(tokens=True)

return super().setUp()

def getDataPath(self, file):
return self.dataPath + file + '.json'

def testGetInsertRequestBodyWithValidBody(self):
body = json.loads(getInsertRequestBody(self.data, True))
body = json.loads(getInsertRequestBody(self.data, self.insertOptions))
expectedOutput = {
"tableName": "pii_fields",
"fields": {
Expand All @@ -57,11 +59,25 @@ def testGetInsertRequestBodyWithValidBody(self):
"quorum": True
}
self.assertEqual(body["records"][0], expectedOutput)

def testGetInsertRequestBodyWithValidUpsertOptions(self):
body = json.loads(getInsertRequestBody(self.data, InsertOptions(True,[UpsertOption(table='pii_fields',column='column1')])))
expectedOutput = {
"tableName": "pii_fields",
"fields": {
"cardNumber": "4111-1111-1111-1111",
"cvv": "234"
},
"method": 'POST',
"quorum": True,
"upsert": 'column1',
}
self.assertEqual(body["records"][0], expectedOutput)

def testGetInsertRequestBodyNoRecords(self):
invalidData = {"invalidKey": self.data["records"]}
try:
getInsertRequestBody(invalidData, True)
getInsertRequestBody(invalidData, self.insertOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -71,7 +87,7 @@ def testGetInsertRequestBodyNoRecords(self):
def testGetInsertRequestBodyRecordsInvalidType(self):
invalidData = {"records": 'records'}
try:
getInsertRequestBody(invalidData, True)
getInsertRequestBody(invalidData, self.insertOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -91,7 +107,7 @@ def testGetInsertRequestBodyNoFields(self):
}
]}
try:
getInsertRequestBody(invalidData, True)
getInsertRequestBody(invalidData, self.insertOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -105,7 +121,7 @@ def testGetInsertRequestBodyInvalidFieldsType(self):
}
]}
try:
getInsertRequestBody(invalidData, True)
getInsertRequestBody(invalidData, self.insertOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -125,7 +141,7 @@ def testGetInsertRequestBodyNoTable(self):
}
]}
try:
getInsertRequestBody(invalidData, True)
getInsertRequestBody(invalidData, self.insertOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -141,7 +157,7 @@ def testGetInsertRequestBodyInvalidTableType(self):
}
]}
try:
getInsertRequestBody(invalidData, True)
getInsertRequestBody(invalidData, self.insertOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand All @@ -159,7 +175,7 @@ def testInsertInvalidJson(self):
}

try:
getInsertRequestBody(invalidjson, True)
getInsertRequestBody(invalidjson, self.insertOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
Expand Down Expand Up @@ -254,3 +270,51 @@ def testInsertInvalidToken(self):
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.TOKEN_PROVIDER_INVALID_TOKEN.value)

def testGetUpsertColumn(self):
testUpsertOptions = [UpsertOption(table='table1',column='column1'),
UpsertOption(table='table2',column='column2')]
upsertValid = getUpsertColumn('table1',upsertOptions=testUpsertOptions)
upsertInvalid = getUpsertColumn('table3',upsertOptions=testUpsertOptions)
self.assertEqual(upsertValid,'column1')
self.assertEqual(upsertInvalid,'')

def testValidUpsertOptions(self):
testUpsertOptions = 'upsert_string'
try:
validateUpsertOptions(testUpsertOptions)
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.INVALID_UPSERT_OPTIONS_TYPE.value % type(testUpsertOptions) )
try:
validateUpsertOptions(upsertOptions=[])
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTIONS_LIST.value)
try:
validateUpsertOptions(upsertOptions=[UpsertOption(table=123,column='')])
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.INVALID_UPSERT_TABLE_TYPE.value % 0)
try:
validateUpsertOptions(upsertOptions=[UpsertOption(table='',column='')])
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_TABLE.value % 0)
try:
validateUpsertOptions(upsertOptions=[UpsertOption(table='table1',column=1343)])
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.INVALID_UPSERT_COLUMN_TYPE.value % 0)
try:
validateUpsertOptions(upsertOptions=[UpsertOption(table='table2',column='')])
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.EMPTY_UPSERT_OPTION_COLUMN.value % 0)

0 comments on commit c567a8e

Please sign in to comment.