Skip to content

Commit

Permalink
Merge pull request #112 from skyflowapi/release/23.11.1
Browse files Browse the repository at this point in the history
SK-812 Release/23.11.1
  • Loading branch information
skyflow-bharti authored Oct 30, 2023
2 parents 5605cb3 + 1f229fa commit 98f4380
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 77 deletions.
34 changes: 34 additions & 0 deletions samples/get_with_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
'''
Copyright (c) 2022 Skyflow, Inc.
'''
from skyflow.errors import SkyflowError
from skyflow.service_account import generate_bearer_token, is_expired
from skyflow.vault import Client, Configuration, RedactionType, GetOptions

# cache token for reuse
bearerToken = ''

def token_provider():
global bearerToken
if is_expired(bearerToken):
bearerToken, _ = generate_bearer_token('<YOUR_CREDENTIALS_FILE_PATH>')
return bearerToken


try:
config = Configuration(
'<YOUR_VAULT_ID>', '<YOUR_VAULT_URL>', token_provider)
client = Client(config)
options = GetOptions(False)
data = {"records": [
{
"ids": ["<SKYFLOW_ID1>", "<SKYFLOW_ID2>", "<SKYFLOW_ID3>"],
"table": "<TABLE_NAME>",
"redaction": RedactionType.PLAIN_TEXT
}
]}

response = client.get(data,options=options)
print('Response:', response)
except SkyflowError as e:
print('Error Occurred:', e)
4 changes: 4 additions & 0 deletions skyflow/errors/_skyflow_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class SkyflowErrorCodes(Enum):
INVALID_INDEX = 404
SERVER_ERROR = 500
PARTIAL_SUCCESS = 500
TOKENS_GET_COLUMN_NOT_SUPPORTED = 400
REDACTION_WITH_TOKENS_NOT_SUPPORTED = 400


class SkyflowErrorMessages(Enum):
Expand Down Expand Up @@ -68,6 +70,8 @@ class SkyflowErrorMessages(Enum):
INVALID_QUERY_PARAM_TYPE = "Query params (key, value) must be of type 'str' given type - (%s, %s)"

INVALID_TOKEN_TYPE = "Token key has value of type %s, expected string"
REDACTION_WITH_TOKENS_NOT_SUPPORTED = "Redaction cannot be used when tokens are true in options"
TOKENS_GET_COLUMN_NOT_SUPPORTED = "Column_name or column_values cannot be used with tokens in options"
PARTIAL_SUCCESS = "Server returned errors, check SkyflowError.data for more"

VAULT_ID_INVALID_TYPE = "Expected Vault ID to be str, got %s"
Expand Down
16 changes: 14 additions & 2 deletions skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@
import json
import types
import requests
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse
from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody
from skyflow.vault._config import Configuration, GetOptions
from skyflow.vault._config import InsertOptions, ConnectionConfig, UpdateOptions
from skyflow.vault._connection import createRequest
from skyflow.vault._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
from skyflow.vault._get_by_id import sendGetByIdRequests, createGetResponseBody
from skyflow.vault._get import sendGetRequests
import asyncio
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import log_info, InfoMessages, InterfaceName, getMetrics
from skyflow.vault._token import tokenProviderWrapper

from ._delete import deleteProcessResponse
from ._insert import getInsertRequestBody, processResponse, convertResponse
Expand Down Expand Up @@ -88,7 +100,7 @@ def detokenize(self, records: dict, options: DetokenizeOptions = DetokenizeOptio
log_info(InfoMessages.DETOKENIZE_SUCCESS.value, interface)
return result

def get(self, records):
def get(self, records, options: GetOptions = GetOptions()):
interface = InterfaceName.GET.value
log_info(InfoMessages.GET_TRIGGERED.value, interface)

Expand All @@ -97,7 +109,7 @@ def get(self, records):
self.storedToken, self.tokenProvider, interface)
url = self._get_complete_vault_url()
responses = asyncio.run(sendGetRequests(
records, url, self.storedToken))
records, options,url, self.storedToken))
result, partial = createGetResponseBody(responses)
if partial:
raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS,
Expand Down
22 changes: 15 additions & 7 deletions skyflow/vault/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

