diff --git a/skyflow/utils/__init__.py b/skyflow/utils/__init__.py index 786c4d9..97642b3 100644 --- a/skyflow/utils/__init__.py +++ b/skyflow/utils/__init__.py @@ -2,5 +2,5 @@ from ._skyflow_messages import SkyflowMessages from ._version import SDK_VERSION from ._helpers import get_base_url, format_scope -from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key +from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index c88a49b..1db30a8 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -8,6 +8,7 @@ import platform import sys import re +from urllib.parse import quote from skyflow.error import SkyflowError from skyflow.generated.rest import V1UpdateRecordResponse, V1BulkDeleteRecordResponse, \ V1DetokenizeResponse, V1TokenizeResponse, V1GetQueryResponse, V1BulkGetRecordResponse @@ -377,3 +378,11 @@ def handle_text_error(err, data, request_id, logger): def handle_generic_error(err, request_id, logger): description = "An error occurred." log_and_reject_error(description, err.status, request_id, logger = logger) + + +def encode_column_values(get_request): + encoded_column_values = list() + for column in get_request.column_values: + encoded_column_values.append(quote(column)) + + return encoded_column_values diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index d656dae..770e535 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -5,7 +5,7 @@ from skyflow.generated.rest.exceptions import BadRequestException, UnauthorizedException from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ - parse_tokenize_response, parse_query_response, parse_get_response + parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_insert_request, validate_delete_request, validate_query_request, \ validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request @@ -23,7 +23,17 @@ def __build_bulk_field_records(self, values, tokens=None): if tokens is None: return [V1FieldRecords(fields=record) for record in values] else: - return [V1FieldRecords(fields=record, tokens=token) for record, token in zip(values, tokens)] + bulk_record_list = [] + for i, value in enumerate(values): + token = tokens[i] if tokens is not None and i < len(tokens) else None + bulk_record = V1FieldRecords( + fields=value, + tokens=token + ) + if token is not None: + bulk_record.tokens = token + bulk_record_list.append(bulk_record) + return bulk_record_list def __build_batch_field_records(self, values, tokens, table_name, return_tokens, upsert): batch_record_list = [] @@ -151,6 +161,8 @@ def delete(self, request: DeleteRequest): def get(self, request: GetRequest): log_info(SkyflowMessages.Info.VALIDATE_GET_REQUEST.value, self.__vault_client.get_logger()) validate_get_request(self.__vault_client.get_logger(), request) + if request.column_values: + request.column_values = encode_column_values(request) log_info(SkyflowMessages.Info.GET_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() records_api = self.__vault_client.get_records_api() diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 368d636..d5d4c72 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -2,13 +2,15 @@ from unittest.mock import patch, Mock import os import json +from unittest.mock import MagicMock +from urllib.parse import quote from requests import PreparedRequest from requests.models import HTTPError from skyflow.error import SkyflowError from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \ parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \ parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ - handle_exception, validate_api_key + handle_exception, validate_api_key, encode_column_values from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics from skyflow.utils.enums import EnvUrls, Env, ContentType from skyflow.vault.connection import InvokeConnectionResponse @@ -384,3 +386,17 @@ def test_validate_api_key_invalid_length(self): def test_validate_api_key_invalid_pattern(self): invalid_key = "sky-ABCDE-1234567890GHIJKL7890abcdef" self.assertFalse(validate_api_key(invalid_key)) + + def test_encode_column_values(self): + get_request = MagicMock() + get_request.column_values = ["Hello World!", "foo/bar", "key=value", "email@example.com"] + + expected_encoded_values = [ + quote("Hello World!"), + quote("foo/bar"), + quote("key=value"), + quote("email@example.com"), + ] + + result = encode_column_values(get_request) + self.assertEqual(result, expected_encoded_values) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 933e713..0d2ea3d 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -1,6 +1,5 @@ import unittest from unittest.mock import Mock, patch - from skyflow.generated.rest import RecordServiceBatchOperationBody, V1BatchRecord, RecordServiceInsertRecordBody, \ V1FieldRecords, RecordServiceUpdateRecordBody, RecordServiceBulkDeleteRecordBody, QueryServiceExecuteQueryBody, \ V1DetokenizeRecordRequest, V1DetokenizePayload, V1TokenizePayload, V1TokenizeRecordRequest, RedactionEnumREDACTION @@ -138,6 +137,58 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val self.assertEqual(result.inserted_fields, expected_inserted_fields) self.assertEqual(result.errors, []) # No errors expected + @patch("skyflow.vault.controller._vault.validate_insert_request") + @patch("skyflow.vault.controller._vault.parse_insert_response") + def test_insert_with_continue_on_error_false_when_tokens_are_not_none(self, mock_parse_response, mock_validate): + """Test insert functionality when continue_on_error is False, ensuring a single bulk insert.""" + + # Mock request with continue_on_error set to False + request = InsertRequest( + table_name=TABLE_NAME, + values=[{"field": "value"}], + tokens=[{"token_field": "token_val1"}], + return_tokens=True, + upsert=None, + homogeneous=True, + continue_on_error=False + ) + + # Expected API request body based on InsertRequest parameters + expected_body = RecordServiceInsertRecordBody( + records=[ + V1FieldRecords(fields={"field": "value"}, tokens={"token_field": "token_val1"}) + ], + tokenization=True, + upsert=None, + homogeneous=True + ) + + # Mock API response for a successful insert + mock_api_response = Mock() + mock_api_response.records = [{"skyflow_id": "id1", "tokens": {"token_field": "token_val1"}}] + + # Expected parsed response + expected_inserted_fields = [{'skyflow_id': 'id1', 'token_field': 'token_val1'}] + expected_response = InsertResponse(inserted_fields=expected_inserted_fields) + + # Set the return value for the parse response + mock_parse_response.return_value = expected_response + records_api = self.vault_client.get_records_api.return_value + records_api.record_service_insert_record.return_value = mock_api_response + + # Call the insert function + result = self.vault.insert(request) + + # Assertions + mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + records_api.record_service_insert_record.assert_called_once_with(VAULT_ID, TABLE_NAME, + expected_body) + mock_parse_response.assert_called_once_with(mock_api_response, False) + + # Assert that the result matches the expected InsertResponse + self.assertEqual(result.inserted_fields, expected_inserted_fields) + self.assertEqual(result.errors, []) # No errors expected + @patch("skyflow.vault.controller._vault.validate_update_request") @patch("skyflow.vault.controller._vault.parse_update_record_response") def test_update_successful(self, mock_parse_response, mock_validate): @@ -250,7 +301,8 @@ def test_get_successful(self, mock_parse_response, mock_validate): fields=["field1", "field2"], offset="0", limit="10", - download_url=True + download_url=True, + column_values=None ) # Expected payload @@ -301,6 +353,58 @@ def test_get_successful(self, mock_parse_response, mock_validate): self.assertEqual(result.data, expected_data) self.assertEqual(result.errors, []) # No errors expected + @patch("skyflow.vault.controller._vault.validate_get_request") + @patch("skyflow.vault.controller._vault.parse_get_response") + def test_get_successful_with_column_values(self, mock_parse_response, mock_validate): + """Test get functionality for a successful get request.""" + + # Mock request + request = GetRequest( + table=TABLE_NAME, + redaction_type=RedactionType.PLAIN_TEXT, + column_values=['customer+15@gmail.com'], + column_name='email' + ) + + # Expected payload + expected_payload = { + "object_name": request.table, + "tokenization": request.return_tokens, + "column_name": request.column_name, + "column_values": request.column_values + } + + # Mock API response + mock_api_response = Mock() + mock_api_response.records = [ + Mock(fields={"field1": "value1", "field2": "value2"}), + Mock(fields={"field1": "value3", "field2": "value4"}) + ] + + # Expected parsed response + expected_data = [ + {"field1": "value1", "field2": "value2"}, + {"field1": "value3", "field2": "value4"} + ] + expected_response = GetResponse(data=expected_data, errors=[]) + + # Set the return value for parse_get_response + mock_parse_response.return_value = expected_response + records_api = self.vault_client.get_records_api.return_value + records_api.record_service_bulk_get_record.return_value = mock_api_response + + # Call the get function + result = self.vault.get(request) + + # Assertions + mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + records_api.record_service_bulk_get_record.assert_called_once() + mock_parse_response.assert_called_once_with(mock_api_response) + + # Check that the result matches the expected GetResponse + self.assertEqual(result.data, expected_data) + self.assertEqual(result.errors, []) # No errors expected + @patch("skyflow.vault.controller._vault.validate_query_request") @patch("skyflow.vault.controller._vault.parse_query_response") def test_query_successful(self, mock_parse_response, mock_validate):