Skip to content

Commit

Permalink
[BUG] Disable Numeric and String comparison (Eventual-Inc#2019)
Browse files Browse the repository at this point in the history
* Disables comparisons between Utf8 and numeric types
  • Loading branch information
samster25 authored Mar 15, 2024
1 parent 578944f commit 76540fc
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 18 deletions.
4 changes: 4 additions & 0 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ impl DataType {
use DataType::*;
match (self, other) {
(s, o) if s == o => Ok((Boolean, None, s.to_physical())),
(Utf8, o) | (o, Utf8) if o.is_numeric() => Err(DaftError::TypeError(format!(
"Cannot perform comparison on Utf8 and numeric type.\ntypes: {}, {}",
self, other
))),
(s, o) if s.is_physical() && o.is_physical() => {
Ok((Boolean, None, try_physical_supertype(s, o)?))
}
Expand Down
5 changes: 1 addition & 4 deletions src/daft-core/src/series/ops/is_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ impl Series {
}

let (output_type, intermediate, comp_type) =
match self.data_type().membership_op(items.data_type()) {
Ok(types) => types,
Err(_) => return default(self.name(), self.len()),
};
self.data_type().membership_op(items.data_type())?;

let (lhs, rhs) = if let Some(ref it) = intermediate {
(self.cast(it)?, items.cast(it)?)
Expand Down
8 changes: 7 additions & 1 deletion src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,13 @@ impl Expr {
}
IsNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)),
NotNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)),
IsIn(expr, ..) => Ok(Field::new(expr.name()?, DataType::Boolean)),
IsIn(left, right) => {
let left_field = left.to_field(schema)?;
let right_field = right.to_field(schema)?;
let (result_type, _intermediate, _comp_type) =
left_field.dtype.membership_op(&right_field.dtype)?;
Ok(Field::new(left_field.name.as_str(), result_type))
}
Literal(value) => Ok(Field::new("literal", value.get_type())),
Function { func, inputs } => func.to_field(inputs.as_slice(), schema, self),
BinaryOp { op, left, right } => {
Expand Down
8 changes: 7 additions & 1 deletion tests/expressions/typing/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@
assert_typing_resolve_vs_runtime_behavior,
has_supertype,
is_comparable,
is_numeric,
)


def comparable_type_validation(lhs: DataType, rhs: DataType) -> bool:
return is_comparable(lhs) and is_comparable(rhs) and has_supertype(lhs, rhs)
return (
is_comparable(lhs)
and is_comparable(rhs)
and has_supertype(lhs, rhs)
and not ((is_numeric(lhs) and rhs == DataType.string()) or (is_numeric(rhs) and lhs == DataType.string()))
)


@pytest.mark.parametrize("op", [ops.eq, ops.ne, ops.lt, ops.le, ops.gt, ops.ge])
Expand Down
26 changes: 21 additions & 5 deletions tests/series/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
arrow_binary_types = [pa.binary(), pa.large_binary()]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
VALID_INT_STRING_COMPARISONS = list(itertools.product(arrow_int_types, repeat=2)) + list(
itertools.product(arrow_string_types, repeat=2)
)


@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str(l_dtype, r_dtype) -> None:
l_arrow = pa.array([1, 2, 3, None, 5, None])
r_arrow = pa.array([1, 3, 1, 5, None, None])
Expand All @@ -43,7 +48,7 @@ def test_comparisons_int_and_str(l_dtype, r_dtype) -> None:
assert gt == [False, False, True, None, None, None]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str_left_scalar(l_dtype, r_dtype) -> None:
l_arrow = pa.array([2])
r_arrow = pa.array([1, 2, 3, None])
Expand Down Expand Up @@ -71,7 +76,7 @@ def test_comparisons_int_and_str_left_scalar(l_dtype, r_dtype) -> None:
assert gt == [True, False, False, None]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str_right_scalar(l_dtype, r_dtype) -> None:
l_arrow = pa.array([1, 2, 3, None, 5, None])
r_arrow = pa.array([2])
Expand All @@ -98,7 +103,7 @@ def test_comparisons_int_and_str_right_scalar(l_dtype, r_dtype) -> None:
assert gt == [False, False, True, None, True, None]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None:
l_arrow = pa.array([1, 2, 3, None, 5, None])
r_arrow = pa.array([None], type=r_dtype)
Expand Down Expand Up @@ -578,7 +583,7 @@ def test_comparisons_binary_right_scalar(l_dtype, r_dtype) -> None:
assert gt == [False, False, True, None, True, None]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None:
l_arrow = pa.array([1, 2, 3, None, 5, None])
r_arrow = pa.array([None], type=r_dtype)
Expand Down Expand Up @@ -744,3 +749,14 @@ def test_compare_timestamps_diff_tz(tu1, tu2):
tz1 = Series.from_pylist([utc]).cast(DataType.timestamp(tu1, "UTC"))
tz2 = Series.from_pylist([eastern]).cast(DataType.timestamp(tu1, "US/Eastern"))
assert (tz1 == tz2).to_pylist() == [True]


