Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SK-1731: Fix inconsistencies #137

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion samples/vault_api/invoke_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
invoke_connection_request = InvokeConnectionRequest(
method=Method.POST,
body=body,
request_headers=headers, # optional
headers=headers, # optional
path_params=path_params, # optional
query_params=query_params, # optional
)
Expand Down
2 changes: 1 addition & 1 deletion samples/vault_api/tokenize_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
# sample data
tokenize_values = [{"<VALUE_FIELD>": "<VALUE>", "<COLUMN_GROUP_FIELD>": "<VALUE>"}]

tokenize_request = TokenizeRequest(tokenize_parameters=tokenize_values)
tokenize_request = TokenizeRequest(values=tokenize_values)

response = skyflow_client.vault("VAULT_ID").tokenize(tokenize_request)

Expand Down
12 changes: 6 additions & 6 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep
url = parse_path_params(connection_url.rstrip('/'), request.path_params)

try:
if isinstance(request.request_headers, dict):
if isinstance(request.headers, dict):
header = to_lowercase_keys(json.loads(
json.dumps(request.request_headers)))
json.dumps(request.headers)))
else:
raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code)
except Exception:
Expand Down Expand Up @@ -224,7 +224,7 @@ def parse_insert_response(api_response, continue_on_error):
errors.append(error)

insert_response.inserted_fields = inserted_fields
insert_response.error_data = errors
insert_response.errors = errors

else:
for record in api_response.records:
Expand Down Expand Up @@ -255,20 +255,20 @@ def parse_delete_response(api_response: V1BulkDeleteRecordResponse):
delete_response = DeleteResponse()
deleted_ids = api_response.record_id_response
delete_response.deleted_ids = deleted_ids
delete_response.error = []
delete_response.errors = []
return delete_response


def parse_get_response(api_response: V1BulkGetRecordResponse):
get_response = GetResponse()
data = []
error = []
errors = []
for record in api_response.records:
field_data = {field: value for field, value in record.fields.items()}
data.append(field_data)

get_response.data = data
get_response.error = error
get_response.errors = errors

return get_response

Expand Down
4 changes: 2 additions & 2 deletions skyflow/vault/connection/_invoke_connection_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ def __init__(self,
body = None,
path_params = None,
query_params = None,
request_headers = None):
headers = None):
self.body = body if body is not None else {}
self.method = method
self.path_params = path_params if path_params is not None else {}
self.query_params = query_params if query_params is not None else {}
self.request_headers = request_headers if request_headers is not None else {}
self.headers = headers if headers is not None else {}
2 changes: 1 addition & 1 deletion skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def tokenize(self, request: TokenizeRequest):

records_list = [
V1TokenizeRecordRequest(value=item["value"], column_group=item["column_group"])
for item in request.tokenize_parameters
for item in request.values
]
payload = V1TokenizePayload(tokenization_parameters=records_list)
tokens_api = self.__vault_client.get_tokens_api()
Expand Down
6 changes: 3 additions & 3 deletions skyflow/vault/data/_delete_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class DeleteResponse:
def __init__(self, deleted_ids = None, error = None):
def __init__(self, deleted_ids = None, errors = None):
self.deleted_ids = deleted_ids
self.error = error
self.errors = errors

def __repr__(self):
return f"DeleteResponse(deleted_ids={self.deleted_ids}, error={self.error})"
return f"DeleteResponse(deleted_ids={self.deleted_ids}, errors={self.errors})"

def __str__(self):
return self.__repr__()
Expand Down
6 changes: 3 additions & 3 deletions skyflow/vault/data/_get_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class GetResponse:
def __init__(self, data=None, error = None):
def __init__(self, data=None, errors = None):
self.data = data if data else []
self.error = error
self.errors = errors

def __repr__(self):
return f"GetResponse(data={self.data}, error={self.error})"
return f"GetResponse(data={self.data}, errors={self.errors})"

def __str__(self):
return self.__repr__()
10 changes: 5 additions & 5 deletions skyflow/vault/data/_insert_response.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
class InsertResponse:
def __init__(self, inserted_fields = None, error_data=None):
if error_data is None:
error_data = list()
def __init__(self, inserted_fields = None, errors=None):
if errors is None:
errors = list()
self.inserted_fields = inserted_fields
self.error_data = error_data
self.errors = errors

def __repr__(self):
return f"InsertResponse(inserted_fields={self.inserted_fields}, error={self.error_data})"
return f"InsertResponse(inserted_fields={self.inserted_fields}, errors={self.errors})"

