diff --git a/samples/vault_api/detokenize_records.py b/samples/vault_api/detokenize_records.py index 2c931c2..192c7e7 100644 --- a/samples/vault_api/detokenize_records.py +++ b/samples/vault_api/detokenize_records.py @@ -1,6 +1,7 @@ import json from skyflow import Env from skyflow import Skyflow, LogLevel +from skyflow.utils.enums import RedactionType from skyflow.vault.tokens import DetokenizeRequest # To generate Bearer Token from credentials string. @@ -43,6 +44,7 @@ detokenize_request = DetokenizeRequest( tokens=detokenize_data, + redaction_type = RedactionType.PLAIN_TEXT ) response = client.vault('VAULT_ID').detokenize(detokenize_request) diff --git a/samples/vault_api/invoke_connection.py b/samples/vault_api/invoke_connection.py index a2243fa..f1d7f50 100644 --- a/samples/vault_api/invoke_connection.py +++ b/samples/vault_api/invoke_connection.py @@ -55,7 +55,7 @@ invoke_connection_request = InvokeConnectionRequest( method=Method.POST, body=body, - request_headers=headers, # optional + headers=headers, # optional path_params=path_params, # optional query_params=query_params, # optional ) diff --git a/samples/vault_api/tokenize_records.py b/samples/vault_api/tokenize_records.py index 0ca8e97..3cf3f65 100644 --- a/samples/vault_api/tokenize_records.py +++ b/samples/vault_api/tokenize_records.py @@ -49,7 +49,7 @@ # sample data tokenize_values = [{"": "", "": ""}] -tokenize_request = TokenizeRequest(tokenize_parameters=tokenize_values) +tokenize_request = TokenizeRequest(values=tokenize_values) response = skyflow_client.vault("VAULT_ID").tokenize(tokenize_request) diff --git a/setup.py b/setup.py index 46ae383..08f8a6a 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ if sys.version_info < (3, 7): raise RuntimeError("skyflow requires Python 3.7+") -current_version = '1.15.1.dev0+ddb228b' +current_version = '1.15.1.dev0+0c98734' setup( name='skyflow', diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index bb841ca..fda915a 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -78,9 +78,9 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep url = parse_path_params(connection_url.rstrip('/'), request.path_params) try: - if isinstance(request.request_headers, dict): + if isinstance(request.headers, dict): header = to_lowercase_keys(json.loads( - json.dumps(request.request_headers))) + json.dumps(request.headers))) else: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) except Exception: @@ -224,7 +224,7 @@ def parse_insert_response(api_response, continue_on_error): errors.append(error) insert_response.inserted_fields = inserted_fields - insert_response.error_data = errors + insert_response.errors = errors else: for record in api_response.records: @@ -255,20 +255,20 @@ def parse_delete_response(api_response: V1BulkDeleteRecordResponse): delete_response = DeleteResponse() deleted_ids = api_response.record_id_response delete_response.deleted_ids = deleted_ids - delete_response.error = [] + delete_response.errors = [] return delete_response def parse_get_response(api_response: V1BulkGetRecordResponse): get_response = GetResponse() data = [] - error = [] + errors = [] for record in api_response.records: field_data = {field: value for field, value in record.fields.items()} data.append(field_data) get_response.data = data - get_response.error = error + get_response.errors = errors return get_response diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 9b42e1f..840572e 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '1.15.1.dev0+ddb228b' \ No newline at end of file +SDK_VERSION = '1.15.1.dev0+0c98734' \ No newline at end of file diff --git a/skyflow/utils/enums/__init__.py b/skyflow/utils/enums/__init__.py index 6a010ba..e30a799 100644 --- a/skyflow/utils/enums/__init__.py +++ b/skyflow/utils/enums/__init__.py @@ -3,4 +3,4 @@ from .content_types import ContentType from .token_strict import TokenStrict from .method import Method -from .redaction_type import Redaction \ No newline at end of file +from .redaction_type import RedactionType \ No newline at end of file diff --git a/skyflow/utils/enums/redaction_type.py b/skyflow/utils/enums/redaction_type.py index 13ac806..8531004 100644 --- a/skyflow/utils/enums/redaction_type.py +++ b/skyflow/utils/enums/redaction_type.py @@ -1,14 +1,8 @@ +from enum import Enum from skyflow.generated.rest import RedactionEnumREDACTION - -class Redaction: - @staticmethod - def to_redaction_enum(value): - if value == "plain-text": - return RedactionEnumREDACTION.PLAIN_TEXT - elif value == "masked": - return RedactionEnumREDACTION.MASKED - elif value == "default": - return RedactionEnumREDACTION.DEFAULT - elif value == "redacted": - return RedactionEnumREDACTION.REDACTED +class RedactionType(Enum): + PLAIN_TEXT = RedactionEnumREDACTION.PLAIN_TEXT + MASKED = RedactionEnumREDACTION.MASKED + DEFAULT = RedactionEnumREDACTION.DEFAULT + REDACTED = RedactionEnumREDACTION.REDACTED diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 738b549..c1dbe3f 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -1,9 +1,7 @@ import json import re - -from skyflow.generated.rest import RedactionEnumREDACTION from skyflow.service_account import is_expired -from skyflow.utils.enums import LogLevel, TokenStrict, Redaction, Env +from skyflow.utils.enums import LogLevel, TokenStrict, Env, RedactionType from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages from skyflow.utils.logger import log_info, log_error_log @@ -391,7 +389,7 @@ def validate_get_request(logger, request): if not isinstance(request.return_tokens, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code) - if redaction_type is not None and not isinstance(redaction_type, RedactionEnumREDACTION): + if redaction_type is not None and not isinstance(redaction_type, RedactionType): raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(redaction_type)), invalid_input_error_code) if fields is not None and (not isinstance(fields, list) or not fields): @@ -505,7 +503,7 @@ def validate_update_request(logger, request): raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code) def validate_detokenize_request(logger, request): - if not isinstance(request.redaction_type, RedactionEnumREDACTION): + if not isinstance(request.redaction_type, RedactionType): raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(request.redaction_type)), invalid_input_error_code) if not isinstance(request.continue_on_error, bool): diff --git a/skyflow/vault/connection/_invoke_connection_request.py b/skyflow/vault/connection/_invoke_connection_request.py index 25d9ec0..9634dfb 100644 --- a/skyflow/vault/connection/_invoke_connection_request.py +++ b/skyflow/vault/connection/_invoke_connection_request.py @@ -4,9 +4,9 @@ def __init__(self, body = None, path_params = None, query_params = None, - request_headers = None): + headers = None): self.body = body if body is not None else {} self.method = method self.path_params = path_params if path_params is not None else {} self.query_params = query_params if query_params is not None else {} - self.request_headers = request_headers if request_headers is not None else {} \ No newline at end of file + self.headers = headers if headers is not None else {} \ No newline at end of file diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 2361f06..ebdc924 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -83,7 +83,6 @@ def insert(self, request: InsertRequest): if request.continue_on_error: api_response = records_api.record_service_batch_operation(self.__vault_client.get_vault_id(), insert_body) - print("respomse: ", api_response) else: api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(), @@ -163,7 +162,7 @@ def get(self, request: GetRequest): self.__vault_client.get_vault_id(), object_name=request.table, skyflow_ids=request.ids, - redaction=request.redaction_type, + redaction = request.redaction_type.value if request.redaction_type is not None else None, tokenization=request.return_tokens, fields=request.fields, offset=request.offset, @@ -211,7 +210,7 @@ def detokenize(self, request: DetokenizeRequest): log_info(SkyflowMessages.Info.DETOKENIZE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() tokens_list = [ - V1DetokenizeRecordRequest(token=token, redaction=request.redaction_type) + V1DetokenizeRecordRequest(token=token, redaction=request.redaction_type.value) for token in request.tokens ] payload = V1DetokenizePayload(detokenization_parameters=tokens_list, continue_on_error=request.continue_on_error) @@ -241,7 +240,7 @@ def tokenize(self, request: TokenizeRequest): records_list = [ V1TokenizeRecordRequest(value=item["value"], column_group=item["column_group"]) - for item in request.tokenize_parameters + for item in request.values ] payload = V1TokenizePayload(tokenization_parameters=records_list) tokens_api = self.__vault_client.get_tokens_api() diff --git a/skyflow/vault/data/_delete_response.py b/skyflow/vault/data/_delete_response.py index ca86624..0147c77 100644 --- a/skyflow/vault/data/_delete_response.py +++ b/skyflow/vault/data/_delete_response.py @@ -1,10 +1,10 @@ class DeleteResponse: - def __init__(self, deleted_ids = None, error = None): + def __init__(self, deleted_ids = None, errors = None): self.deleted_ids = deleted_ids - self.error = error + self.errors = errors def __repr__(self): - return f"DeleteResponse(deleted_ids={self.deleted_ids}, error={self.error})" + return f"DeleteResponse(deleted_ids={self.deleted_ids}, errors={self.errors})" def __str__(self): return self.__repr__() diff --git a/skyflow/vault/data/_get_request.py b/skyflow/vault/data/_get_request.py index f7238de..81cb21a 100644 --- a/skyflow/vault/data/_get_request.py +++ b/skyflow/vault/data/_get_request.py @@ -1,6 +1,3 @@ -from skyflow.utils.enums import Redaction - - class GetRequest: def __init__(self, table, @@ -15,7 +12,7 @@ def __init__(self, column_values = None): self.table = table self.ids = ids - self.redaction_type = Redaction.to_redaction_enum(redaction_type) + self.redaction_type = redaction_type self.return_tokens = return_tokens self.fields = fields self.offset = offset diff --git a/skyflow/vault/data/_get_response.py b/skyflow/vault/data/_get_response.py index b84c7c2..cf1b080 100644 --- a/skyflow/vault/data/_get_response.py +++ b/skyflow/vault/data/_get_response.py @@ -1,10 +1,10 @@ class GetResponse: - def __init__(self, data=None, error = None): + def __init__(self, data=None, errors = None): self.data = data if data else [] - self.error = error + self.errors = errors def __repr__(self): - return f"GetResponse(data={self.data}, error={self.error})" + return f"GetResponse(data={self.data}, errors={self.errors})" def __str__(self): return self.__repr__() \ No newline at end of file diff --git a/skyflow/vault/data/_insert_response.py b/skyflow/vault/data/_insert_response.py index fbddc0c..6407426 100644 --- a/skyflow/vault/data/_insert_response.py +++ b/skyflow/vault/data/_insert_response.py @@ -1,12 +1,12 @@ class InsertResponse: - def __init__(self, inserted_fields = None, error_data=None): - if error_data is None: - error_data = list() + def __init__(self, inserted_fields = None, errors=None): + if errors is None: + errors = list() self.inserted_fields = inserted_fields - self.error_data = error_data + self.errors = errors def __repr__(self): - return f"InsertResponse(inserted_fields={self.inserted_fields}, error={self.error_data})" + return f"InsertResponse(inserted_fields={self.inserted_fields}, errors={self.errors})" def __str__(self): return self.__repr__() diff --git a/skyflow/vault/data/_query_response.py b/skyflow/vault/data/_query_response.py index 1a98cf1..e203475 100644 --- a/skyflow/vault/data/_query_response.py +++ b/skyflow/vault/data/_query_response.py @@ -1,10 +1,10 @@ class QueryResponse: def __init__(self): self.fields = [] - self.error = [] + self.errors = [] def __repr__(self): - return f"QueryResponse(fields={self.fields}, error={self.error})" + return f"QueryResponse(fields={self.fields}, errors={self.errors})" def __str__(self): return self.__repr__() diff --git a/skyflow/vault/data/_update_response.py b/skyflow/vault/data/_update_response.py index 6cdedf1..dbbb9cc 100644 --- a/skyflow/vault/data/_update_response.py +++ b/skyflow/vault/data/_update_response.py @@ -1,10 +1,10 @@ class UpdateResponse: - def __init__(self, updated_field = None, error=None): + def __init__(self, updated_field = None, errors=None): self.updated_field = updated_field - self.error = error if error is not None else [] + self.errors = errors if errors is not None else [] def __repr__(self): - return f"UpdateResponse(updated_field={self.updated_field}, error={self.error})" + return f"UpdateResponse(updated_field={self.updated_field}, errors={self.errors})" def __str__(self): return self.__repr__() diff --git a/skyflow/vault/tokens/_detokenize_request.py b/skyflow/vault/tokens/_detokenize_request.py index 82b5630..5e3bc04 100644 --- a/skyflow/vault/tokens/_detokenize_request.py +++ b/skyflow/vault/tokens/_detokenize_request.py @@ -1,7 +1,7 @@ -from skyflow.utils.enums import Redaction +from skyflow.utils.enums.redaction_type import RedactionType class DetokenizeRequest: - def __init__(self, tokens, redaction_type = "plain-text", continue_on_error = False): + def __init__(self, tokens, redaction_type = RedactionType.PLAIN_TEXT, continue_on_error = False): self.tokens = tokens - self.redaction_type = Redaction.to_redaction_enum(redaction_type) + self.redaction_type = redaction_type self.continue_on_error = continue_on_error \ No newline at end of file diff --git a/skyflow/vault/tokens/_tokenize_request.py b/skyflow/vault/tokens/_tokenize_request.py index 6e92c6b..a1e7c2b 100644 --- a/skyflow/vault/tokens/_tokenize_request.py +++ b/skyflow/vault/tokens/_tokenize_request.py @@ -1,3 +1,3 @@ class TokenizeRequest: - def __init__(self, tokenize_parameters): - self.tokenize_parameters = tokenize_parameters + def __init__(self, values): + self.values = values diff --git a/skyflow/vault/tokens/_tokenize_response.py b/skyflow/vault/tokens/_tokenize_response.py index dd1e855..264b398 100644 --- a/skyflow/vault/tokens/_tokenize_response.py +++ b/skyflow/vault/tokens/_tokenize_response.py @@ -4,7 +4,7 @@ def __init__(self, tokenized_fields = None): def __repr__(self): - return f"InsertResponse(tokenized_fields={self.tokenized_fields})" + return f"TokenizeResponse(tokenized_fields={self.tokenized_fields})" def __str__(self): return self.__repr__() diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index ac5b1f2..de5643e 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -93,7 +93,7 @@ def test_get_metrics(self): def test_construct_invoke_connection_request_valid(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} - mock_connection_request.request_headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} mock_connection_request.body = {"key": "value"} mock_connection_request.method.value = "POST" mock_connection_request.query_params = {"query": "test"} @@ -115,7 +115,7 @@ def test_construct_invoke_connection_request_valid(self): def test_construct_invoke_connection_request_with_invalid_headers(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} - mock_connection_request.request_headers = [] + mock_connection_request.headers = [] mock_connection_request.body = {"key": "value"} mock_connection_request.method.value = "POST" mock_connection_request.query_params = {"query": "test"} @@ -130,7 +130,7 @@ def test_construct_invoke_connection_request_with_invalid_headers(self): def test_construct_invoke_connection_request_with_invalid_request_body(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} - mock_connection_request.request_headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} mock_connection_request.body = [] mock_connection_request.method.value = "POST" mock_connection_request.query_params = {"query": "test"} @@ -144,7 +144,7 @@ def test_construct_invoke_connection_request_with_invalid_request_body(self): def test_construct_invoke_connection_request_with_url_encoded_content_type(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} - mock_connection_request.request_headers = {"Content-Type": ContentType.URLENCODED.value} + mock_connection_request.headers = {"Content-Type": ContentType.URLENCODED.value} mock_connection_request.body = {"key": "value"} mock_connection_request.method.value = "POST" mock_connection_request.query_params = {"query": "test"} @@ -158,7 +158,7 @@ def test_construct_invoke_connection_request_with_url_encoded_content_type(self) def test_construct_invoke_connection_request_with_form_date_content_type(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} - mock_connection_request.request_headers = {"Content-Type": ContentType.FORMDATA.value} + mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} mock_connection_request.body = { "name": (None, "John Doe") } @@ -179,7 +179,7 @@ def test_parse_insert_response(self): ] result = parse_insert_response(api_response, continue_on_error=True) self.assertEqual(len(result.inserted_fields), 1) - self.assertEqual(len(result.error_data), 1) + self.assertEqual(len(result.errors), 1) def test_parse_insert_response_continue_on_error_false(self): mock_api_response = Mock() @@ -198,7 +198,7 @@ def test_parse_insert_response_continue_on_error_false(self): ] self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.error_data, []) + self.assertEqual(result.errors, []) def test_parse_update_record_response(self): api_response = Mock() @@ -219,7 +219,7 @@ def test_parse_delete_response_successful(self): expected_deleted_ids = ["id_1", "id_2", "id_3"] self.assertEqual(result.deleted_ids, expected_deleted_ids) - self.assertEqual(result.error, []) + self.assertEqual(result.errors, []) def test_parse_get_response_successful(self): mock_api_response = Mock() @@ -238,7 +238,7 @@ def test_parse_get_response_successful(self): ] self.assertEqual(result.data, expected_data) - self.assertEqual(result.error, []) + self.assertEqual(result.errors, []) def test_parse_detokenize_response_with_mixed_records(self): mock_api_response = Mock() diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 0a3ddc9..5a845cf 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import Mock, patch, call +from unittest.mock import Mock, patch from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages @@ -31,7 +31,7 @@ def test_invoke_success(self, mock_send): method=Method.POST, body=VALID_BODY, path_params=VALID_PATH_PARAMS, - request_headers=VALID_HEADERS, + headers=VALID_HEADERS, query_params=VALID_QUERY_PARAMS ) @@ -48,7 +48,7 @@ def test_invoke_invalid_headers(self, mock_send): method="POST", body=VALID_BODY, path_params=VALID_PATH_PARAMS, - request_headers=INVALID_HEADERS, + headers=INVALID_HEADERS, query_params=VALID_QUERY_PARAMS ) @@ -62,7 +62,7 @@ def test_invoke_invalid_body(self, mock_send): method="POST", body=INVALID_BODY, path_params=VALID_PATH_PARAMS, - request_headers=VALID_HEADERS, + headers=VALID_HEADERS, query_params=VALID_QUERY_PARAMS ) @@ -81,7 +81,7 @@ def test_invoke_request_error(self, mock_send): method=Method.POST, body=VALID_BODY, path_params=VALID_PATH_PARAMS, - request_headers=VALID_HEADERS, + headers=VALID_HEADERS, query_params=VALID_QUERY_PARAMS ) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index c490eb9..cdb49ef 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -4,7 +4,7 @@ from skyflow.generated.rest import RecordServiceBatchOperationBody, V1BatchRecord, RecordServiceInsertRecordBody, \ V1FieldRecords, RecordServiceUpdateRecordBody, RecordServiceBulkDeleteRecordBody, QueryServiceExecuteQueryBody, \ V1DetokenizeRecordRequest, V1DetokenizePayload, V1TokenizePayload, V1TokenizeRecordRequest, RedactionEnumREDACTION -from skyflow.utils.enums import TokenStrict, Redaction +from skyflow.utils.enums import TokenStrict, RedactionType from skyflow.vault.controller import Vault from skyflow.vault.data import InsertRequest, InsertResponse, UpdateResponse, UpdateRequest, DeleteResponse, \ DeleteRequest, GetRequest, GetResponse, QueryRequest, QueryResponse @@ -66,7 +66,7 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate) expected_errors = [ {'request_index': 1, 'error': 'Insert error for record 2'} ] - expected_response = InsertResponse(inserted_fields=expected_inserted_fields, error_data=expected_errors) + expected_response = InsertResponse(inserted_fields=expected_inserted_fields, errors=expected_errors) # Set the return value for the parse response mock_parse_response.return_value = expected_response @@ -83,7 +83,7 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate) # Assert that the result matches the expected InsertResponse self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.error_data, expected_errors) + self.assertEqual(result.errors, expected_errors) @patch("skyflow.vault.controller._vault.validate_insert_request") @patch("skyflow.vault.controller._vault.parse_insert_response") @@ -135,7 +135,7 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val # Assert that the result matches the expected InsertResponse self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.error_data, []) # No errors expected + self.assertEqual(result.errors, []) # No errors expected @patch("skyflow.vault.controller._vault.validate_update_request") @patch("skyflow.vault.controller._vault.parse_update_record_response") @@ -190,7 +190,7 @@ def test_update_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected UpdateResponse self.assertEqual(result.updated_field, expected_updated_field) - self.assertEqual(result.error, []) # No errors expected + self.assertEqual(result.errors, []) # No errors expected @patch("skyflow.vault.controller._vault.validate_delete_request") @patch("skyflow.vault.controller._vault.parse_delete_response") @@ -212,7 +212,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate): # Expected parsed response expected_deleted_ids = ["12345", "67890"] - expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, error=[]) + expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, errors=[]) # Set the return value for the parse response mock_parse_response.return_value = expected_response @@ -233,7 +233,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected DeleteResponse self.assertEqual(result.deleted_ids, expected_deleted_ids) - self.assertEqual(result.error, []) # No errors expected + self.assertEqual(result.errors, []) # No errors expected @patch("skyflow.vault.controller._vault.validate_get_request") @patch("skyflow.vault.controller._vault.parse_get_response") @@ -244,7 +244,7 @@ def test_get_successful(self, mock_parse_response, mock_validate): request = GetRequest( table=TABLE_NAME, ids=["12345", "67890"], - redaction_type="PLAIN_TEXT", + redaction_type=RedactionType.PLAIN_TEXT, return_tokens=True, fields=["field1", "field2"], offset="0", @@ -256,7 +256,7 @@ def test_get_successful(self, mock_parse_response, mock_validate): expected_payload = { "object_name": request.table, "skyflow_ids": request.ids, - "redaction": request.redaction_type, + "redaction": request.redaction_type.value, "tokenization": request.return_tokens, "fields": request.fields, "offset": request.offset, @@ -278,7 +278,7 @@ def test_get_successful(self, mock_parse_response, mock_validate): {"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"} ] - expected_response = GetResponse(data=expected_data, error=[]) + expected_response = GetResponse(data=expected_data, errors=[]) # Set the return value for parse_get_response mock_parse_response.return_value = expected_response @@ -298,7 +298,7 @@ def test_get_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected GetResponse self.assertEqual(result.data, expected_data) - self.assertEqual(result.error, []) # No errors expected + self.assertEqual(result.errors, []) # No errors expected @patch("skyflow.vault.controller._vault.validate_query_request") @patch("skyflow.vault.controller._vault.parse_query_response") @@ -344,14 +344,14 @@ def test_query_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected QueryResponse self.assertEqual(result.fields, expected_fields) - self.assertEqual(result.error, []) # No errors expected + self.assertEqual(result.errors, []) # No errors expected @patch("skyflow.vault.controller._vault.validate_detokenize_request") @patch("skyflow.vault.controller._vault.parse_detokenize_response") def test_detokenize_successful(self, mock_parse_response, mock_validate): request = DetokenizeRequest( tokens=["token1", "token2"], - redaction_type="plain-text", + redaction_type=RedactionType.PLAIN_TEXT, continue_on_error=False ) @@ -406,7 +406,7 @@ def test_tokenize_successful(self, mock_parse_response, mock_validate): # Mock request with tokenization parameters request = TokenizeRequest( - tokenize_parameters=[ + values=[ {"value": "value1", "column_group": "group1"}, {"value": "value2", "column_group": "group2"} ]