diff --git a/skyflow/utils/__init__.py b/skyflow/utils/__init__.py index 786c4d9..97642b3 100644 --- a/skyflow/utils/__init__.py +++ b/skyflow/utils/__init__.py @@ -2,5 +2,5 @@ from ._skyflow_messages import SkyflowMessages from ._version import SDK_VERSION from ._helpers import get_base_url, format_scope -from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key +from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index c88a49b..1db30a8 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -8,6 +8,7 @@ import platform import sys import re +from urllib.parse import quote from skyflow.error import SkyflowError from skyflow.generated.rest import V1UpdateRecordResponse, V1BulkDeleteRecordResponse, \ V1DetokenizeResponse, V1TokenizeResponse, V1GetQueryResponse, V1BulkGetRecordResponse @@ -377,3 +378,11 @@ def handle_text_error(err, data, request_id, logger): def handle_generic_error(err, request_id, logger): description = "An error occurred." log_and_reject_error(description, err.status, request_id, logger = logger) + + +def encode_column_values(get_request): + encoded_column_values = list() + for column in get_request.column_values: + encoded_column_values.append(quote(column)) + + return encoded_column_values diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index d656dae..d4fd8ac 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -5,7 +5,7 @@ from skyflow.generated.rest.exceptions import BadRequestException, UnauthorizedException from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ - parse_tokenize_response, parse_query_response, parse_get_response + parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_insert_request, validate_delete_request, validate_query_request, \ validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request @@ -23,7 +23,17 @@ def __build_bulk_field_records(self, values, tokens=None): if tokens is None: return [V1FieldRecords(fields=record) for record in values] else: - return [V1FieldRecords(fields=record, tokens=token) for record, token in zip(values, tokens)] + bulk_record_list = [] + for i, value in enumerate(values): + token = tokens[i] if tokens is not None and i < len(tokens) else None + bulk_record = V1FieldRecords( + fields=value, + tokens=token + ) + if token is not None: + bulk_record.tokens = token + bulk_record_list.append(bulk_record) + return bulk_record_list def __build_batch_field_records(self, values, tokens, table_name, return_tokens, upsert): batch_record_list = [] @@ -151,6 +161,7 @@ def delete(self, request: DeleteRequest): def get(self, request: GetRequest): log_info(SkyflowMessages.Info.VALIDATE_GET_REQUEST.value, self.__vault_client.get_logger()) validate_get_request(self.__vault_client.get_logger(), request) + request.column_values = encode_column_values(request) log_info(SkyflowMessages.Info.GET_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() records_api = self.__vault_client.get_records_api() diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 368d636..d5d4c72 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -2,13 +2,15 @@ from unittest.mock import patch, Mock import os import json +from unittest.mock import MagicMock +from urllib.parse import quote from requests import PreparedRequest from requests.models import HTTPError from skyflow.error import SkyflowError from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \ parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \ parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ - handle_exception, validate_api_key + handle_exception, validate_api_key, encode_column_values from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics from skyflow.utils.enums import EnvUrls, Env, ContentType from skyflow.vault.connection import InvokeConnectionResponse @@ -384,3 +386,17 @@ def test_validate_api_key_invalid_length(self): def test_validate_api_key_invalid_pattern(self): invalid_key = "sky-ABCDE-1234567890GHIJKL7890abcdef" self.assertFalse(validate_api_key(invalid_key)) + + def test_encode_column_values(self): + get_request = MagicMock() + get_request.column_values = ["Hello World!", "foo/bar", "key=value", "email@example.com"] + + expected_encoded_values = [ + quote("Hello World!"), + quote("foo/bar"), + quote("key=value"), + quote("email@example.com"), + ] + + result = encode_column_values(get_request) + self.assertEqual(result, expected_encoded_values)