Skip to content

Commit

Permalink
fix(udf): fix decimal values (#11839)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Dec 1, 2023
1 parent da79ff5 commit 0bd10c4
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 84 deletions.
1 change: 1 addition & 0 deletions ci/scripts/build-other.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cd java
mvn -B package -Dmaven.test.skip=true
mvn -B install -Dmaven.test.skip=true --pl java-binding-integration-test --am
mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test
mvn -B test --pl udf
cd ..

echo "--- Build rust binary for java binding integration test"
Expand Down
6 changes: 6 additions & 0 deletions e2e_test/udf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]:
return dec


@udf(input_types=["DECIMAL", "DECIMAL"], result_type="DECIMAL")
def decimal_add(a: Decimal, b: Decimal) -> Decimal:
return a + b


@udf(input_types=["VARCHAR[]", "INT"], result_type="VARCHAR")
def array_access(list: List[str], idx: int) -> Optional[str]:
if idx == 0 or idx > len(list):
Expand Down Expand Up @@ -212,6 +217,7 @@ def return_all_arrays(
server.add_function(split)
server.add_function(extract_tcp_info)
server.add_function(hex_to_dec)
server.add_function(decimal_add)
server.add_function(array_access)
server.add_function(jsonb_access)
server.add_function(jsonb_concat)
Expand Down
20 changes: 16 additions & 4 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ create function split(varchar) returns table (word varchar, length int) as split
statement ok
create function hex_to_dec(varchar) returns decimal as hex_to_dec using link 'http://localhost:8815';

statement ok
create function decimal_add(decimal, decimal) returns decimal as decimal_add using link 'http://localhost:8815';

statement ok
create function array_access(varchar[], int) returns varchar as array_access using link 'http://localhost:8815';

Expand Down Expand Up @@ -72,6 +75,7 @@ query TTTTT rowsort
show functions
----
array_access character varying[], integer character varying (empty) http://localhost:8815
decimal_add numeric, numeric numeric (empty) http://localhost:8815
extract_tcp_info bytea struct<src_ip character varying, dst_ip character varying, src_port smallint, dst_port smallint> (empty) http://localhost:8815
gcd integer, integer integer (empty) http://localhost:8815
gcd integer, integer, integer integer (empty) http://localhost:8815
Expand Down Expand Up @@ -106,6 +110,11 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90
----
233276425899864771438119478

query R
select decimal_add(1.11, 2.22);
----
3.33

query T
select array_access(ARRAY['a', 'b', 'c'], 2);
----
Expand Down Expand Up @@ -142,7 +151,7 @@ select (return_all(
1 ::bigint,
1 ::float4,
1 ::float8,
1234567890123456789012345678 ::decimal,
12345678901234567890.12345678 ::decimal,
date '2023-06-01',
time '01:02:03.456789',
timestamp '2023-06-01 01:02:03.456789',
Expand All @@ -153,7 +162,7 @@ select (return_all(
row(1, 2)::struct<f1 int, f2 int>
)).*;
----
t 1 1 1 1 1 1234567890123456789012345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2)
t 1 1 1 1 1 12345678901234567890.12345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2)

query T
select (return_all_arrays(
Expand All @@ -163,7 +172,7 @@ select (return_all_arrays(
array[null, 1 ::bigint],
array[null, 1 ::float4],
array[null, 1 ::float8],
array[null, 1234567890123456789012345678 ::decimal],
array[null, 12345678901234567890.12345678 ::decimal],
array[null, date '2023-06-01'],
array[null, time '01:02:03.456789'],
array[null, timestamp '2023-06-01 01:02:03.456789'],
Expand All @@ -174,7 +183,7 @@ select (return_all_arrays(
array[null, row(1, 2)::struct<f1 int, f2 int>]
)).*;
----
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1234567890123456789012345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"}
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,12345678901234567890.12345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"}

query I
select series(5);
Expand Down Expand Up @@ -315,6 +324,9 @@ drop function split;
statement ok
drop function hex_to_dec;

statement ok
drop function decimal_add;

statement ok
drop function array_access;

Expand Down
7 changes: 7 additions & 0 deletions java/udf-example/src/main/java/com/example/UdfExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public static void main(String[] args) throws IOException {
server.addFunction("gcd3", new Gcd3());
server.addFunction("extract_tcp_info", new ExtractTcpInfo());
server.addFunction("hex_to_dec", new HexToDec());
server.addFunction("decimal_add", new DecimalAdd());
server.addFunction("array_access", new ArrayAccess());
server.addFunction("jsonb_access", new JsonbAccess());
server.addFunction("jsonb_concat", new JsonbConcat());
Expand Down Expand Up @@ -126,6 +127,12 @@ public BigDecimal eval(String hex) {
}
}

public static class DecimalAdd implements ScalarFunction {
public BigDecimal eval(BigDecimal a, BigDecimal b) {
return a.add(b);
}
}

public static class ArrayAccess implements ScalarFunction {
public String eval(String[] array, int index) {
return array[index - 1];
Expand Down
6 changes: 5 additions & 1 deletion java/udf/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.1] - 2023-11-28
## [0.1.1] - 2023-12-01

### Added

Expand All @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Bump Arrow version to 14.

### Fixed

- Fix unconstrained decimal type.

## [0.1.0] - 2023-09-01

- Initial release.
24 changes: 14 additions & 10 deletions java/udf/src/main/java/com/risingwave/functions/TypeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ static Field stringToField(String typeStr, String name) {
return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE));
} else if (typeStr.equals("FLOAT8") || typeStr.equals("DOUBLE PRECISION")) {
return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
} else if (typeStr.startsWith("DECIMAL") || typeStr.startsWith("NUMERIC")) {
return Field.nullable(name, new ArrowType.Decimal(38, 0, 128));
} else if (typeStr.equals("DECIMAL") || typeStr.equals("NUMERIC")) {
return Field.nullable(name, new ArrowType.LargeBinary());
} else if (typeStr.equals("DATE")) {
return Field.nullable(name, new ArrowType.Date(DateUnit.DAY));
} else if (typeStr.equals("TIME") || typeStr.equals("TIME WITHOUT TIME ZONE")) {
Expand Down Expand Up @@ -110,7 +110,7 @@ static Field classToField(Class<?> param, DataTypeHint hint, String name) {
} else if (param == Double.class || param == double.class) {
return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
} else if (param == BigDecimal.class) {
return Field.nullable(name, new ArrowType.Decimal(38, 0, 128));
return Field.nullable(name, new ArrowType.LargeBinary());
} else if (param == LocalDate.class) {
return Field.nullable(name, new ArrowType.Date(DateUnit.DAY));
} else if (param == LocalTime.class) {
Expand Down Expand Up @@ -240,12 +240,12 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
vector.set(i, (double) values[i]);
}
}
} else if (fieldVector instanceof DecimalVector) {
var vector = (DecimalVector) fieldVector;
} else if (fieldVector instanceof LargeVarBinaryVector) {
var vector = (LargeVarBinaryVector) fieldVector;
vector.allocateNew(values.length);
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
vector.set(i, (BigDecimal) values[i]);
vector.set(i, ((BigDecimal) values[i]).toString().getBytes());
}
}
} else if (fieldVector instanceof DateDayVector) {
Expand Down Expand Up @@ -329,9 +329,9 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
} else if (vector.getDataVector() instanceof Float8Vector) {
TypeUtils.<Float8Vector, Double>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof DecimalVector) {
TypeUtils.<DecimalVector, BigDecimal>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof LargeVarBinaryVector) {
TypeUtils.<LargeVarBinaryVector, BigDecimal>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val.toString().getBytes()));
} else if (vector.getDataVector() instanceof DateDayVector) {
TypeUtils.<DateDayVector, LocalDate>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, (int) val.toEpochDay()));
Expand Down Expand Up @@ -476,6 +476,10 @@ static Function<Object, Object> processFunc0(Field field, Class<?> targetClass)
} else if (field.getType() instanceof ArrowType.LargeUtf8 && targetClass == String.class) {
// object is org.apache.arrow.vector.util.Text
return obj -> obj.toString();
} else if (field.getType() instanceof ArrowType.LargeBinary
&& targetClass == BigDecimal.class) {
// object is byte[]
return obj -> new BigDecimal(new String((byte[]) obj));
} else if (field.getType() instanceof ArrowType.Date && targetClass == LocalDate.class) {
// object is Integer
return obj -> LocalDate.ofEpochDay((int) obj);
Expand Down Expand Up @@ -504,7 +508,7 @@ static Function<Object, Object> processFunc0(Field field, Class<?> targetClass)
} else if (subfield.getType()
.equals(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(Double[]::new);
} else if (subfield.getType() instanceof ArrowType.Decimal) {
} else if (subfield.getType() instanceof ArrowType.LargeBinary) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(BigDecimal[]::new);
} else if (subfield.getType() instanceof ArrowType.Date) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(LocalDate[]::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ public void all_types() throws Exception {
c5.set(0, 1);
c5.setValueCount(2);

var c6 = new DecimalVector("", allocator, 38, 0);
var c6 = new LargeVarBinaryVector("", allocator);
c6.allocateNew(2);
c6.set(0, BigDecimal.valueOf(10).pow(37));
c6.set(0, "1.234".getBytes());
c6.setValueCount(2);

var c7 = new DateDayVector("", allocator);
Expand Down Expand Up @@ -255,7 +255,7 @@ public void all_types() throws Exception {
var output = stream.getRoot();
assertTrue(stream.next());
assertEquals(
"{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":10000000000000000000000000000000000000,\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}",
"{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":\"MS4yMzQ=\",\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}",
output.contentToTSVString().trim());
}
}
Expand Down
91 changes: 33 additions & 58 deletions src/common/src/array/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ converts_generic! {
{ arrow_array::Float64Array, Float64, ArrayImpl::Float64 },
{ arrow_array::StringArray, Utf8, ArrayImpl::Utf8 },
{ arrow_array::BooleanArray, Boolean, ArrayImpl::Bool },
{ arrow_array::Decimal128Array, Decimal128(_, _), ArrayImpl::Decimal },
// Arrow doesn't have a data type to represent unconstrained numeric (`DECIMAL` in RisingWave and
// Postgres). So we pick a special type `LargeBinary` for it.
// Values stored in the array are the string representation of the decimal. e.g. b"1.234", b"+inf"
{ arrow_array::LargeBinaryArray, LargeBinary, ArrayImpl::Decimal },
{ arrow_array::Decimal256Array, Decimal256(_, _), ArrayImpl::Int256 },
{ arrow_array::Date32Array, Date32, ArrayImpl::Date },
{ arrow_array::TimestampMicrosecondArray, Timestamp(Microsecond, None), ArrayImpl::Timestamp },
Expand All @@ -169,7 +172,7 @@ impl From<&arrow_schema::DataType> for DataType {
Int64 => Self::Int64,
Float32 => Self::Float32,
Float64 => Self::Float64,
Decimal128(_, _) => Self::Decimal,
LargeBinary => Self::Decimal,
Decimal256(_, _) => Self::Int256,
Date32 => Self::Date,
Time64(Microsecond) => Self::Time,
Expand Down Expand Up @@ -237,7 +240,7 @@ impl TryFrom<&DataType> for arrow_schema::DataType {
DataType::Varchar => Ok(Self::Utf8),
DataType::Jsonb => Ok(Self::LargeUtf8),
DataType::Bytea => Ok(Self::Binary),
DataType::Decimal => Ok(Self::Decimal128(38, 0)), // arrow precision can not be 0
DataType::Decimal => Ok(Self::LargeBinary),
DataType::Struct(struct_type) => Ok(Self::Struct(
struct_type
.iter()
Expand Down Expand Up @@ -462,53 +465,33 @@ impl FromIntoArrow for Interval {
}
}

// RisingWave Decimal type is self-contained, but Arrow is not.
// In Arrow DecimalArray, the scale is stored in data type as metadata, and the mantissa is stored
// as i128 in the array.
impl From<&DecimalArray> for arrow_array::Decimal128Array {
impl From<&DecimalArray> for arrow_array::LargeBinaryArray {
fn from(array: &DecimalArray) -> Self {
let max_scale = array
.iter()
.filter_map(|o| o.map(|v| v.scale().unwrap_or(0)))
.max()
.unwrap_or(0) as u32;
let mut builder = arrow_array::builder::Decimal128Builder::with_capacity(array.len())
.with_data_type(arrow_schema::DataType::Decimal128(38, max_scale as i8));
let mut builder =
arrow_array::builder::LargeBinaryBuilder::with_capacity(array.len(), array.len() * 8);
for value in array.iter() {
builder.append_option(value.map(|d| decimal_to_i128(d, max_scale)));
builder.append_option(value.map(|d| d.to_string()));
}
builder.finish()
}
}

fn decimal_to_i128(value: Decimal, scale: u32) -> i128 {
match value {
Decimal::Normalized(mut d) => {
d.rescale(scale);
d.mantissa()
}
Decimal::NaN => i128::MIN + 1,
Decimal::PositiveInf => i128::MAX,
Decimal::NegativeInf => i128::MIN,
}
}
impl TryFrom<&arrow_array::LargeBinaryArray> for DecimalArray {
type Error = ArrayError;

impl From<&arrow_array::Decimal128Array> for DecimalArray {
fn from(array: &arrow_array::Decimal128Array) -> Self {
assert!(array.scale() >= 0, "todo: support negative scale");
let from_arrow = |value| {
const NAN: i128 = i128::MIN + 1;
match value {
NAN => Decimal::NaN,
i128::MAX => Decimal::PositiveInf,
i128::MIN => Decimal::NegativeInf,
_ => Decimal::Normalized(rust_decimal::Decimal::from_i128_with_scale(
value,
array.scale() as u32,
)),
}
};
array.iter().map(|o| o.map(from_arrow)).collect()
fn try_from(array: &arrow_array::LargeBinaryArray) -> Result<Self, Self::Error> {
array
.iter()
.map(|o| {
o.map(|s| {
let s = std::str::from_utf8(s)
.map_err(|_| ArrayError::FromArrow(format!("invalid decimal: {s:?}")))?;
s.parse()
.map_err(|_| ArrayError::FromArrow(format!("invalid decimal: {s:?}")))
})
.transpose()
})
.try_collect()
}
}

Expand Down Expand Up @@ -631,20 +614,12 @@ impl TryFrom<&ListArray> for arrow_array::ListArray {
b.append_option(v)
})
}
ArrayImpl::Decimal(a) => {
let max_scale = a
.iter()
.filter_map(|o| o.map(|v| v.scale().unwrap_or(0)))
.max()
.unwrap_or(0) as u32;
build(
array,
a,
Decimal128Builder::with_capacity(a.len())
.with_data_type(arrow_schema::DataType::Decimal128(38, max_scale as i8)),
|b, v| b.append_option(v.map(|d| decimal_to_i128(d, max_scale))),
)
}
ArrayImpl::Decimal(a) => build(
array,
a,
LargeBinaryBuilder::with_capacity(a.len(), a.len() * 8),
|b, v| b.append_option(v.map(|d| d.to_string())),
),
ArrayImpl::Interval(a) => build(
array,
a,
Expand Down Expand Up @@ -842,8 +817,8 @@ mod tests {
Some(Decimal::Normalized("123.4".parse().unwrap())),
Some(Decimal::Normalized("123.456".parse().unwrap())),
]);
let arrow = arrow_array::Decimal128Array::from(&array);
assert_eq!(DecimalArray::from(&arrow), array);
let arrow = arrow_array::LargeBinaryArray::from(&array);
assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array);
}

#[test]
Expand Down
6 changes: 6 additions & 0 deletions src/udf/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.0] - 2023-12-01

### Fixed

- Fix unconstrained decimal type.

## [0.0.12] - 2023-11-28

### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/udf/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "risingwave"
version = "0.0.12"
version = "0.1.0"
authors = [{ name = "RisingWave Labs" }]
description = "RisingWave Python API"
readme = "README.md"
Expand Down
Loading

0 comments on commit 0bd10c4

Please sign in to comment.