diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 8322d6bd6fb..9835e6f8461 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1188,6 +1188,10 @@ class Cast(Expr): def __init__(self, dtype: plc.DataType, value: Expr) -> None: super().__init__(dtype) self.children = (value,) + if not plc.unary.is_supported_cast(self.dtype, value.dtype): + raise NotImplementedError( + f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}" + ) def do_evaluate( self, diff --git a/python/cudf_polars/tests/expressions/test_casting.py b/python/cudf_polars/tests/expressions/test_casting.py new file mode 100644 index 00000000000..3e003054338 --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_casting.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) + +_supported_dtypes = [(pl.Int8(), pl.Int64())] + +_unsupported_dtypes = [ + (pl.String(), pl.Int64()), +] + + +@pytest.fixture +def dtypes(request): + return request.param + + +@pytest.fixture +def tests(dtypes): + fromtype, totype = dtypes + if fromtype == pl.String(): + data = ["a", "b", "c"] + else: + data = [1, 2, 3] + return pl.DataFrame( + { + "a": pl.Series(data, dtype=fromtype), + } + ).lazy(), totype + + +@pytest.mark.parametrize("dtypes", _supported_dtypes, indirect=True) +def test_cast_supported(tests): + df, totype = tests + q = df.select(pl.col("a").cast(totype)) + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize("dtypes", _unsupported_dtypes, indirect=True) +def test_cast_unsupported(tests): + df, totype = tests + assert_ir_translation_raises( + df.select(pl.col("a").cast(totype)), NotImplementedError + )