def __str__(self):
return self.__repr__()
4 changes: 2 additions & 2 deletions skyflow/vault/data/_query_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class QueryResponse:
def __init__(self):
self.fields = []
self.error = []
self.errors = []

def __repr__(self):
return f"QueryResponse(fields={self.fields}, error={self.error})"
return f"QueryResponse(fields={self.fields}, errors={self.errors})"

def __str__(self):
return self.__repr__()
6 changes: 3 additions & 3 deletions skyflow/vault/data/_update_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class UpdateResponse:
def __init__(self, updated_field = None, error=None):
def __init__(self, updated_field = None, errors=None):
self.updated_field = updated_field
self.error = error if error is not None else []
self.errors = errors if errors is not None else []

def __repr__(self):
return f"UpdateResponse(updated_field={self.updated_field}, error={self.error})"
return f"UpdateResponse(updated_field={self.updated_field}, errors={self.errors})"

def __str__(self):
return self.__repr__()
4 changes: 2 additions & 2 deletions skyflow/vault/tokens/_tokenize_request.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class TokenizeRequest:
def __init__(self, tokenize_parameters):
self.tokenize_parameters = tokenize_parameters
def __init__(self, values):
self.values = values
18 changes: 9 additions & 9 deletions tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_get_metrics(self):
def test_construct_invoke_connection_request_valid(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = {"Content-Type": ContentType.JSON.value}
mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
mock_connection_request.body = {"key": "value"}
mock_connection_request.method.value = "POST"
mock_connection_request.query_params = {"query": "test"}
Expand All @@ -115,7 +115,7 @@ def test_construct_invoke_connection_request_valid(self):
def test_construct_invoke_connection_request_with_invalid_headers(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = []
mock_connection_request.headers = []
mock_connection_request.body = {"key": "value"}
mock_connection_request.method.value = "POST"
mock_connection_request.query_params = {"query": "test"}
Expand All @@ -130,7 +130,7 @@ def test_construct_invoke_connection_request_with_invalid_headers(self):
def test_construct_invoke_connection_request_with_invalid_request_body(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = {"Content-Type": ContentType.JSON.value}
mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
mock_connection_request.body = []
mock_connection_request.method.value = "POST"
mock_connection_request.query_params = {"query": "test"}
Expand All @@ -144,7 +144,7 @@ def test_construct_invoke_connection_request_with_invalid_request_body(self):
def test_construct_invoke_connection_request_with_url_encoded_content_type(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = {"Content-Type": ContentType.URLENCODED.value}
mock_connection_request.headers = {"Content-Type": ContentType.URLENCODED.value}
mock_connection_request.body = {"key": "value"}
mock_connection_request.method.value = "POST"
mock_connection_request.query_params = {"query": "test"}
Expand All @@ -158,7 +158,7 @@ def test_construct_invoke_connection_request_with_url_encoded_content_type(self)
def test_construct_invoke_connection_request_with_form_date_content_type(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = {"Content-Type": ContentType.FORMDATA.value}
mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value}
mock_connection_request.body = {
"name": (None, "John Doe")
}
Expand All @@ -179,7 +179,7 @@ def test_parse_insert_response(self):
]
result = parse_insert_response(api_response, continue_on_error=True)
self.assertEqual(len(result.inserted_fields), 1)
self.assertEqual(len(result.error_data), 1)
self.assertEqual(len(result.errors), 1)

def test_parse_insert_response_continue_on_error_false(self):
mock_api_response = Mock()
Expand All @@ -198,7 +198,7 @@ def test_parse_insert_response_continue_on_error_false(self):
]
self.assertEqual(result.inserted_fields, expected_inserted_fields)

self.assertEqual(result.error_data, [])
self.assertEqual(result.errors, [])

def test_parse_update_record_response(self):
api_response = Mock()
Expand All @@ -219,7 +219,7 @@ def test_parse_delete_response_successful(self):
expected_deleted_ids = ["id_1", "id_2", "id_3"]
self.assertEqual(result.deleted_ids, expected_deleted_ids)

self.assertEqual(result.error, [])
self.assertEqual(result.errors, [])

def test_parse_get_response_successful(self):
mock_api_response = Mock()
Expand All @@ -238,7 +238,7 @@ def test_parse_get_response_successful(self):
]
self.assertEqual(result.data, expected_data)

self.assertEqual(result.error, [])
self.assertEqual(result.errors, [])

def test_parse_detokenize_response_with_mixed_records(self):
mock_api_response = Mock()
Expand Down
10 changes: 5 additions & 5 deletions tests/vault/controller/test__connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import Mock, patch, call
from unittest.mock import Mock, patch

from skyflow.error import SkyflowError
from skyflow.utils import SkyflowMessages
Expand Down Expand Up @@ -31,7 +31,7 @@ def test_invoke_success(self, mock_send):
method=Method.POST,
body=VALID_BODY,
path_params=VALID_PATH_PARAMS,
request_headers=VALID_HEADERS,
headers=VALID_HEADERS,
query_params=VALID_QUERY_PARAMS
)