class Configuration:

def __init__(self, vaultID: str=None, vaultURL: str=None, tokenProvider: FunctionType=None):
def __init__(self, vaultID: str = None, vaultURL: str = None, tokenProvider: FunctionType = None):

self.vaultID = ''
self.vaultURL = ''

if tokenProvider == None and vaultURL == None and isinstance(vaultID, FunctionType):
self.tokenProvider = vaultID
elif tokenProvider == None and vaultID == None and isinstance(vaultURL, FunctionType):
Expand All @@ -30,19 +30,25 @@ class BYOT(Enum):
ENABLE_STRICT = "ENABLE_STRICT"

class UpsertOption:
def __init__(self,table: str,column: str):
def __init__(self, table: str, column: str):
self.table = table
self.column = column


class InsertOptions:
def __init__(self, tokens: bool=True, upsert :List[UpsertOption]=None, continueOnError:bool=None, byot:BYOT=BYOT.DISABLE):
self.tokens = tokens
self.upsert = upsert
self.continueOnError = continueOnError
self.byot = byot


class UpdateOptions:
def __init__(self, tokens: bool=True):
def __init__(self, tokens: bool = True):
self.tokens = tokens

class GetOptions:
def __init__(self, tokens: bool = False):
self.tokens = tokens

class DeleteOptions:
Expand All @@ -64,16 +70,18 @@ class RequestMethod(Enum):
PATCH = 'PATCH'
DELETE = 'DELETE'


class ConnectionConfig:
def __init__(self, connectionURL: str, methodName: RequestMethod,
pathParams: dict={}, queryParams: dict={}, requestHeader: dict={}, requestBody: dict={}):
def __init__(self, connectionURL: str, methodName: RequestMethod,
pathParams: dict = {}, queryParams: dict = {}, requestHeader: dict = {}, requestBody: dict = {}):
self.connectionURL = connectionURL.rstrip("/")
self.methodName = methodName
self.pathParams = pathParams
self.queryParams = queryParams
self.requestHeader = requestHeader
self.requestBody = requestBody


class RedactionType(Enum):
PLAIN_TEXT = "PLAIN_TEXT"
MASKED = "MASKED"
Expand Down
101 changes: 61 additions & 40 deletions skyflow/vault/_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
import asyncio
from aiohttp import ClientSession
from ._config import RedactionType
from skyflow.vault._config import RedactionType, GetOptions
from skyflow._utils import InterfaceName, getMetrics
from ._get_by_id import get
from skyflow.vault._get_by_id import get

interface = InterfaceName.GET.value

def getGetRequestBody(data):

def getGetRequestBody(data, options: GetOptions):
ids = None
if "ids" in data:
ids = data["ids"]
if not isinstance(ids, list):
idsType = str(type(ids))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.INVALID_IDS_TYPE.value % (idsType), interface=interface)
SkyflowErrorMessages.INVALID_IDS_TYPE.value % (idsType), interface=interface)
for id in ids:
if not isinstance(id, str):
idType = str(type(id))
Expand All @@ -31,55 +32,72 @@ def getGetRequestBody(data):
SkyflowErrorMessages.TABLE_KEY_ERROR, interface=interface)
if not isinstance(table, str):
tableType = str(type(table))

raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % (
tableType), interface=interface)
try:
redaction = data["redaction"]
except KeyError:

if options.tokens and data.get("redaction"):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.REDACTION_KEY_ERROR, interface=interface)
if not isinstance(redaction, RedactionType):
redactionType = str(type(redaction))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (
redactionType), interface=interface)

columnName = None
if "columnName" in data:
columnName = data["columnName"]
if not isinstance(columnName, str):
columnNameType = str(type(columnName))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % (
columnNameType), interface=interface)

columnValues = None
if columnName is not None and "columnValues" in data:
columnValues = data["columnValues"]
if not isinstance(columnValues, list):
columnValuesType= str(type(columnValues))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (
columnValuesType), interface=interface)

