Skip to content

Commit

Permalink
SK-1731: Fixed test cases for inconsistencies
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshwar-skyflow committed Nov 20, 2024
1 parent 6cc94d7 commit 43067e6
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 43 deletions.
12 changes: 6 additions & 6 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep
url = parse_path_params(connection_url.rstrip('/'), request.path_params)

try:
if isinstance(request.request_headers, dict):
if isinstance(request.headers, dict):
header = to_lowercase_keys(json.loads(
json.dumps(request.request_headers)))
json.dumps(request.headers)))
else:
raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code)
except Exception:
Expand Down Expand Up @@ -224,7 +224,7 @@ def parse_insert_response(api_response, continue_on_error):
errors.append(error)

insert_response.inserted_fields = inserted_fields
insert_response.error_data = errors
insert_response.errors = errors

else:
for record in api_response.records:
Expand Down Expand Up @@ -255,20 +255,20 @@ def parse_delete_response(api_response: V1BulkDeleteRecordResponse):
delete_response = DeleteResponse()
deleted_ids = api_response.record_id_response
delete_response.deleted_ids = deleted_ids
delete_response.error = []
delete_response.errors = []
return delete_response


def parse_get_response(api_response: V1BulkGetRecordResponse):
get_response = GetResponse()
data = []
error = []
errors = []
for record in api_response.records:
field_data = {field: value for field, value in record.fields.items()}
data.append(field_data)

get_response.data = data
get_response.error = error
get_response.errors = errors

return get_response

Expand Down
2 changes: 1 addition & 1 deletion skyflow/vault/connection/_invoke_connection_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def __init__(self,
self.method = method
self.path_params = path_params if path_params is not None else {}
self.query_params = query_params if query_params is not None else {}
self.request_headers = headers if headers is not None else {}
self.headers = headers if headers is not None else {}
2 changes: 1 addition & 1 deletion skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def tokenize(self, request: TokenizeRequest):

records_list = [
V1TokenizeRecordRequest(value=item["value"], column_group=item["column_group"])
for item in request.tokenize_parameters
for item in request.values
]
payload = V1TokenizePayload(tokenization_parameters=records_list)
tokens_api = self.__vault_client.get_tokens_api()
Expand Down
6 changes: 3 additions & 3 deletions skyflow/vault/data/_delete_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class DeleteResponse:
def __init__(self, deleted_ids = None, error = None):
def __init__(self, deleted_ids = None, errors = None):
self.deleted_ids = deleted_ids
self.error = error
self.errors = errors

def __repr__(self):
return f"DeleteResponse(deleted_ids={self.deleted_ids}, errors={self.error})"
return f"DeleteResponse(deleted_ids={self.deleted_ids}, errors={self.errors})"

def __str__(self):
return self.__repr__()
Expand Down
6 changes: 3 additions & 3 deletions skyflow/vault/data/_get_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class GetResponse:
def __init__(self, data=None, error = None):
def __init__(self, data=None, errors = None):
self.data = data if data else []
self.error = error
self.errors = errors

def __repr__(self):
return f"GetResponse(data={self.data}, errors={self.error})"
return f"GetResponse(data={self.data}, errors={self.errors})"

def __str__(self):
return self.__repr__()
10 changes: 5 additions & 5 deletions skyflow/vault/data/_insert_response.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
class InsertResponse:
def __init__(self, inserted_fields = None, error_data=None):
if error_data is None:
error_data = list()
def __init__(self, inserted_fields = None, errors=None):
if errors is None:
errors = list()
self.inserted_fields = inserted_fields
self.error_data = error_data
self.errors = errors

def __repr__(self):
return f"InsertResponse(inserted_fields={self.inserted_fields}, errors={self.error_data})"
return f"InsertResponse(inserted_fields={self.inserted_fields}, errors={self.errors})"

def __str__(self):
return self.__repr__()
4 changes: 2 additions & 2 deletions skyflow/vault/data/_query_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class QueryResponse:
def __init__(self):
self.fields = []
self.error = []
self.errors = []

def __repr__(self):
return f"QueryResponse(fields={self.fields}, errors={self.error})"
return f"QueryResponse(fields={self.fields}, errors={self.errors})"

def __str__(self):
return self.__repr__()
6 changes: 3 additions & 3 deletions skyflow/vault/data/_update_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class UpdateResponse:
def __init__(self, updated_field = None, error=None):
def __init__(self, updated_field = None, errors=None):
self.updated_field = updated_field
self.error = error if error is not None else []
self.errors = errors if errors is not None else []

def __repr__(self):
return f"UpdateResponse(updated_field={self.updated_field}, errors={self.error})"
return f"UpdateResponse(updated_field={self.updated_field}, errors={self.errors})"

def __str__(self):
return self.__repr__()
2 changes: 1 addition & 1 deletion skyflow/vault/tokens/_tokenize_request.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class TokenizeRequest:
def __init__(self, values):
self.tokenize_parameters = values
self.values = values
18 changes: 9 additions & 9 deletions tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_get_metrics(self):
def test_construct_invoke_connection_request_valid(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = {"Content-Type": ContentType.JSON.value}
mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
mock_connection_request.body = {"key": "value"}
mock_connection_request.method.value = "POST"
mock_connection_request.query_params = {"query": "test"}
Expand All @@ -115,7 +115,7 @@ def test_construct_invoke_connection_request_valid(self):
def test_construct_invoke_connection_request_with_invalid_headers(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = []
mock_connection_request.headers = []
mock_connection_request.body = {"key": "value"}
mock_connection_request.method.value = "POST"
mock_connection_request.query_params = {"query": "test"}
Expand All @@ -130,7 +130,7 @@ def test_construct_invoke_connection_request_with_invalid_headers(self):
def test_construct_invoke_connection_request_with_invalid_request_body(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = {"Content-Type": ContentType.JSON.value}
mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
mock_connection_request.body = []
mock_connection_request.method.value = "POST"
mock_connection_request.query_params = {"query": "test"}
Expand All @@ -144,7 +144,7 @@ def test_construct_invoke_connection_request_with_invalid_request_body(self):
def test_construct_invoke_connection_request_with_url_encoded_content_type(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = {"Content-Type": ContentType.URLENCODED.value}
mock_connection_request.headers = {"Content-Type": ContentType.URLENCODED.value}
mock_connection_request.body = {"key": "value"}
mock_connection_request.method.value = "POST"
mock_connection_request.query_params = {"query": "test"}
Expand All @@ -158,7 +158,7 @@ def test_construct_invoke_connection_request_with_url_encoded_content_type(self)
def test_construct_invoke_connection_request_with_form_date_content_type(self):
mock_connection_request = Mock()
mock_connection_request.path_params = {"param1": "value1"}
mock_connection_request.request_headers = {"Content-Type": ContentType.FORMDATA.value}
mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value}
mock_connection_request.body = {
"name": (None, "John Doe")
}
Expand All @@ -179,7 +179,7 @@ def test_parse_insert_response(self):
]
result = parse_insert_response(api_response, continue_on_error=True)
self.assertEqual(len(result.inserted_fields), 1)
self.assertEqual(len(result.error_data), 1)
self.assertEqual(len(result.errors), 1)

def test_parse_insert_response_continue_on_error_false(self):
mock_api_response = Mock()
Expand All @@ -198,7 +198,7 @@ def test_parse_insert_response_continue_on_error_false(self):
]
self.assertEqual(result.inserted_fields, expected_inserted_fields)

self.assertEqual(result.error_data, [])
self.assertEqual(result.errors, [])

def test_parse_update_record_response(self):
api_response = Mock()
Expand All @@ -219,7 +219,7 @@ def test_parse_delete_response_successful(self):
expected_deleted_ids = ["id_1", "id_2", "id_3"]
self.assertEqual(result.deleted_ids, expected_deleted_ids)

self.assertEqual(result.error, [])
self.assertEqual(result.errors, [])

def test_parse_get_response_successful(self):
mock_api_response = Mock()
Expand All @@ -238,7 +238,7 @@ def test_parse_get_response_successful(self):
]
self.assertEqual(result.data, expected_data)

self.assertEqual(result.error, [])
self.assertEqual(result.errors, [])

def test_parse_detokenize_response_with_mixed_records(self):
mock_api_response = Mock()
Expand Down
18 changes: 9 additions & 9 deletions tests/vault/controller/test__vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate)
expected_errors = [
{'request_index': 1, 'error': 'Insert error for record 2'}
]
expected_response = InsertResponse(inserted_fields=expected_inserted_fields, error_data=expected_errors)
expected_response = InsertResponse(inserted_fields=expected_inserted_fields, errors=expected_errors)