Expand All @@ -48,7 +48,7 @@ def test_invoke_invalid_headers(self, mock_send):
method="POST",
body=VALID_BODY,
path_params=VALID_PATH_PARAMS,
request_headers=INVALID_HEADERS,
headers=INVALID_HEADERS,
query_params=VALID_QUERY_PARAMS
)

Expand All @@ -62,7 +62,7 @@ def test_invoke_invalid_body(self, mock_send):
method="POST",
body=INVALID_BODY,
path_params=VALID_PATH_PARAMS,
request_headers=VALID_HEADERS,
headers=VALID_HEADERS,
query_params=VALID_QUERY_PARAMS
)

Expand All @@ -81,7 +81,7 @@ def test_invoke_request_error(self, mock_send):
method=Method.POST,
body=VALID_BODY,
path_params=VALID_PATH_PARAMS,
request_headers=VALID_HEADERS,
headers=VALID_HEADERS,
query_params=VALID_QUERY_PARAMS
)

Expand Down
20 changes: 10 additions & 10 deletions tests/vault/controller/test__vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate)
expected_errors = [
{'request_index': 1, 'error': 'Insert error for record 2'}
]
expected_response = InsertResponse(inserted_fields=expected_inserted_fields, error_data=expected_errors)
expected_response = InsertResponse(inserted_fields=expected_inserted_fields, errors=expected_errors)

# Set the return value for the parse response
mock_parse_response.return_value = expected_response
Expand All @@ -83,7 +83,7 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate)

# Assert that the result matches the expected InsertResponse
self.assertEqual(result.inserted_fields, expected_inserted_fields)
self.assertEqual(result.error_data, expected_errors)
self.assertEqual(result.errors, expected_errors)

@patch("skyflow.vault.controller._vault.validate_insert_request")
@patch("skyflow.vault.controller._vault.parse_insert_response")
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val

# Assert that the result matches the expected InsertResponse
self.assertEqual(result.inserted_fields, expected_inserted_fields)
self.assertEqual(result.error_data, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_update_request")
@patch("skyflow.vault.controller._vault.parse_update_record_response")
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_update_successful(self, mock_parse_response, mock_validate):

# Check that the result matches the expected UpdateResponse
self.assertEqual(result.updated_field, expected_updated_field)
self.assertEqual(result.error, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_delete_request")
@patch("skyflow.vault.controller._vault.parse_delete_response")
Expand All @@ -212,7 +212,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate):

# Expected parsed response
expected_deleted_ids = ["12345", "67890"]
expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, error=[])
expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, errors=[])

# Set the return value for the parse response
mock_parse_response.return_value = expected_response
Expand All @@ -233,7 +233,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate):

# Check that the result matches the expected DeleteResponse
self.assertEqual(result.deleted_ids, expected_deleted_ids)
self.assertEqual(result.error, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_get_request")
@patch("skyflow.vault.controller._vault.parse_get_response")
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_get_successful(self, mock_parse_response, mock_validate):
{"field1": "value1", "field2": "value2"},
{"field1": "value3", "field2": "value4"}
]
expected_response = GetResponse(data=expected_data, error=[])
expected_response = GetResponse(data=expected_data, errors=[])

# Set the return value for parse_get_response
mock_parse_response.return_value = expected_response
Expand All @@ -298,7 +298,7 @@ def test_get_successful(self, mock_parse_response, mock_validate):

# Check that the result matches the expected GetResponse
self.assertEqual(result.data, expected_data)
self.assertEqual(result.error, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_query_request")
@patch("skyflow.vault.controller._vault.parse_query_response")
Expand Down Expand Up @@ -344,7 +344,7 @@ def test_query_successful(self, mock_parse_response, mock_validate):

# Check that the result matches the expected QueryResponse
self.assertEqual(result.fields, expected_fields)
self.assertEqual(result.error, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_detokenize_request")
@patch("skyflow.vault.controller._vault.parse_detokenize_response")
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_tokenize_successful(self, mock_parse_response, mock_validate):

# Mock request with tokenization parameters
request = TokenizeRequest(
tokenize_parameters=[
values=[
{"value": "value1", "column_group": "group1"},
{"value": "value2", "column_group": "group2"}
]
Expand Down
Loading