Skip to content

Commit

Permalink
SK-1736: Added encoding for column values in get interface (#145)
Browse files Browse the repository at this point in the history
* SK-1736: Added encoding for column values in get interface
  • Loading branch information
saileshwar-skyflow authored Dec 3, 2024
1 parent a27f123 commit 69630e5
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 6 deletions.
2 changes: 1 addition & 1 deletion skyflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from ._skyflow_messages import SkyflowMessages
from ._version import SDK_VERSION
from ._helpers import get_base_url, format_scope
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values

9 changes: 9 additions & 0 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import platform
import sys
import re
from urllib.parse import quote
from skyflow.error import SkyflowError
from skyflow.generated.rest import V1UpdateRecordResponse, V1BulkDeleteRecordResponse, \
V1DetokenizeResponse, V1TokenizeResponse, V1GetQueryResponse, V1BulkGetRecordResponse
Expand Down Expand Up @@ -377,3 +378,11 @@ def handle_text_error(err, data, request_id, logger):
def handle_generic_error(err, request_id, logger):
description = "An error occurred."
log_and_reject_error(description, err.status, request_id, logger = logger)


def encode_column_values(get_request):
encoded_column_values = list()
for column in get_request.column_values:
encoded_column_values.append(quote(column))

return encoded_column_values
16 changes: 14 additions & 2 deletions skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from skyflow.generated.rest.exceptions import BadRequestException, UnauthorizedException
from skyflow.utils import SkyflowMessages, parse_insert_response, \
handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \
parse_tokenize_response, parse_query_response, parse_get_response
parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values
from skyflow.utils.logger import log_info, log_error_log
from skyflow.utils.validations import validate_insert_request, validate_delete_request, validate_query_request, \
validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request
Expand All @@ -23,7 +23,17 @@ def __build_bulk_field_records(self, values, tokens=None):
if tokens is None:
return [V1FieldRecords(fields=record) for record in values]
else:
return [V1FieldRecords(fields=record, tokens=token) for record, token in zip(values, tokens)]
bulk_record_list = []
for i, value in enumerate(values):
token = tokens[i] if tokens is not None and i < len(tokens) else None
bulk_record = V1FieldRecords(
fields=value,
tokens=token
)
if token is not None:
bulk_record.tokens = token
bulk_record_list.append(bulk_record)
return bulk_record_list

def __build_batch_field_records(self, values, tokens, table_name, return_tokens, upsert):
batch_record_list = []
Expand Down Expand Up @@ -151,6 +161,8 @@ def delete(self, request: DeleteRequest):
def get(self, request: GetRequest):
log_info(SkyflowMessages.Info.VALIDATE_GET_REQUEST.value, self.__vault_client.get_logger())
validate_get_request(self.__vault_client.get_logger(), request)
if request.column_values:
request.column_values = encode_column_values(request)
log_info(SkyflowMessages.Info.GET_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
self.__initialize()
records_api = self.__vault_client.get_records_api()
Expand Down
18 changes: 17 additions & 1 deletion tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from unittest.mock import patch, Mock
import os
import json
from unittest.mock import MagicMock
from urllib.parse import quote
from requests import PreparedRequest
from requests.models import HTTPError
from skyflow.error import SkyflowError
from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \
parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \
parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \
handle_exception, validate_api_key
handle_exception, validate_api_key, encode_column_values
from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics
from skyflow.utils.enums import EnvUrls, Env, ContentType
from skyflow.vault.connection import InvokeConnectionResponse
Expand Down Expand Up @@ -384,3 +386,17 @@ def test_validate_api_key_invalid_length(self):
def test_validate_api_key_invalid_pattern(self):
invalid_key = "sky-ABCDE-1234567890GHIJKL7890abcdef"
self.assertFalse(validate_api_key(invalid_key))

def test_encode_column_values(self):
get_request = MagicMock()
get_request.column_values = ["Hello World!", "foo/bar", "key=value", "[email protected]"]

expected_encoded_values = [
quote("Hello World!"),
quote("foo/bar"),
quote("key=value"),
quote("[email protected]"),
]

result = encode_column_values(get_request)
self.assertEqual(result, expected_encoded_values)
108 changes: 106 additions & 2 deletions tests/vault/controller/test__vault.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
from unittest.mock import Mock, patch

from skyflow.generated.rest import RecordServiceBatchOperationBody, V1BatchRecord, RecordServiceInsertRecordBody, \
V1FieldRecords, RecordServiceUpdateRecordBody, RecordServiceBulkDeleteRecordBody, QueryServiceExecuteQueryBody, \
V1DetokenizeRecordRequest, V1DetokenizePayload, V1TokenizePayload, V1TokenizeRecordRequest, RedactionEnumREDACTION
Expand Down Expand Up @@ -138,6 +137,58 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val
self.assertEqual(result.inserted_fields, expected_inserted_fields)
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_insert_request")
@patch("skyflow.vault.controller._vault.parse_insert_response")
def test_insert_with_continue_on_error_false_when_tokens_are_not_none(self, mock_parse_response, mock_validate):
"""Test insert functionality when continue_on_error is False, ensuring a single bulk insert."""

# Mock request with continue_on_error set to False
request = InsertRequest(
table_name=TABLE_NAME,
values=[{"field": "value"}],
tokens=[{"token_field": "token_val1"}],
return_tokens=True,
upsert=None,
homogeneous=True,
continue_on_error=False
)

# Expected API request body based on InsertRequest parameters
expected_body = RecordServiceInsertRecordBody(
records=[
V1FieldRecords(fields={"field": "value"}, tokens={"token_field": "token_val1"})
],
tokenization=True,
upsert=None,
homogeneous=True
)

# Mock API response for a successful insert
mock_api_response = Mock()
mock_api_response.records = [{"skyflow_id": "id1", "tokens": {"token_field": "token_val1"}}]

# Expected parsed response
expected_inserted_fields = [{'skyflow_id': 'id1', 'token_field': 'token_val1'}]
expected_response = InsertResponse(inserted_fields=expected_inserted_fields)

# Set the return value for the parse response
mock_parse_response.return_value = expected_response
records_api = self.vault_client.get_records_api.return_value
records_api.record_service_insert_record.return_value = mock_api_response

# Call the insert function
result = self.vault.insert(request)

# Assertions
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
records_api.record_service_insert_record.assert_called_once_with(VAULT_ID, TABLE_NAME,
expected_body)
mock_parse_response.assert_called_once_with(mock_api_response, False)

# Assert that the result matches the expected InsertResponse
self.assertEqual(result.inserted_fields, expected_inserted_fields)
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_update_request")
@patch("skyflow.vault.controller._vault.parse_update_record_response")
def test_update_successful(self, mock_parse_response, mock_validate):
Expand Down Expand Up @@ -250,7 +301,8 @@ def test_get_successful(self, mock_parse_response, mock_validate):
fields=["field1", "field2"],
offset="0",
limit="10",
download_url=True
download_url=True,
column_values=None
)

# Expected payload
Expand Down Expand Up @@ -301,6 +353,58 @@ def test_get_successful(self, mock_parse_response, mock_validate):
self.assertEqual(result.data, expected_data)
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_get_request")
@patch("skyflow.vault.controller._vault.parse_get_response")
def test_get_successful_with_column_values(self, mock_parse_response, mock_validate):
"""Test get functionality for a successful get request."""

# Mock request
request = GetRequest(
table=TABLE_NAME,
redaction_type=RedactionType.PLAIN_TEXT,
column_values=['[email protected]'],
column_name='email'
)

# Expected payload
expected_payload = {
"object_name": request.table,
"tokenization": request.return_tokens,
"column_name": request.column_name,
"column_values": request.column_values
}

# Mock API response
mock_api_response = Mock()
mock_api_response.records = [
Mock(fields={"field1": "value1", "field2": "value2"}),
Mock(fields={"field1": "value3", "field2": "value4"})
]

# Expected parsed response
expected_data = [
{"field1": "value1", "field2": "value2"},
{"field1": "value3", "field2": "value4"}
]
expected_response = GetResponse(data=expected_data, errors=[])

# Set the return value for parse_get_response
mock_parse_response.return_value = expected_response
records_api = self.vault_client.get_records_api.return_value
records_api.record_service_bulk_get_record.return_value = mock_api_response

# Call the get function
result = self.vault.get(request)

# Assertions
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
records_api.record_service_bulk_get_record.assert_called_once()
mock_parse_response.assert_called_once_with(mock_api_response)

# Check that the result matches the expected GetResponse
self.assertEqual(result.data, expected_data)
self.assertEqual(result.errors, []) # No errors expected

@patch("skyflow.vault.controller._vault.validate_query_request")
@patch("skyflow.vault.controller._vault.parse_query_response")
def test_query_successful(self, mock_parse_response, mock_validate):
Expand Down

0 comments on commit 69630e5

Please sign in to comment.