Skip to content

Commit

Permalink
add support for decimal types, clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Dec 12, 2024
1 parent 0b219fd commit cef040c
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 17 deletions.
7 changes: 7 additions & 0 deletions cpp/src/unary/math_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ struct fixed_point_abs {
__device__ T operator()(T data) { return numeric::detail::abs(data); }
};

template <typename T>
struct fixed_point_negate {
T n;
__device__ T operator()(T data) { return -data; }
};

template <typename T, template <typename> typename FixedPointFunctor>
std::unique_ptr<column> unary_op_with(column_view const& input,
rmm::cuda_stream_view stream,
Expand Down Expand Up @@ -578,6 +584,7 @@ struct FixedPointOpDispatcher {
case cudf::unary_operator::CEIL: return unary_op_with<T, fixed_point_ceil>(input, stream, mr);
case cudf::unary_operator::FLOOR: return unary_op_with<T, fixed_point_floor>(input, stream, mr);
case cudf::unary_operator::ABS: return unary_op_with<T, fixed_point_abs>(input, stream, mr);
case cudf::unary_operator::NEGATE: return unary_op_with<T, fixed_point_negate>(input, stream, mr);
default: CUDF_FAIL("Unsupported fixed_point unary operation");
}
// clang-format on
Expand Down
26 changes: 9 additions & 17 deletions cpp/tests/unary/math_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@

#include <vector>

using SignedNumericTypesNotBool =
cudf::test::Types<int8_t, int16_t, int32_t, int64_t, float, double>;
using TypesToNegate =
cudf::test::Types<int8_t, int16_t, int32_t, int64_t, float, double, cudf::duration_D>;

template <typename T>
struct UnaryMathOpsSignedTest : public cudf::test::BaseFixture {};
struct UnaryNegateTests : public cudf::test::BaseFixture {};

TYPED_TEST_SUITE(UnaryMathOpsSignedTest, SignedNumericTypesNotBool);
TYPED_TEST_SUITE(UnaryNegateTests, TypesToNegate);

TYPED_TEST(UnaryMathOpsSignedTest, SimpleNEGATE)
TYPED_TEST(UnaryNegateTests, SimpleNEGATE)
{
cudf::test::fixed_width_column_wrapper<TypeParam> input{{1, 2, 3}};
auto const v = cudf::test::make_type_param_vector<TypeParam>({-1, -2, -3});
cudf::test::fixed_width_column_wrapper<TypeParam> expected(v.begin(), v.end());
using T = TypeParam;
cudf::test::fixed_width_column_wrapper<T> input{{1, 2, 3}};
auto const v = cudf::test::make_type_param_vector<T>({-1, -2, -3});
cudf::test::fixed_width_column_wrapper<T> expected(v.begin(), v.end());
auto output = cudf::unary_operation(input, cudf::unary_operator::NEGATE);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, output->view());
}
Expand Down Expand Up @@ -251,15 +252,6 @@ using floating_point_type_list = ::testing::Types<float, double>;

TYPED_TEST_SUITE(UnaryMathFloatOpsTest, floating_point_type_list);

TYPED_TEST(UnaryMathFloatOpsTest, SimpleNEGATE)
{
cudf::test::fixed_width_column_wrapper<TypeParam> input{{1.0, 2.0}};
auto const v = cudf::test::make_type_param_vector<TypeParam>({-1.0, -2.0});
cudf::test::fixed_width_column_wrapper<TypeParam> expected(v.begin(), v.end());
auto output = cudf::unary_operation(input, cudf::unary_operator::NEGATE);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, output->view());
}

TYPED_TEST(UnaryMathFloatOpsTest, SimpleSIN)
{
cudf::test::fixed_width_column_wrapper<TypeParam> input{{0.0}};
Expand Down
14 changes: 14 additions & 0 deletions cpp/tests/unary/unary_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,20 @@ struct FixedPointUnaryTests : public cudf::test::BaseFixture {};

TYPED_TEST_SUITE(FixedPointUnaryTests, cudf::test::FixedPointTypes);

TYPED_TEST(FixedPointUnaryTests, FixedPointUnaryNegate)
{
using namespace numeric;
using decimalXX = TypeParam;
using RepType = cudf::device_storage_type_t<decimalXX>;
using fp_wrapper = cudf::test::fixed_point_column_wrapper<RepType>;

auto const input = fp_wrapper{{-1234, -3456, -6789, 1234, 3456, 6789}, scale_type{-3}};
auto const expected = fp_wrapper{{1234, 3456, 6789, -1234, -3456, -6789}, scale_type{-3}};
auto const result = cudf::unary_operation(input, cudf::unary_operator::NEGATE);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view());
}

TYPED_TEST(FixedPointUnaryTests, FixedPointUnaryAbs)
{
using namespace numeric;
Expand Down
13 changes: 13 additions & 0 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.core.buffer import Buffer

_SUPPORTED_UNARY_OPERATIONS: set = {"ABS", "CEIL", "FLOOR", "NEGATE"}


class DecimalBaseColumn(NumericalBaseColumn):
"""Base column for decimal32, decimal64 or decimal128 columns"""
Expand Down Expand Up @@ -226,6 +228,17 @@ def as_numerical_column(
) -> "cudf.core.column.NumericalColumn":
return unary.cast(self, dtype) # type: ignore[return-value]

def unary_operator(self, unaryop: str) -> ColumnBase:
# TODO: Support Callable unary operations via numba
unaryop = unaryop.upper()
if unaryop in _SUPPORTED_UNARY_OPERATIONS:
unaryop = plc.unary.UnaryOperator[unaryop]
return unary.unary_operation(self, unaryop)
else:
raise TypeError(
f"Operation {unaryop} not supported for dtype {self.dtype}."
)


class Decimal32Column(DecimalBaseColumn):
def __init__(
Expand Down
7 changes: 7 additions & 0 deletions python/cudf/cudf/tests/test_unaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import operator
import re
from decimal import Decimal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -135,3 +136,9 @@ def test_series_bool_neg():
sr = Series([True, False, True, None, False, None, True, True])
psr = sr.to_pandas(nullable=True)
assert_eq((-sr).to_pandas(nullable=True), -psr, check_dtype=True)


def test_series_decimal_neg():
sr = Series([Decimal("1.23"), Decimal("4.567")])
psr = sr.to_pandas()
assert_eq((-sr).to_pandas(), -psr, check_dtype=True)

0 comments on commit cef040c

Please sign in to comment.