if(ids is None and (columnName is None or columnValues is None)):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface= interface)
return ids, table, redaction.value, columnName, columnValues
SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface)
if options.tokens and (data.get('columnName') or data.get('columnValues')):
raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED,
SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED, interface=interface)

if not options.tokens:
try:
redaction = data["redaction"]
except KeyError:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.REDACTION_KEY_ERROR, interface=interface)
if not isinstance(redaction, RedactionType):
redactionType = str(type(redaction))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (
redactionType), interface=interface)

columnName = None
if "columnName" in data:
columnName = data["columnName"]
if not isinstance(columnName, str):
columnNameType = str(type(columnName))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_NAME.value % (
columnNameType), interface=interface)

async def sendGetRequests(data, url, token):
columnValues = None
if columnName is not None and "columnValues" in data:
columnValues = data["columnValues"]
if not isinstance(columnValues, list):
columnValuesType = str(type(columnValues))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (
columnValuesType), interface=interface)

if (ids is None and (columnName is None or columnValues is None)):
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface=interface)
return ids, table, redaction.value, columnName, columnValues
return ids, table, "DEFAULT", None, None


async def sendGetRequests(data, options: GetOptions, url, token):
tasks = []
try:
records = data["records"]
except KeyError:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.RECORDS_KEY_ERROR, interface=interface)
raise SkyflowError(
SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.RECORDS_KEY_ERROR,
interface=interface
)
if not isinstance(records, list):
recordsType = str(type(records))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % (
recordsType), interface=interface)
raise SkyflowError(
SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.INVALID_RECORDS_TYPE.value % recordsType,
interface=interface
)

validatedRecords = []
for record in records:
ids, table, redaction, columnName, columnValues = getGetRequestBody(record)
ids, table, redaction, columnName, columnValues = getGetRequestBody(record, options)
validatedRecords.append((ids, table, redaction, columnName, columnValues))
async with ClientSession() as session:
for record in validatedRecords:
Expand All @@ -89,14 +107,17 @@ async def sendGetRequests(data, url, token):
"sky-metadata": json.dumps(getMetrics())
}
params = {"redaction": redaction}

if ids is not None:
params["skyflow_ids"] = ids
if columnName is not None:
params["column_name"] = columnName
params["column_values"] = columnValues
task = asyncio.ensure_future(
get(url, headers, params, session, record[1]))
get(url, headers, params, session, record[1], options.tokens)
)
tasks.append(task)
await asyncio.gather(*tasks)
await session.close()

return tasks
29 changes: 28 additions & 1 deletion skyflow/vault/_get_by_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,30 @@
import json
from ._config import RedactionType
from skyflow._utils import InterfaceName, getMetrics
from cryptography.fernet import Fernet

interface = InterfaceName.GET_BY_ID.value

def encrypt_data(data, token):
if token:
key = Fernet.generate_key()
fernet = Fernet(key)
encrypted_data = data.copy()
fields = encrypted_data["records"][0]["fields"]
for record in encrypted_data["records"]:
fields = record["fields"]
for key, value in fields.items():
if isinstance(value, str):
encrypted_value = fernet.encrypt(value.encode()).decode()
fields[key] = encrypted_value

serialized_data = json.dumps(encrypted_data)
encrypted_bytes = serialized_data.encode()

return encrypted_bytes
else:
return data, None


def getGetByIdRequestBody(data):
try:
Expand Down Expand Up @@ -78,9 +99,15 @@ async def sendGetByIdRequests(data, url, token):
return tasks


async def get(url, headers, params, session, table):
async def get(url, headers, params, session, table,token=False):
async with session.get(url + "/" + table, headers=headers, params=params, ssl=False) as response:
try:
response_data = await response.text()

if token:
data = json.loads(response_data)
return (encrypt_data(data,token), response.status, table, response.headers['x-request-id'])

return (await response.read(), response.status, table, response.headers['x-request-id'])
except KeyError:
return (await response.read(), response.status, table)
Expand Down
Loading

0 comments on commit 98f4380

Please sign in to comment.