Skip to content

Commit

Permalink
Improve filter API validation
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-ssh16 authored and gs-ssh16 committed Aug 17, 2023
1 parent 1f88b45 commit 7c0912b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def calculate_columns(self) -> PyLegendSequence["TdsColumn"]:
def validate(self) -> bool:
tds_row = TdsRow.from_tds_frame("frame", self.__base_frame)

copy = self.__filter_function # For MyPy
if not isinstance(copy, type(lambda x: 0)) or (copy.__code__.co_argcount != 1):
raise TypeError("Filter function should be a lambda which takes one argument (TDSRow)")

try:
result = self.__filter_function(tds_row)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def test_extend_function_error_on_non_lambda_func(self) -> None:
assert r.value.args[0] == ("Error at extend function at index 0 (0-indexed). Each extend function "
"should be a lambda which takes one argument (TDSRow)")

def test_extend_function_error_on_incompatible_lambda_func(self) -> None:
columns = [
PrimitiveTdsColumn.integer_column("col1"),
PrimitiveTdsColumn.string_column("col2")
]
frame: LegendApiTdsFrame = LegendApiTableSpecInputFrame(['test_schema', 'test_table'], columns)
with pytest.raises(TypeError) as r:
frame.extend([lambda x, y: 1], ["col4"]) # type: ignore
assert r.value.args[0] == ("Error at extend function at index 0 (0-indexed). Each extend function "
"should be a lambda which takes one argument (TDSRow)")

def test_extend_function_error_on_non_string_name(self) -> None:
columns = [
PrimitiveTdsColumn.integer_column("col1"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ def test_filter_function_error_on_unknown_col(self) -> None:
"Column - 'col3' doesn't exist in the current frame. Current frame columns: "
"['col1', 'col2']")

def test_filter_function_error_non_lambda_arg(self) -> None:
columns = [
PrimitiveTdsColumn.integer_column("col1"),
PrimitiveTdsColumn.string_column("col2")
]
frame: LegendApiTdsFrame = LegendApiTableSpecInputFrame(['test_schema', 'test_table'], columns)

with pytest.raises(TypeError) as r:
frame.filter(1) # type: ignore
assert r.value.args[0] == "Filter function should be a lambda which takes one argument (TDSRow)"

def test_filter_function_error_multi_param_lambda_arg(self) -> None:
columns = [
PrimitiveTdsColumn.integer_column("col1"),
PrimitiveTdsColumn.string_column("col2")
]
frame: LegendApiTdsFrame = LegendApiTableSpecInputFrame(['test_schema', 'test_table'], columns)

with pytest.raises(TypeError) as r:
frame.filter(lambda x, y: 1) # type: ignore
assert r.value.args[0] == "Filter function should be a lambda which takes one argument (TDSRow)"

def test_filter_function_error_on_non_boolean_func(self) -> None:
columns = [
PrimitiveTdsColumn.integer_column("col1"),
Expand Down

0 comments on commit 7c0912b

Please sign in to comment.