Skip to content

Commit

Permalink
SK-1648: Fixed client initialization and optimized vault controller
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshwar-skyflow committed Oct 9, 2024
1 parent a04ef22 commit e51cd92
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 132 deletions.
22 changes: 8 additions & 14 deletions skyflow/client/skyflow.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from collections import OrderedDict

from skyflow import LogLevel
from skyflow.error import SkyflowError
from skyflow.utils.validations import validate_vault_config, validate_connection_config
from skyflow.vault.client.client import VaultClient
from skyflow.vault.controller import Vault
from skyflow.vault.controller import Connection
from skyflow.vault.manager.vault import VaultManager


class Skyflow:
def __init__(self, builder):
Expand Down Expand Up @@ -109,13 +106,10 @@ def update_vault_config(self, config):

def get_vault_config(self, vault_id):
if vault_id in self.__vault_configs.keys():
vault_config = self.__vault_configs[vault_id]
return vault_config.get("vault_client").get_config()
vault_config = self.__vault_configs.get(vault_id)
return vault_config
raise SkyflowError(f"Vault config with id {vault_id} not found")

def get_vault_configs(self):
return self.__vault_configs

def add_connection_config(self, config):
if validate_connection_config(config) and config["connection_id"] not in self.__connection_configs.keys():
connection_id = config.get("connection_id")
Expand Down Expand Up @@ -147,23 +141,23 @@ def update_connection_config(self, config):

def get_connection_config(self, connection_id):
if connection_id in self.__connection_configs.keys():
vault_config = self.__connection_configs[connection_id]
return vault_config.get("vault_client").get_config()
connection_config = self.__connection_configs[connection_id]
return connection_config
raise SkyflowError(f"Connection config with id {connection_id} not found")

def add_skyflow_credentials(self, credentials):
for vault_config in self.__vault_configs:
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_common_skyflow_credentials(credentials)

for connection_config in self.__connection_configs:
for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_common_skyflow_credentials(credentials)
return self

def set_log_level(self, log_level):
for vault_config in self.__vault_configs:
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_log_level(log_level)

for connection_config in self.__connection_configs:
for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_log_level(log_level)
return self

Expand Down
14 changes: 10 additions & 4 deletions skyflow/service_account/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import datetime
from math import expm1

import time
import jwt
from skyflow.error import SkyflowError
from skyflow.generated.rest.models import V1GetAuthTokenRequest
Expand Down Expand Up @@ -140,7 +139,14 @@ def generate_signed_data_tokens(credentials_file_path, options):
except Exception:
raise SkyflowError("Invalid file path")

return get_signed_tokens(credentials, options)
return get_signed_tokens(credentials_file_path, options)

def generate_signed_data_tokens_from_creds(credentials, options):
return get_signed_tokens(credentials, options)
return get_signed_tokens(credentials, options)

def get_signed_data_token_response_object(signed_token, actual_token):
response_object = {
"token": actual_token,
"signed_token": signed_token
}
return response_object
2 changes: 1 addition & 1 deletion skyflow/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ..utils.enums import LogLevel, Env
from ._utils import get_credentials, get_vault_url, get_client_configuration, get_base_url, format_scope, get_redaction_type, construct_invoke_connection_request
from ._utils import get_credentials, get_vault_url, get_client_configuration, get_base_url, format_scope, get_redaction_type, construct_invoke_connection_request, build_field_records
98 changes: 94 additions & 4 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@

import os
import json
from urllib.parse import urlparse

from skyflow.generated.rest import RedactionEnumREDACTION
from skyflow.utils.enums import Env
import urllib.parse
import requests
from skyflow.error import SkyflowError
from skyflow.generated.rest import RedactionEnumREDACTION, V1FieldRecords
from skyflow.utils.enums import Env, ContentType
import skyflow.generated.rest as vault_client

def get_credentials(config_level_creds = None, common_skyflow_creds = None):
Expand Down Expand Up @@ -51,3 +53,91 @@ def get_redaction_type(redaction_type):
if redaction_type == "plain-text":
return RedactionEnumREDACTION.PLAIN_TEXT

def parse_path_params(url, path_params):
result = url
for param, value in path_params.items():
result = result.replace('{' + param + '}', value)

return result

def construct_invoke_connection_request(request, connection_url):
url = parse_path_params(connection_url.rstrip('/'), connection_url.pathParams)
header = dict()
header['content-type'] = ContentType.JSON

try:
if isinstance(request.body, dict):
json_data, files = get_data_from_content_type(
request.body, header["content-type"]
)
else:
raise SyntaxError("Given response body is not valid")
except Exception as e:
raise SyntaxError("Given request body is not valid")

try:
return requests.Request(
method = request.method,
url = url,
data = json_data,
headers = header,
params = request.params,
files = files
).prepare()
except requests.exceptions.InvalidURL:
raise SkyflowError("Invalid URL")


def http_build_query(data):
'''
Creates a form urlencoded string from python dictionary
urllib.urlencode() doesn't encode it in a php-esque way, this function helps in that
'''

return urllib.parse.urlencode(r_urlencode(list(), dict(), data))

def r_urlencode(parents, pairs, data):
'''
convert the python dict recursively into a php style associative dictionary
'''
if isinstance(data, list) or isinstance(data, tuple):
for i in range(len(data)):
parents.append(i)
r_urlencode(parents, pairs, data[i])
parents.pop()
elif isinstance(data, dict):
for key, value in data.items():
parents.append(key)
r_urlencode(parents, pairs, value)
parents.pop()
else:
pairs[render_key(parents)] = str(data)

return pairs

def render_key(parents):
'''
renders the nested dictionary key as an associative array (php style dict)
'''
depth, out_str = 0, ''
for x in parents:
s = "[%s]" if depth > 0 or isinstance(x, int) else "%s"
out_str += s % str(x)
depth += 1
return out_str

def get_data_from_content_type(data, content_type):
'''
Get request data according to content type
'''
converted_data = data
files = {}
if content_type == ContentType.URLENCODED:
converted_data = http_build_query(data)
elif content_type == ContentType.FORMDATA:
converted_data = r_urlencode(list(), dict(), data)
files = {(None, None)}
elif content_type == ContentType.JSON:
converted_data = json.dumps(data)

return converted_data, files
4 changes: 2 additions & 2 deletions skyflow/vault/connection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from _invoke_connection_request import InvokeConnectionRequest
from _invoke_connection_response import InvokeConnectionResponse
from ._invoke_connection_request import InvokeConnectionRequest
from ._invoke_connection_response import InvokeConnectionResponse
4 changes: 0 additions & 4 deletions skyflow/vault/controller/_connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import requests

from samples.v2sample import response
from skyflow.utils import construct_invoke_connection_request
from skyflow.vault.connection import InvokeConnectionRequest, InvokeConnectionResponse

Expand All @@ -11,8 +9,6 @@ def __init__(self, vault_client):
self.__vault_client = vault_client

def invoke(self, request: InvokeConnectionRequest):
#generate token

session = requests.Session()
config = self.__vault_client.get_config()
bearer_token = self.__vault_client.get_bearer_token(config.get("credentials"))
Expand Down
Loading

0 comments on commit e51cd92

Please sign in to comment.