diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index f20f9fe662..6d08a6312c 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -37,6 +37,10 @@ impl DataType { use DataType::*; match (self, other) { (s, o) if s == o => Ok((Boolean, None, s.to_physical())), + (Utf8, o) | (o, Utf8) if o.is_numeric() => Err(DaftError::TypeError(format!( + "Cannot perform comparison on Utf8 and numeric type.\ntypes: {}, {}", + self, other + ))), (s, o) if s.is_physical() && o.is_physical() => { Ok((Boolean, None, try_physical_supertype(s, o)?)) } diff --git a/src/daft-core/src/series/ops/is_in.rs b/src/daft-core/src/series/ops/is_in.rs index 06e2f64937..080e7411ad 100644 --- a/src/daft-core/src/series/ops/is_in.rs +++ b/src/daft-core/src/series/ops/is_in.rs @@ -19,10 +19,7 @@ impl Series { } let (output_type, intermediate, comp_type) = - match self.data_type().membership_op(items.data_type()) { - Ok(types) => types, - Err(_) => return default(self.name(), self.len()), - }; + self.data_type().membership_op(items.data_type())?; let (lhs, rhs) = if let Some(ref it) = intermediate { (self.cast(it)?, items.cast(it)?) diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 6ad61c9434..3b4e8716df 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -455,7 +455,13 @@ impl Expr { } IsNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)), NotNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)), - IsIn(expr, ..) => Ok(Field::new(expr.name()?, DataType::Boolean)), + IsIn(left, right) => { + let left_field = left.to_field(schema)?; + let right_field = right.to_field(schema)?; + let (result_type, _intermediate, _comp_type) = + left_field.dtype.membership_op(&right_field.dtype)?; + Ok(Field::new(left_field.name.as_str(), result_type)) + } Literal(value) => Ok(Field::new("literal", value.get_type())), Function { func, inputs } => func.to_field(inputs.as_slice(), schema, self), BinaryOp { op, left, right } => { diff --git a/tests/expressions/typing/test_compare.py b/tests/expressions/typing/test_compare.py index 9ab988e40c..e814b7d236 100644 --- a/tests/expressions/typing/test_compare.py +++ b/tests/expressions/typing/test_compare.py @@ -10,11 +10,17 @@ assert_typing_resolve_vs_runtime_behavior, has_supertype, is_comparable, + is_numeric, ) def comparable_type_validation(lhs: DataType, rhs: DataType) -> bool: - return is_comparable(lhs) and is_comparable(rhs) and has_supertype(lhs, rhs) + return ( + is_comparable(lhs) + and is_comparable(rhs) + and has_supertype(lhs, rhs) + and not ((is_numeric(lhs) and rhs == DataType.string()) or (is_numeric(rhs) and lhs == DataType.string())) + ) @pytest.mark.parametrize("op", [ops.eq, ops.ne, ops.lt, ops.le, ops.gt, ops.ge]) diff --git a/tests/series/test_comparisons.py b/tests/series/test_comparisons.py index bcb64d3723..8a499dff8d 100644 --- a/tests/series/test_comparisons.py +++ b/tests/series/test_comparisons.py @@ -16,7 +16,12 @@ arrow_binary_types = [pa.binary(), pa.large_binary()] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2)) +VALID_INT_STRING_COMPARISONS = list(itertools.product(arrow_int_types, repeat=2)) + list( + itertools.product(arrow_string_types, repeat=2) +) + + +@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str(l_dtype, r_dtype) -> None: l_arrow = pa.array([1, 2, 3, None, 5, None]) r_arrow = pa.array([1, 3, 1, 5, None, None]) @@ -43,7 +48,7 @@ def test_comparisons_int_and_str(l_dtype, r_dtype) -> None: assert gt == [False, False, True, None, None, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2)) +@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str_left_scalar(l_dtype, r_dtype) -> None: l_arrow = pa.array([2]) r_arrow = pa.array([1, 2, 3, None]) @@ -71,7 +76,7 @@ def test_comparisons_int_and_str_left_scalar(l_dtype, r_dtype) -> None: assert gt == [True, False, False, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2)) +@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str_right_scalar(l_dtype, r_dtype) -> None: l_arrow = pa.array([1, 2, 3, None, 5, None]) r_arrow = pa.array([2]) @@ -98,7 +103,7 @@ def test_comparisons_int_and_str_right_scalar(l_dtype, r_dtype) -> None: assert gt == [False, False, True, None, True, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2)) +@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None: l_arrow = pa.array([1, 2, 3, None, 5, None]) r_arrow = pa.array([None], type=r_dtype) @@ -578,7 +583,7 @@ def test_comparisons_binary_right_scalar(l_dtype, r_dtype) -> None: assert gt == [False, False, True, None, True, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2)) +@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None: l_arrow = pa.array([1, 2, 3, None, 5, None]) r_arrow = pa.array([None], type=r_dtype) @@ -744,3 +749,14 @@ def test_compare_timestamps_diff_tz(tu1, tu2): tz1 = Series.from_pylist([utc]).cast(DataType.timestamp(tu1, "UTC")) tz2 = Series.from_pylist([eastern]).cast(DataType.timestamp(tu1, "US/Eastern")) assert (tz1 == tz2).to_pylist() == [True] + + +@pytest.mark.parametrize("op", [operator.eq, operator.ne, operator.lt, operator.gt, operator.le, operator.ge]) +def test_numeric_and_string_compare_raises_error(op): + left = Series.from_pylist([1, 2, 3]) + right = Series.from_pylist(["1", "2", "3"]) + with pytest.raises(ValueError, match="Cannot perform comparison on types:"): + op(left, right) + + with pytest.raises(ValueError, match="Cannot perform comparison on types:"): + op(right, left) diff --git a/tests/table/test_is_in.py b/tests/table/test_is_in.py index a351937505..a3ed42b850 100644 --- a/tests/table/test_is_in.py +++ b/tests/table/test_is_in.py @@ -73,16 +73,16 @@ def test_table_expr_is_in_same_types(input, items, expected) -> None: "input,items,expected", [ # Int - pytest.param([-1, 2, 3, 4], ["-1", "2"], [True, True, False, False], id="IntWithString"), + pytest.param([-1, 2, 3, 4], ["-1", "2"], None, id="IntWithString"), pytest.param([1, 2, 3, 4], [1.0, 2.0], [True, True, False, False], id="IntWithFloat"), pytest.param([0, 1, 2, 3], [True], [False, True, False, False], id="IntWithBool"), # Float - pytest.param([-1.0, 2.0, 3.0, 4.0], ["-1.0", "2.0"], [True, True, False, False], id="FloatWithString"), + pytest.param([-1.0, 2.0, 3.0, 4.0], ["-1.0", "2.0"], None, id="FloatWithString"), pytest.param([1.0, 2.0, 3.0, 4.0], [1, 2], [True, True, False, False], id="FloatWithInt"), pytest.param([0.0, 1.0, 2.0, 3.0], [True], [False, True, False, False], id="FloatWithBool"), # String - pytest.param(["1", "2", "3", "4"], [1, 2], [True, True, False, False], id="StringWithInt"), - pytest.param(["1.0", "2.0", "3.0", "4.0"], [1.0, 2.0], [True, True, False, False], id="StringWithFloat"), + pytest.param(["1", "2", "3", "4"], [1, 2], None, id="StringWithInt"), + pytest.param(["1.0", "2.0", "3.0", "4.0"], [1.0, 2.0], None, id="StringWithFloat"), # Bool pytest.param([True, False, None], [1, 0], [True, True, None], id="BoolWithInt"), pytest.param([True, False, None], [1.0], [True, False, None], id="BoolWithFloat"), @@ -104,10 +104,14 @@ def test_table_expr_is_in_same_types(input, items, expected) -> None: ) def test_table_expr_is_in_different_types_castable(input, items, expected) -> None: daft_table = MicroPartition.from_pydict({"input": input}) - daft_table = daft_table.eval_expression_list([col("input").is_in(items)]) - pydict = daft_table.to_pydict() - assert pydict["input"] == expected + if expected is None: + with pytest.raises(ValueError, match="Cannot perform comparison on types:"): + daft_table = daft_table.eval_expression_list([col("input").is_in(items)]) + else: + daft_table = daft_table.eval_expression_list([col("input").is_in(items)]) + pydict = daft_table.to_pydict() + assert pydict["input"] == expected @pytest.mark.parametrize(