Skip to content

Commit

Permalink
SK-1736: Added encoding for column values in get interface
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshwar-skyflow committed Dec 3, 2024
1 parent b4ac30f commit 9827825
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 4 deletions.
2 changes: 1 addition & 1 deletion skyflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from ._skyflow_messages import SkyflowMessages
from ._version import SDK_VERSION
from ._helpers import get_base_url, format_scope
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values

9 changes: 9 additions & 0 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import platform
import sys
import re
from urllib.parse import quote
from skyflow.error import SkyflowError
from skyflow.generated.rest import V1UpdateRecordResponse, V1BulkDeleteRecordResponse, \
V1DetokenizeResponse, V1TokenizeResponse, V1GetQueryResponse, V1BulkGetRecordResponse
Expand Down Expand Up @@ -377,3 +378,11 @@ def handle_text_error(err, data, request_id, logger):
def handle_generic_error(err, request_id, logger):
description = "An error occurred."
log_and_reject_error(description, err.status, request_id, logger = logger)


def encode_column_values(get_request):
encoded_column_values = list()
for column in get_request.column_values:
encoded_column_values.append(quote(column))

return encoded_column_values
15 changes: 13 additions & 2 deletions skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from skyflow.generated.rest.exceptions import BadRequestException, UnauthorizedException
from skyflow.utils import SkyflowMessages, parse_insert_response, \
handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \
parse_tokenize_response, parse_query_response, parse_get_response
parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values
from skyflow.utils.logger import log_info, log_error_log
from skyflow.utils.validations import validate_insert_request, validate_delete_request, validate_query_request, \
validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request
Expand All @@ -23,7 +23,17 @@ def __build_bulk_field_records(self, values, tokens=None):
if tokens is None:
return [V1FieldRecords(fields=record) for record in values]
else:
return [V1FieldRecords(fields=record, tokens=token) for record, token in zip(values, tokens)]
bulk_record_list = []
for i, value in enumerate(values):
token = tokens[i] if tokens is not None and i < len(tokens) else None
bulk_record = V1FieldRecords(
fields=value,
tokens=token
)
if token is not None:
bulk_record.tokens = token
bulk_record_list.append(bulk_record)
return bulk_record_list

def __build_batch_field_records(self, values, tokens, table_name, return_tokens, upsert):
batch_record_list = []
Expand Down Expand Up @@ -151,6 +161,7 @@ def delete(self, request: DeleteRequest):
def get(self, request: GetRequest):
log_info(SkyflowMessages.Info.VALIDATE_GET_REQUEST.value, self.__vault_client.get_logger())
validate_get_request(self.__vault_client.get_logger(), request)
request.column_values = encode_column_values(request)
log_info(SkyflowMessages.Info.GET_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
self.__initialize()
records_api = self.__vault_client.get_records_api()
Expand Down
18 changes: 17 additions & 1 deletion tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from unittest.mock import patch, Mock
import os
import json
from unittest.mock import MagicMock
from urllib.parse import quote
from requests import PreparedRequest
from requests.models import HTTPError
from skyflow.error import SkyflowError
from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \
parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \
parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \
handle_exception, validate_api_key
handle_exception, validate_api_key, encode_column_values
from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics
from skyflow.utils.enums import EnvUrls, Env, ContentType
from skyflow.vault.connection import InvokeConnectionResponse
Expand Down Expand Up @@ -384,3 +386,17 @@ def test_validate_api_key_invalid_length(self):
def test_validate_api_key_invalid_pattern(self):
invalid_key = "sky-ABCDE-1234567890GHIJKL7890abcdef"
self.assertFalse(validate_api_key(invalid_key))

def test_encode_column_values(self):
get_request = MagicMock()
get_request.column_values = ["Hello World!", "foo/bar", "key=value", "[email protected]"]

expected_encoded_values = [
quote("Hello World!"),
quote("foo/bar"),
quote("key=value"),
quote("[email protected]"),
]

result = encode_column_values(get_request)
self.assertEqual(result, expected_encoded_values)

0 comments on commit 9827825

Please sign in to comment.