Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(udf): fix decimal values #11839

Merged
merged 14 commits into from
Dec 1, 2023
Merged
1 change: 1 addition & 0 deletions ci/scripts/build-other.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cd java
mvn -B package -Dmaven.test.skip=true
mvn -B install -Dmaven.test.skip=true --pl java-binding-integration-test --am
mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test
mvn -B test --pl udf
cd ..

echo "--- Build rust binary for java binding integration test"
Expand Down
6 changes: 6 additions & 0 deletions e2e_test/udf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]:
return dec


@udf(input_types=["DECIMAL", "DECIMAL"], result_type="DECIMAL")
def decimal_add(a: Decimal, b: Decimal) -> Decimal:
return a + b


@udf(input_types=["VARCHAR[]", "INT"], result_type="VARCHAR")
def array_access(list: List[str], idx: int) -> Optional[str]:
if idx == 0 or idx > len(list):
Expand Down Expand Up @@ -212,6 +217,7 @@ def return_all_arrays(
server.add_function(split)
server.add_function(extract_tcp_info)
server.add_function(hex_to_dec)
server.add_function(decimal_add)
server.add_function(array_access)
server.add_function(jsonb_access)
server.add_function(jsonb_concat)
Expand Down
20 changes: 16 additions & 4 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ create function split(varchar) returns table (word varchar, length int) as split
statement ok
create function hex_to_dec(varchar) returns decimal as hex_to_dec using link 'http://localhost:8815';

statement ok
create function decimal_add(decimal, decimal) returns decimal as decimal_add using link 'http://localhost:8815';

statement ok
create function array_access(varchar[], int) returns varchar as array_access using link 'http://localhost:8815';

Expand Down Expand Up @@ -72,6 +75,7 @@ query TTTTT rowsort
show functions
----
array_access character varying[], integer character varying (empty) http://localhost:8815
decimal_add numeric, numeric numeric (empty) http://localhost:8815
extract_tcp_info bytea struct<src_ip character varying, dst_ip character varying, src_port smallint, dst_port smallint> (empty) http://localhost:8815
gcd integer, integer integer (empty) http://localhost:8815
gcd integer, integer, integer integer (empty) http://localhost:8815
Expand Down Expand Up @@ -106,6 +110,11 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90
----
233276425899864771438119478

query R
select decimal_add(1.11, 2.22);
----
3.33

query T
select array_access(ARRAY['a', 'b', 'c'], 2);
----
Expand Down Expand Up @@ -142,7 +151,7 @@ select (return_all(
1 ::bigint,
1 ::float4,
1 ::float8,
1234567890123456789012345678 ::decimal,
12345678901234567890.12345678 ::decimal,
date '2023-06-01',
time '01:02:03.456789',
timestamp '2023-06-01 01:02:03.456789',
Expand All @@ -153,7 +162,7 @@ select (return_all(
row(1, 2)::struct<f1 int, f2 int>
)).*;
----
t 1 1 1 1 1 1234567890123456789012345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2)
t 1 1 1 1 1 12345678901234567890.12345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2)

query T
select (return_all_arrays(
Expand All @@ -163,7 +172,7 @@ select (return_all_arrays(
array[null, 1 ::bigint],
array[null, 1 ::float4],
array[null, 1 ::float8],
array[null, 1234567890123456789012345678 ::decimal],
array[null, 12345678901234567890.12345678 ::decimal],
array[null, date '2023-06-01'],
array[null, time '01:02:03.456789'],
array[null, timestamp '2023-06-01 01:02:03.456789'],
Expand All @@ -174,7 +183,7 @@ select (return_all_arrays(
array[null, row(1, 2)::struct<f1 int, f2 int>]
)).*;
----
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1234567890123456789012345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"}
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,12345678901234567890.12345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"}

query I
select series(5);
Expand Down Expand Up @@ -315,6 +324,9 @@ drop function split;
statement ok
drop function hex_to_dec;

statement ok
drop function decimal_add;

statement ok
drop function array_access;

Expand Down
7 changes: 7 additions & 0 deletions java/udf-example/src/main/java/com/example/UdfExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public static void main(String[] args) throws IOException {
server.addFunction("gcd3", new Gcd3());
server.addFunction("extract_tcp_info", new ExtractTcpInfo());
server.addFunction("hex_to_dec", new HexToDec());
server.addFunction("decimal_add", new DecimalAdd());
server.addFunction("array_access", new ArrayAccess());
server.addFunction("jsonb_access", new JsonbAccess());
server.addFunction("jsonb_concat", new JsonbConcat());
Expand Down Expand Up @@ -126,6 +127,12 @@ public BigDecimal eval(String hex) {
}
}

public static class DecimalAdd implements ScalarFunction {
public BigDecimal eval(BigDecimal a, BigDecimal b) {
return a.add(b);
}
}

public static class ArrayAccess implements ScalarFunction {
public String eval(String[] array, int index) {
return array[index - 1];
Expand Down
24 changes: 14 additions & 10 deletions java/udf/src/main/java/com/risingwave/functions/TypeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ static Field stringToField(String typeStr, String name) {
return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE));
} else if (typeStr.equals("FLOAT8") || typeStr.equals("DOUBLE PRECISION")) {
return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
} else if (typeStr.startsWith("DECIMAL") || typeStr.startsWith("NUMERIC")) {
return Field.nullable(name, new ArrowType.Decimal(38, 0, 128));
} else if (typeStr.equals("DECIMAL") || typeStr.equals("NUMERIC")) {
return Field.nullable(name, new ArrowType.LargeBinary());
} else if (typeStr.equals("DATE")) {
return Field.nullable(name, new ArrowType.Date(DateUnit.DAY));
} else if (typeStr.equals("TIME") || typeStr.equals("TIME WITHOUT TIME ZONE")) {
Expand Down Expand Up @@ -110,7 +110,7 @@ static Field classToField(Class<?> param, DataTypeHint hint, String name) {
} else if (param == Double.class || param == double.class) {
return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
} else if (param == BigDecimal.class) {
return Field.nullable(name, new ArrowType.Decimal(38, 0, 128));
return Field.nullable(name, new ArrowType.LargeBinary());
} else if (param == LocalDate.class) {
return Field.nullable(name, new ArrowType.Date(DateUnit.DAY));
} else if (param == LocalTime.class) {
Expand Down Expand Up @@ -240,12 +240,12 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
vector.set(i, (double) values[i]);
}
}
} else if (fieldVector instanceof DecimalVector) {
var vector = (DecimalVector) fieldVector;
} else if (fieldVector instanceof LargeVarBinaryVector) {
var vector = (LargeVarBinaryVector) fieldVector;
vector.allocateNew(values.length);
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
vector.set(i, (BigDecimal) values[i]);
vector.set(i, ((BigDecimal) values[i]).toString().getBytes());
}
}
} else if (fieldVector instanceof DateDayVector) {
Expand Down Expand Up @@ -329,9 +329,9 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
} else if (vector.getDataVector() instanceof Float8Vector) {
TypeUtils.<Float8Vector, Double>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof DecimalVector) {
TypeUtils.<DecimalVector, BigDecimal>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof LargeVarBinaryVector) {
TypeUtils.<LargeVarBinaryVector, BigDecimal>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val.toString().getBytes()));
} else if (vector.getDataVector() instanceof DateDayVector) {
TypeUtils.<DateDayVector, LocalDate>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, (int) val.toEpochDay()));
Expand Down Expand Up @@ -476,6 +476,10 @@ static Function<Object, Object> processFunc0(Field field, Class<?> targetClass)
} else if (field.getType() instanceof ArrowType.LargeUtf8 && targetClass == String.class) {
// object is org.apache.arrow.vector.util.Text
return obj -> obj.toString();
} else if (field.getType() instanceof ArrowType.LargeBinary
&& targetClass == BigDecimal.class) {
// object is byte[]
return obj -> new BigDecimal(new String((byte[]) obj));
} else if (field.getType() instanceof ArrowType.Date && targetClass == LocalDate.class) {
// object is Integer
return obj -> LocalDate.ofEpochDay((int) obj);
Expand Down Expand Up @@ -504,7 +508,7 @@ static Function<Object, Object> processFunc0(Field field, Class<?> targetClass)
} else if (subfield.getType()
.equals(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(Double[]::new);
} else if (subfield.getType() instanceof ArrowType.Decimal) {
} else if (subfield.getType() instanceof ArrowType.LargeBinary) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(BigDecimal[]::new);
} else if (subfield.getType() instanceof ArrowType.Date) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(LocalDate[]::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ public void all_types() throws Exception {
c5.set(0, 1);
c5.setValueCount(2);

var c6 = new DecimalVector("", allocator, 38, 0);
var c6 = new LargeVarBinaryVector("", allocator);
c6.allocateNew(2);
c6.set(0, BigDecimal.valueOf(10).pow(37));
c6.set(0, "1.234".getBytes());
c6.setValueCount(2);

var c7 = new DateDayVector("", allocator);
Expand Down Expand Up @@ -255,7 +255,7 @@ public void all_types() throws Exception {
var output = stream.getRoot();
assertTrue(stream.next());
assertEquals(
"{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":10000000000000000000000000000000000000,\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}",
"{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":\"MS4yMzQ=\",\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}",
output.contentToTSVString().trim());
}
}
Expand Down
91 changes: 33 additions & 58 deletions src/common/src/array/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ converts_generic! {
{ arrow_array::Float64Array, Float64, ArrayImpl::Float64 },
{ arrow_array::StringArray, Utf8, ArrayImpl::Utf8 },
{ arrow_array::BooleanArray, Boolean, ArrayImpl::Bool },
{ arrow_array::Decimal128Array, Decimal128(_, _), ArrayImpl::Decimal },
// Arrow doesn't have a data type to represent unconstrained numeric (`DECIMAL` in RisingWave and
// Postgres). So we pick a special type `LargeBinary` for it.
// Values stored in the array are the string representation of the decimal. e.g. b"1.234", b"+inf"
{ arrow_array::LargeBinaryArray, LargeBinary, ArrayImpl::Decimal },
{ arrow_array::Decimal256Array, Decimal256(_, _), ArrayImpl::Int256 },
{ arrow_array::Date32Array, Date32, ArrayImpl::Date },
{ arrow_array::TimestampMicrosecondArray, Timestamp(Microsecond, None), ArrayImpl::Timestamp },
Expand All @@ -169,7 +172,7 @@ impl From<&arrow_schema::DataType> for DataType {
Int64 => Self::Int64,
Float32 => Self::Float32,
Float64 => Self::Float64,
Decimal128(_, _) => Self::Decimal,
LargeBinary => Self::Decimal,
Decimal256(_, _) => Self::Int256,
Date32 => Self::Date,
Time64(Microsecond) => Self::Time,
Expand Down Expand Up @@ -237,7 +240,7 @@ impl TryFrom<&DataType> for arrow_schema::DataType {
DataType::Varchar => Ok(Self::Utf8),
DataType::Jsonb => Ok(Self::LargeUtf8),
DataType::Bytea => Ok(Self::Binary),
DataType::Decimal => Ok(Self::Decimal128(38, 0)), // arrow precision can not be 0
DataType::Decimal => Ok(Self::LargeBinary),
DataType::Struct(struct_type) => Ok(Self::Struct(
struct_type
.iter()
Expand Down Expand Up @@ -462,53 +465,33 @@ impl FromIntoArrow for Interval {
}
}

// RisingWave Decimal type is self-contained, but Arrow is not.
// In Arrow DecimalArray, the scale is stored in data type as metadata, and the mantissa is stored
// as i128 in the array.
impl From<&DecimalArray> for arrow_array::Decimal128Array {
impl From<&DecimalArray> for arrow_array::LargeBinaryArray {
fn from(array: &DecimalArray) -> Self {
let max_scale = array
.iter()
.filter_map(|o| o.map(|v| v.scale().unwrap_or(0)))
.max()
.unwrap_or(0) as u32;
let mut builder = arrow_array::builder::Decimal128Builder::with_capacity(array.len())
.with_data_type(arrow_schema::DataType::Decimal128(38, max_scale as i8));
let mut builder =
arrow_array::builder::LargeBinaryBuilder::with_capacity(array.len(), array.len() * 8);
for value in array.iter() {
builder.append_option(value.map(|d| decimal_to_i128(d, max_scale)));
builder.append_option(value.map(|d| d.to_string()));
}
builder.finish()
}
}

fn decimal_to_i128(value: Decimal, scale: u32) -> i128 {
match value {
Decimal::Normalized(mut d) => {
d.rescale(scale);
d.mantissa()
}
Decimal::NaN => i128::MIN + 1,
Decimal::PositiveInf => i128::MAX,
Decimal::NegativeInf => i128::MIN,
}
}
impl TryFrom<&arrow_array::LargeBinaryArray> for DecimalArray {
type Error = ArrayError;

impl From<&arrow_array::Decimal128Array> for DecimalArray {
fn from(array: &arrow_array::Decimal128Array) -> Self {
assert!(array.scale() >= 0, "todo: support negative scale");
let from_arrow = |value| {
const NAN: i128 = i128::MIN + 1;
match value {
NAN => Decimal::NaN,
i128::MAX => Decimal::PositiveInf,
i128::MIN => Decimal::NegativeInf,
_ => Decimal::Normalized(rust_decimal::Decimal::from_i128_with_scale(
value,
array.scale() as u32,
)),
}
};
array.iter().map(|o| o.map(from_arrow)).collect()
fn try_from(array: &arrow_array::LargeBinaryArray) -> Result<Self, Self::Error> {
array
.iter()
.map(|o| {
o.map(|s| {
let s = std::str::from_utf8(s)
.map_err(|_| ArrayError::FromArrow(format!("invalid decimal: {s:?}")))?;
s.parse()
.map_err(|_| ArrayError::FromArrow(format!("invalid decimal: {s:?}")))
})
.transpose()
})
.try_collect()
}
}

Expand Down Expand Up @@ -631,20 +614,12 @@ impl TryFrom<&ListArray> for arrow_array::ListArray {
b.append_option(v)
})
}
ArrayImpl::Decimal(a) => {
let max_scale = a
.iter()
.filter_map(|o| o.map(|v| v.scale().unwrap_or(0)))
.max()
.unwrap_or(0) as u32;
build(
array,
a,
Decimal128Builder::with_capacity(a.len())
.with_data_type(arrow_schema::DataType::Decimal128(38, max_scale as i8)),
|b, v| b.append_option(v.map(|d| decimal_to_i128(d, max_scale))),
)
}
ArrayImpl::Decimal(a) => build(
array,
a,
LargeBinaryBuilder::with_capacity(a.len(), a.len() * 8),
|b, v| b.append_option(v.map(|d| d.to_string())),
),
ArrayImpl::Interval(a) => build(
array,
a,
Expand Down Expand Up @@ -842,8 +817,8 @@ mod tests {
Some(Decimal::Normalized("123.4".parse().unwrap())),
Some(Decimal::Normalized("123.456".parse().unwrap())),
]);
let arrow = arrow_array::Decimal128Array::from(&array);
assert_eq!(DecimalArray::from(&arrow), array);
let arrow = arrow_array::LargeBinaryArray::from(&array);
assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array);
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/udf/python/risingwave/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_all_types():
pa.array([1], type=pa.int64()),
pa.array([1], type=pa.float32()),
pa.array([1], type=pa.float64()),
pa.array([10**37], type=pa.decimal128(38)),
pa.array(["12345678901234567890.1234567890"], type=pa.large_binary()),
pa.array([datetime.date(2023, 6, 1)], type=pa.date32()),
pa.array([datetime.time(1, 2, 3, 456789)], type=pa.time64("us")),
pa.array(
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_all_types():
1,
1.0,
1.0,
10**37,
b"12345678901234567890.1234567890",
datetime.date(2023, 6, 1),
datetime.time(1, 2, 3, 456789),
datetime.datetime(2023, 6, 1, 1, 2, 3, 456789),
Expand Down
Loading
Loading