@pytest.mark.parametrize("op", [operator.eq, operator.ne, operator.lt, operator.gt, operator.le, operator.ge])
def test_numeric_and_string_compare_raises_error(op):
left = Series.from_pylist([1, 2, 3])
right = Series.from_pylist(["1", "2", "3"])
with pytest.raises(ValueError, match="Cannot perform comparison on types:"):
op(left, right)

with pytest.raises(ValueError, match="Cannot perform comparison on types:"):
op(right, left)
18 changes: 11 additions & 7 deletions tests/table/test_is_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ def test_table_expr_is_in_same_types(input, items, expected) -> None:
"input,items,expected",
[
# Int
pytest.param([-1, 2, 3, 4], ["-1", "2"], [True, True, False, False], id="IntWithString"),
pytest.param([-1, 2, 3, 4], ["-1", "2"], None, id="IntWithString"),
pytest.param([1, 2, 3, 4], [1.0, 2.0], [True, True, False, False], id="IntWithFloat"),
pytest.param([0, 1, 2, 3], [True], [False, True, False, False], id="IntWithBool"),
# Float
pytest.param([-1.0, 2.0, 3.0, 4.0], ["-1.0", "2.0"], [True, True, False, False], id="FloatWithString"),
pytest.param([-1.0, 2.0, 3.0, 4.0], ["-1.0", "2.0"], None, id="FloatWithString"),
pytest.param([1.0, 2.0, 3.0, 4.0], [1, 2], [True, True, False, False], id="FloatWithInt"),
pytest.param([0.0, 1.0, 2.0, 3.0], [True], [False, True, False, False], id="FloatWithBool"),
# String
pytest.param(["1", "2", "3", "4"], [1, 2], [True, True, False, False], id="StringWithInt"),
pytest.param(["1.0", "2.0", "3.0", "4.0"], [1.0, 2.0], [True, True, False, False], id="StringWithFloat"),
pytest.param(["1", "2", "3", "4"], [1, 2], None, id="StringWithInt"),
pytest.param(["1.0", "2.0", "3.0", "4.0"], [1.0, 2.0], None, id="StringWithFloat"),
# Bool
pytest.param([True, False, None], [1, 0], [True, True, None], id="BoolWithInt"),
pytest.param([True, False, None], [1.0], [True, False, None], id="BoolWithFloat"),
Expand All @@ -104,10 +104,14 @@ def test_table_expr_is_in_same_types(input, items, expected) -> None:
)
def test_table_expr_is_in_different_types_castable(input, items, expected) -> None:
daft_table = MicroPartition.from_pydict({"input": input})
daft_table = daft_table.eval_expression_list([col("input").is_in(items)])
pydict = daft_table.to_pydict()

assert pydict["input"] == expected
if expected is None:
with pytest.raises(ValueError, match="Cannot perform comparison on types:"):
daft_table = daft_table.eval_expression_list([col("input").is_in(items)])
else:
daft_table = daft_table.eval_expression_list([col("input").is_in(items)])
pydict = daft_table.to_pydict()
assert pydict["input"] == expected


@pytest.mark.parametrize(
Expand Down

0 comments on commit 76540fc

Please sign in to comment.