# Set the return value for the parse response
mock_parse_response.return_value = expected_response
Expand All @@ -83,7 +83,7 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate)

# Assert that the result matches the expected InsertResponse
self.assertEqual(result.inserted_fields, expected_inserted_fields)
self.assertEqual(result.error_data, expected_errors)
self.assertEqual(result.errors, expected_errors)

@patch("skyflow.vault.controller._vault.validate_insert_request")
@patch("skyflow.vault.controller._vault.parse_insert_response")
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val

# Assert that the result matches the expected InsertResponse
self.assertEqual(result.inserted_fields, expected_inserted_fields)
self.assertEqual(result.error_data, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_update_request")
@patch("skyflow.vault.controller._vault.parse_update_record_response")
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_update_successful(self, mock_parse_response, mock_validate):

# Check that the result matches the expected UpdateResponse
self.assertEqual(result.updated_field, expected_updated_field)
self.assertEqual(result.error, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_delete_request")
@patch("skyflow.vault.controller._vault.parse_delete_response")
Expand All @@ -212,7 +212,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate):

# Expected parsed response
expected_deleted_ids = ["12345", "67890"]
expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, error=[])
expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, errors=[])

# Set the return value for the parse response
mock_parse_response.return_value = expected_response
Expand All @@ -233,7 +233,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate):

# Check that the result matches the expected DeleteResponse
self.assertEqual(result.deleted_ids, expected_deleted_ids)
self.assertEqual(result.error, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_get_request")
@patch("skyflow.vault.controller._vault.parse_get_response")
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_get_successful(self, mock_parse_response, mock_validate):
{"field1": "value1", "field2": "value2"},
{"field1": "value3", "field2": "value4"}
]
expected_response = GetResponse(data=expected_data, error=[])
expected_response = GetResponse(data=expected_data, errors=[])

# Set the return value for parse_get_response
mock_parse_response.return_value = expected_response
Expand All @@ -298,7 +298,7 @@ def test_get_successful(self, mock_parse_response, mock_validate):

# Check that the result matches the expected GetResponse
self.assertEqual(result.data, expected_data)
self.assertEqual(result.error, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_query_request")
@patch("skyflow.vault.controller._vault.parse_query_response")
Expand Down Expand Up @@ -344,7 +344,7 @@ def test_query_successful(self, mock_parse_response, mock_validate):

# Check that the result matches the expected QueryResponse
self.assertEqual(result.fields, expected_fields)
self.assertEqual(result.error, []) # No errors expected
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_detokenize_request")
@patch("skyflow.vault.controller._vault.parse_detokenize_response")
Expand Down

0 comments on commit 43067e6

Please sign in to comment.