From 330a1ad949479e003c5695a0bdfbad785d5754f8 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 22 Aug 2023 21:27:33 +0800 Subject: [PATCH 01/11] use large binary type to represent unconstrained numeric in arrow Signed-off-by: Runji Wang --- src/common/src/array/arrow.rs | 91 ++++++++++++-------------------- src/udf/python/risingwave/udf.py | 27 ++++++++-- 2 files changed, 55 insertions(+), 63 deletions(-) diff --git a/src/common/src/array/arrow.rs b/src/common/src/array/arrow.rs index 9aaec9acbf1fe..935aeccff3e35 100644 --- a/src/common/src/array/arrow.rs +++ b/src/common/src/array/arrow.rs @@ -108,7 +108,10 @@ converts_generic! { { arrow_array::Float64Array, Float64, ArrayImpl::Float64 }, { arrow_array::StringArray, Utf8, ArrayImpl::Utf8 }, { arrow_array::BooleanArray, Boolean, ArrayImpl::Bool }, - { arrow_array::Decimal128Array, Decimal128(_, _), ArrayImpl::Decimal }, + // Arrow doesn't have a data type to represent unconstrained numeric (`DECIMAL` in RisingWave and + // Postgres). So we pick a special type `LargeBinary` for it. + // Values stored in the array are the string representation of the decimal. e.g. b"1.234", b"+inf" + { arrow_array::LargeBinaryArray, LargeBinary, ArrayImpl::Decimal }, { arrow_array::Decimal256Array, Decimal256(_, _), ArrayImpl::Int256 }, { arrow_array::Date32Array, Date32, ArrayImpl::Date }, { arrow_array::TimestampMicrosecondArray, Timestamp(Microsecond, None), ArrayImpl::Timestamp }, @@ -134,7 +137,7 @@ impl From<&arrow_schema::DataType> for DataType { Int64 => Self::Int64, Float32 => Self::Float32, Float64 => Self::Float64, - Decimal128(_, _) => Self::Decimal, + LargeBinary => Self::Decimal, Decimal256(_, _) => Self::Int256, Date32 => Self::Date, Time64(Microsecond) => Self::Time, @@ -191,7 +194,7 @@ impl TryFrom<&DataType> for arrow_schema::DataType { DataType::Varchar => Ok(Self::Utf8), DataType::Jsonb => Ok(Self::LargeUtf8), DataType::Bytea => Ok(Self::Binary), - DataType::Decimal => Ok(Self::Decimal128(38, 0)), // arrow precision can not be 0 + DataType::Decimal => Ok(Self::LargeBinary), DataType::Struct(struct_type) => Ok(Self::Struct( struct_type .iter() @@ -408,53 +411,33 @@ impl FromIntoArrow for Interval { } } -// RisingWave Decimal type is self-contained, but Arrow is not. -// In Arrow DecimalArray, the scale is stored in data type as metadata, and the mantissa is stored -// as i128 in the array. -impl From<&DecimalArray> for arrow_array::Decimal128Array { +impl From<&DecimalArray> for arrow_array::LargeBinaryArray { fn from(array: &DecimalArray) -> Self { - let max_scale = array - .iter() - .filter_map(|o| o.map(|v| v.scale().unwrap_or(0))) - .max() - .unwrap_or(0) as u32; - let mut builder = arrow_array::builder::Decimal128Builder::with_capacity(array.len()) - .with_data_type(arrow_schema::DataType::Decimal128(38, max_scale as i8)); + let mut builder = + arrow_array::builder::LargeBinaryBuilder::with_capacity(array.len(), array.len() * 8); for value in array.iter() { - builder.append_option(value.map(|d| decimal_to_i128(d, max_scale))); + builder.append_option(value.map(|d| d.to_string())); } builder.finish() } } -fn decimal_to_i128(value: Decimal, scale: u32) -> i128 { - match value { - Decimal::Normalized(mut d) => { - d.rescale(scale); - d.mantissa() - } - Decimal::NaN => i128::MIN + 1, - Decimal::PositiveInf => i128::MAX, - Decimal::NegativeInf => i128::MIN, - } -} +impl TryFrom<&arrow_array::LargeBinaryArray> for DecimalArray { + type Error = ArrayError; -impl From<&arrow_array::Decimal128Array> for DecimalArray { - fn from(array: &arrow_array::Decimal128Array) -> Self { - assert!(array.scale() >= 0, "todo: support negative scale"); - let from_arrow = |value| { - const NAN: i128 = i128::MIN + 1; - match value { - NAN => Decimal::NaN, - i128::MAX => Decimal::PositiveInf, - i128::MIN => Decimal::NegativeInf, - _ => Decimal::Normalized(rust_decimal::Decimal::from_i128_with_scale( - value, - array.scale() as u32, - )), - } - }; - array.iter().map(|o| o.map(from_arrow)).collect() + fn try_from(array: &arrow_array::LargeBinaryArray) -> Result { + array + .iter() + .map(|o| { + o.map(|s| { + let s = std::str::from_utf8(s) + .map_err(|_| ArrayError::FromArrow(format!("invalid decimal: {s:?}")))?; + s.parse() + .map_err(|_| ArrayError::FromArrow(format!("invalid decimal: {s:?}"))) + }) + .transpose() + }) + .try_collect() } } @@ -575,20 +558,12 @@ impl From<&ListArray> for arrow_array::ListArray { b.append_option(v) }) } - ArrayImpl::Decimal(a) => { - let max_scale = a - .iter() - .filter_map(|o| o.map(|v| v.scale().unwrap_or(0))) - .max() - .unwrap_or(0) as u32; - build( - array, - a, - Decimal128Builder::with_capacity(a.len()) - .with_data_type(arrow_schema::DataType::Decimal128(38, max_scale as i8)), - |b, v| b.append_option(v.map(|d| decimal_to_i128(d, max_scale))), - ) - } + ArrayImpl::Decimal(a) => build( + array, + a, + LargeBinaryBuilder::with_capacity(a.len(), a.len() * 8), + |b, v| b.append_option(v.map(|d| d.to_string())), + ), ArrayImpl::Interval(a) => build( array, a, @@ -772,8 +747,8 @@ mod tests { Some(Decimal::Normalized("123.4".parse().unwrap())), Some(Decimal::Normalized("123.456".parse().unwrap())), ]); - let arrow = arrow_array::Decimal128Array::from(&array); - assert_eq!(DecimalArray::from(&arrow), array); + let arrow = arrow_array::LargeBinaryArray::from(&array); + assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array); } #[test] diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py index 03dbe1a4224a3..3e6ed319b9e73 100644 --- a/src/udf/python/risingwave/udf.py +++ b/src/udf/python/risingwave/udf.py @@ -7,6 +7,7 @@ import json from concurrent.futures import ThreadPoolExecutor import concurrent +from decimal import Decimal class UserDefinedFunction: @@ -81,6 +82,7 @@ def _process_func(type: pa.DataType, output: bool) -> Callable: if pa.types.is_list(type): func = _process_func(type.value_type, output) return lambda array: [(func(v) if v is not None else None) for v in array] + if pa.types.is_struct(type): funcs = [_process_func(field.type, output) for field in type] if output: @@ -94,11 +96,19 @@ def _process_func(type: pa.DataType, output: bool) -> Callable: (func(v) if v is not None else None) for v, func in zip(map.values(), funcs) ) - if pa.types.is_large_string(type): + + if type.equals(JSONB): if output: return lambda v: json.dumps(v) else: return lambda v: json.loads(v) + + if type.equals(UNCONSTRAINED_DECIMAL): + if output: + return lambda v: str(v).encode("utf-8") + else: + return lambda v: Decimal(v.decode("utf-8")) + return lambda v: v @@ -396,6 +406,11 @@ def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType: return t +# we use `large_binary` to represent unconstrained decimal type +UNCONSTRAINED_DECIMAL = pa.large_binary() +JSONB = pa.large_string() + + def _string_to_data_type(type_str: str): """ Convert a SQL data type string to `pyarrow.DataType`. @@ -417,7 +432,7 @@ def _string_to_data_type(type_str: str): return pa.float64() elif type_str.startswith("DECIMAL") or type_str.startswith("NUMERIC"): if type_str == "DECIMAL" or type_str == "NUMERIC": - return pa.decimal128(38) + return UNCONSTRAINED_DECIMAL rest = type_str[8:-1] # remove "DECIMAL(" and ")" if "," in rest: precision, scale = rest.split(",") @@ -435,7 +450,7 @@ def _string_to_data_type(type_str: str): elif type_str in ("VARCHAR"): return pa.string() elif type_str in ("JSONB"): - return pa.large_string() + return JSONB elif type_str in ("BYTEA"): return pa.binary() elif type_str.startswith("STRUCT"): @@ -468,8 +483,10 @@ def _data_type_to_string(t: pa.DataType) -> str: return "FLOAT4" elif t.equals(pa.float64()): return "FLOAT8" - elif t.equals(pa.decimal128(38)): + elif t.equals(UNCONSTRAINED_DECIMAL): return "DECIMAL" + elif pa.types.is_decimal(t): + return f"DECIMAL({t.precision},{t.scale})" elif t.equals(pa.date32()): return "DATE" elif t.equals(pa.time64("us")): @@ -480,7 +497,7 @@ def _data_type_to_string(t: pa.DataType) -> str: return "INTERVAL" elif t.equals(pa.string()): return "VARCHAR" - elif t.equals(pa.large_string()): + elif t.equals(JSONB): return "JSONB" elif t.equals(pa.binary()): return "BYTEA" From 7b1eec613d7583e51ef98688e083cbf6764f8851 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 22 Aug 2023 21:31:29 +0800 Subject: [PATCH 02/11] add decimal_add for python udf Signed-off-by: Runji Wang --- e2e_test/udf/test.py | 6 ++++++ e2e_test/udf/udf.slt | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index ed1f49e7d4dc5..5a5f2dfa05650 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -63,6 +63,11 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]: return dec +@udf(input_types=["DECIMAL", "DECIMAL"], result_type="DECIMAL") +def decimal_add(a: Decimal, b: Decimal) -> Decimal: + return a + b + + @udf(input_types=["VARCHAR[]", "INT"], result_type="VARCHAR") def array_access(list: List[str], idx: int) -> Optional[str]: if idx == 0 or idx > len(list): @@ -184,6 +189,7 @@ def return_all_arrays( server.add_function(split) server.add_function(extract_tcp_info) server.add_function(hex_to_dec) + server.add_function(decimal_add) server.add_function(array_access) server.add_function(jsonb_access) server.add_function(jsonb_concat) diff --git a/e2e_test/udf/udf.slt b/e2e_test/udf/udf.slt index 33579a825832e..0ab50993d94be 100644 --- a/e2e_test/udf/udf.slt +++ b/e2e_test/udf/udf.slt @@ -42,6 +42,9 @@ create function split(varchar) returns table (word varchar, length int) as split statement ok create function hex_to_dec(varchar) returns decimal as hex_to_dec using link 'http://localhost:8815'; +statement ok +create function decimal_add(decimal, decimal) returns decimal as decimal_add using link 'http://localhost:8815'; + statement ok create function array_access(varchar[], int) returns varchar as array_access using link 'http://localhost:8815'; @@ -72,6 +75,7 @@ query TTTTT rowsort show functions ---- array_access varchar[], integer varchar (empty) http://localhost:8815 +decimal_add numeric, numeric numeric (empty) http://localhost:8815 extract_tcp_info bytea struct (empty) http://localhost:8815 gcd integer, integer integer (empty) http://localhost:8815 gcd integer, integer, integer integer (empty) http://localhost:8815 @@ -106,6 +110,11 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90 ---- 233276425899864771438119478 +query R +select decimal_add(1.11, 2.22); +---- +3.33 + query T select array_access(ARRAY['a', 'b', 'c'], 2); ---- @@ -142,7 +151,7 @@ select return_all( 1 ::bigint, 1 ::float4, 1 ::float8, - 1234567890123456789012345678 ::decimal, + 12345678901234567890.12345678 ::decimal, date '2023-06-01', time '01:02:03.456789', timestamp '2023-06-01 01:02:03.456789', @@ -152,7 +161,7 @@ select return_all( '{"key":1}'::jsonb ); ---- -(t,1,1,1,1,1,1234567890123456789012345678,2023-06-01,01:02:03.456789,2023-06-01 01:02:03.456789,1 mon 2 days 00:00:03,string,\x6279746573,{"key": 1}) +(t,1,1,1,1,1,12345678901234567890.12345678,2023-06-01,01:02:03.456789,2023-06-01 01:02:03.456789,1 mon 2 days 00:00:03,string,\x6279746573,{"key": 1}) query T select return_all_arrays( @@ -162,7 +171,7 @@ select return_all_arrays( array[null, 1 ::bigint], array[null, 1 ::float4], array[null, 1 ::float8], - array[null, 1234567890123456789012345678 ::decimal], + array[null, 12345678901234567890.12345678 ::decimal], array[null, date '2023-06-01'], array[null, time '01:02:03.456789'], array[null, timestamp '2023-06-01 01:02:03.456789'], @@ -172,7 +181,7 @@ select return_all_arrays( array[null, '{"key":1}'::jsonb] ); ---- -({NULL,t},{NULL,1},{NULL,1},{NULL,1},{NULL,1},{NULL,1},{NULL,1234567890123456789012345678},{NULL,2023-06-01},{NULL,01:02:03.456789},{NULL,"2023-06-01 01:02:03.456789"},{NULL,"1 mon 2 days 00:00:03"},{NULL,string},{NULL,"\\x6279746573"},{NULL,"{\"key\": 1}"}) +({NULL,t},{NULL,1},{NULL,1},{NULL,1},{NULL,1},{NULL,1},{NULL,12345678901234567890.12345678},{NULL,2023-06-01},{NULL,01:02:03.456789},{NULL,"2023-06-01 01:02:03.456789"},{NULL,"1 mon 2 days 00:00:03"},{NULL,string},{NULL,"\\x6279746573"},{NULL,"{\"key\": 1}"}) query I select series(5); From 16d4005a915b322b5e1b4ef5f5da35808023159f Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 22 Aug 2023 23:13:56 +0800 Subject: [PATCH 03/11] detach java udf and udf-example from parent Signed-off-by: Runji Wang --- java/pom.xml | 2 -- java/udf-example/pom.xml | 14 ++++---------- java/udf/pom.xml | 22 ++++++++++++++-------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/java/pom.xml b/java/pom.xml index 401db2aed123a..41d5776996b88 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -8,8 +8,6 @@ 1.0-SNAPSHOT proto - udf - udf-example java-binding common-utils java-binding-integration-test diff --git a/java/udf-example/pom.xml b/java/udf-example/pom.xml index 781d89db7bbe5..e3309994b7f1d 100644 --- a/java/udf-example/pom.xml +++ b/java/udf-example/pom.xml @@ -1,16 +1,10 @@ - 4.0.0 - - java-parent - com.risingwave.java - 1.0-SNAPSHOT - ../pom.xml - - com.example udf-example 1.0-SNAPSHOT @@ -37,7 +31,7 @@ org.apache.maven.plugins maven-surefire-plugin - 3.0.0-M6 + 3.0.0 --add-opens=java.base/java.nio=ALL-UNNAMED @@ -71,4 +65,4 @@ - + \ No newline at end of file diff --git a/java/udf/pom.xml b/java/udf/pom.xml index c589136b8b302..690e1f3d575a5 100644 --- a/java/udf/pom.xml +++ b/java/udf/pom.xml @@ -1,16 +1,12 @@ - 4.0.0 com.risingwave.java risingwave-udf jar 0.0.1 - - java-parent - com.risingwave.java - 1.0-SNAPSHOT - ../pom.xml - + risingwave-udf http://maven.apache.org @@ -60,5 +56,15 @@ 1.7.0 + + + org.apache.maven.plugins + maven-surefire-plugin + 3.0.0 + + --add-opens=java.base/java.nio=ALL-UNNAMED + + + - + \ No newline at end of file From 9eee762b0634645df687c0297d172f61ca923f90 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 22 Aug 2023 23:18:59 +0800 Subject: [PATCH 04/11] support decimal in java udf Signed-off-by: Runji Wang --- .../src/main/java/com/example/UdfExample.java | 7 +++ .../com/risingwave/functions/TypeUtils.java | 43 +++++++++++-------- .../risingwave/functions/TestUdfServer.java | 6 +-- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/java/udf-example/src/main/java/com/example/UdfExample.java b/java/udf-example/src/main/java/com/example/UdfExample.java index eed88d3dda281..23e8032a72ad5 100644 --- a/java/udf-example/src/main/java/com/example/UdfExample.java +++ b/java/udf-example/src/main/java/com/example/UdfExample.java @@ -39,6 +39,7 @@ public static void main(String[] args) throws IOException { server.addFunction("gcd3", new Gcd3()); server.addFunction("extract_tcp_info", new ExtractTcpInfo()); server.addFunction("hex_to_dec", new HexToDec()); + server.addFunction("decimal_add", new DecimalAdd()); server.addFunction("array_access", new ArrayAccess()); server.addFunction("jsonb_access", new JsonbAccess()); server.addFunction("jsonb_concat", new JsonbConcat()); @@ -114,6 +115,12 @@ public BigDecimal eval(String hex) { } } + public static class DecimalAdd implements ScalarFunction { + public BigDecimal eval(BigDecimal a, BigDecimal b) { + return a.add(b); + } + } + public static class ArrayAccess implements ScalarFunction { public String eval(String[] array, int index) { return array[index - 1]; diff --git a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java index 9485ace03e2db..4c5dea27c61e2 100644 --- a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java +++ b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java @@ -52,8 +52,8 @@ static Field stringToField(String typeStr, String name) { return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); } else if (typeStr.equals("FLOAT8") || typeStr.equals("DOUBLE PRECISION")) { return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); - } else if (typeStr.startsWith("DECIMAL") || typeStr.startsWith("NUMERIC")) { - return Field.nullable(name, new ArrowType.Decimal(38, 0, 128)); + } else if (typeStr.equals("DECIMAL") || typeStr.equals("NUMERIC")) { + return Field.nullable(name, new ArrowType.LargeBinary()); } else if (typeStr.equals("DATE")) { return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); } else if (typeStr.equals("TIME") || typeStr.equals("TIME WITHOUT TIME ZONE")) { @@ -75,10 +75,9 @@ static Field stringToField(String typeStr, String name) { } else if (typeStr.startsWith("STRUCT")) { // extract "STRUCT" var typeList = typeStr.substring(7, typeStr.length() - 1); - var fields = - Arrays.stream(typeList.split(",")) - .map(s -> stringToField(s.trim(), "")) - .collect(Collectors.toList()); + var fields = Arrays.stream(typeList.split(",")) + .map(s -> stringToField(s.trim(), "")) + .collect(Collectors.toList()); return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields); } else { throw new IllegalArgumentException("Unsupported type: " + typeStr); @@ -89,8 +88,8 @@ static Field stringToField(String typeStr, String name) { * Convert a Java class to an Arrow type. * * @param param The Java class. - * @param hint An optional DataTypeHint annotation. - * @param name The name of the field. + * @param hint An optional DataTypeHint annotation. + * @param name The name of the field. * @return The Arrow type. */ static Field classToField(Class param, DataTypeHint hint, String name) { @@ -109,7 +108,7 @@ static Field classToField(Class param, DataTypeHint hint, String name) { } else if (param == Double.class || param == double.class) { return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); } else if (param == BigDecimal.class) { - return Field.nullable(name, new ArrowType.Decimal(38, 0, 128)); + return Field.nullable(name, new ArrowType.LargeBinary()); } else if (param == LocalDate.class) { return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); } else if (param == LocalTime.class) { @@ -163,8 +162,7 @@ static Schema tableFunctionToOutputSchema(Method method) { if (!Iterator.class.isAssignableFrom(type)) { throw new IllegalArgumentException("Table function must return Iterator"); } - var typeArguments = - ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments(); + var typeArguments = ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments(); type = (Class) typeArguments[0]; var rowIndex = Field.nullable("row_index", new ArrowType.Int(32, true)); return new Schema(Arrays.asList(rowIndex, classToField(type, hint, ""))); @@ -239,12 +237,12 @@ static void fillVector(FieldVector fieldVector, Object[] values) { vector.set(i, (double) values[i]); } } - } else if (fieldVector instanceof DecimalVector) { - var vector = (DecimalVector) fieldVector; + } else if (fieldVector instanceof LargeVarBinaryVector) { + var vector = (LargeVarBinaryVector) fieldVector; vector.allocateNew(values.length); for (int i = 0; i < values.length; i++) { if (values[i] != null) { - vector.set(i, (BigDecimal) values[i]); + vector.set(i, ((BigDecimal) values[i]).toString().getBytes()); } } } else if (fieldVector instanceof DateDayVector) { @@ -328,9 +326,9 @@ static void fillVector(FieldVector fieldVector, Object[] values) { } else if (vector.getDataVector() instanceof Float8Vector) { TypeUtils.fillListVector( vector, values, (vec, i, val) -> vec.set(i, val)); - } else if (vector.getDataVector() instanceof DecimalVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val)); + } else if (vector.getDataVector() instanceof LargeVarBinaryVector) { + TypeUtils.fillListVector( + vector, values, (vec, i, val) -> vec.set(i, val.toString().getBytes())); } else if (vector.getDataVector() instanceof DateDayVector) { TypeUtils.fillListVector( vector, values, (vec, i, val) -> vec.set(i, (int) val.toEpochDay())); @@ -425,7 +423,10 @@ static long timestampToMicros(LocalDateTime timestamp) { return date * 24 * 3600 * 1000 * 1000 + time / 1000; } - /** Return a function that converts the object get from input array to the correct type. */ + /** + * Return a function that converts the object get from input array to the + * correct type. + */ static Function processFunc(Field field, Class targetClass) { var inner = processFunc0(field, targetClass); return obj -> obj == null ? null : inner.apply(obj); @@ -438,6 +439,10 @@ static Function processFunc0(Field field, Class targetClass) } else if (field.getType() instanceof ArrowType.LargeUtf8 && targetClass == String.class) { // object is org.apache.arrow.vector.util.Text return obj -> obj.toString(); + } else if (field.getType() instanceof ArrowType.LargeBinary + && targetClass == BigDecimal.class) { + // object is byte[] + return obj -> new BigDecimal(new String((byte[]) obj)); } else if (field.getType() instanceof ArrowType.Date && targetClass == LocalDate.class) { // object is Integer return obj -> LocalDate.ofEpochDay((int) obj); @@ -466,7 +471,7 @@ static Function processFunc0(Field field, Class targetClass) } else if (subfield.getType() .equals(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { return obj -> ((List) obj).stream().map(subfunc).toArray(Double[]::new); - } else if (subfield.getType() instanceof ArrowType.Decimal) { + } else if (subfield.getType() instanceof ArrowType.LargeBinary) { return obj -> ((List) obj).stream().map(subfunc).toArray(BigDecimal[]::new); } else if (subfield.getType() instanceof ArrowType.Date) { return obj -> ((List) obj).stream().map(subfunc).toArray(LocalDate[]::new); diff --git a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java index daba383ab561e..2ac4448d3e9ff 100644 --- a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java +++ b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java @@ -172,9 +172,9 @@ public void all_types() throws Exception { c5.set(0, 1); c5.setValueCount(2); - var c6 = new DecimalVector("", allocator, 38, 0); + var c6 = new LargeVarBinaryVector("", allocator); c6.allocateNew(2); - c6.set(0, BigDecimal.valueOf(10).pow(37)); + c6.set(0, "1.234".getBytes()); c6.setValueCount(2); var c7 = new DateDayVector("", allocator); @@ -226,7 +226,7 @@ public void all_types() throws Exception { var output = stream.getRoot(); assertTrue(stream.next()); assertEquals( - "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":10000000000000000000000000000000000000,\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\"}\n{}", + "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":\"MS4yMzQ=\",\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\"}\n{}", output.contentToTSVString().trim()); } } From 236c87e829b8dfdc17e9a91a827f871a95bdd4fa Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 22 Aug 2023 23:23:27 +0800 Subject: [PATCH 05/11] bump version Signed-off-by: Runji Wang --- src/udf/python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/udf/python/setup.py b/src/udf/python/setup.py index bb05b4cc3ef5d..93d5993540478 100644 --- a/src/udf/python/setup.py +++ b/src/udf/python/setup.py @@ -5,7 +5,7 @@ setup( name="risingwave", - version="0.0.9", + version="0.0.10", author="RisingWave Labs", description="RisingWave Python API", long_description=long_description, From a3c12643942369baf5d86c1a9576da91b5e61439 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 23 Aug 2023 11:46:04 +0800 Subject: [PATCH 06/11] fix pytest Signed-off-by: Runji Wang --- src/udf/python/risingwave/test_udf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/udf/python/risingwave/test_udf.py b/src/udf/python/risingwave/test_udf.py index d1507438e734f..8e72b09f043fe 100644 --- a/src/udf/python/risingwave/test_udf.py +++ b/src/udf/python/risingwave/test_udf.py @@ -185,7 +185,7 @@ def test_all_types(): pa.array([1], type=pa.int64()), pa.array([1], type=pa.float32()), pa.array([1], type=pa.float64()), - pa.array([10**37], type=pa.decimal128(38)), + pa.array(["12345678901234567890.1234567890"], type=pa.large_binary()), pa.array([datetime.date(2023, 6, 1)], type=pa.date32()), pa.array([datetime.time(1, 2, 3, 456789)], type=pa.time64("us")), pa.array( @@ -215,7 +215,7 @@ def test_all_types(): 1, 1.0, 1.0, - 10**37, + b"12345678901234567890.1234567890", datetime.date(2023, 6, 1), datetime.time(1, 2, 3, 456789), datetime.datetime(2023, 6, 1, 1, 2, 3, 456789), From 3d0a11b737ed6c4b171c8ef74dd838e12af34933 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 23 Aug 2023 11:55:28 +0800 Subject: [PATCH 07/11] ci: fix build java udf Signed-off-by: Runji Wang --- ci/scripts/build-other.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci/scripts/build-other.sh b/ci/scripts/build-other.sh index f5b6692c6cbf9..874cbd4891fa5 100755 --- a/ci/scripts/build-other.sh +++ b/ci/scripts/build-other.sh @@ -11,6 +11,8 @@ cd java mvn -B package -Dmaven.test.skip=true mvn -B install -Dmaven.test.skip=true --pl java-binding-integration-test --am mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test +cd udf && mvn -B install && cd .. # don't skip test for udf +cd udf-example && mvn -B package && cd .. cd .. echo "--- Build rust binary for java binding integration test" From 2dcf944e5beb877bf11e468e7e91f2de5ed07b7a Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 23 Aug 2023 13:32:47 +0800 Subject: [PATCH 08/11] fix udf e2e test Signed-off-by: Runji Wang --- e2e_test/udf/udf.slt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/e2e_test/udf/udf.slt b/e2e_test/udf/udf.slt index 0ab50993d94be..fa1f48ee368f6 100644 --- a/e2e_test/udf/udf.slt +++ b/e2e_test/udf/udf.slt @@ -293,6 +293,9 @@ drop function split; statement ok drop function hex_to_dec; +statement ok +drop function decimal_add; + statement ok drop function array_access; From 9275297e360e656d46922c61a89da0d072fb67a8 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 1 Dec 2023 13:21:39 +0800 Subject: [PATCH 09/11] revert pom Signed-off-by: Runji Wang --- java/pom.xml | 2 ++ java/udf/pom.xml | 10 ---------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/java/pom.xml b/java/pom.xml index 293db0e6709ae..79309e3c5d3ec 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -37,6 +37,8 @@ proto + udf + udf-example java-binding common-utils java-binding-integration-test diff --git a/java/udf/pom.xml b/java/udf/pom.xml index 910692a49804d..7e9814b4af41e 100644 --- a/java/udf/pom.xml +++ b/java/udf/pom.xml @@ -54,15 +54,5 @@ 1.7.0 - - - org.apache.maven.plugins - maven-surefire-plugin - 3.0.0 - - --add-opens=java.base/java.nio=ALL-UNNAMED - - - \ No newline at end of file From db06a22d87085ba869b29159933e1c6a1aa75d8f Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 1 Dec 2023 14:09:52 +0800 Subject: [PATCH 10/11] fix java udf Signed-off-by: Runji Wang --- ci/scripts/build-other.sh | 3 +-- .../com/risingwave/functions/TypeUtils.java | 19 +++++++++---------- .../risingwave/functions/TestUdfServer.java | 2 +- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/ci/scripts/build-other.sh b/ci/scripts/build-other.sh index 874cbd4891fa5..9cd44dc78c95a 100755 --- a/ci/scripts/build-other.sh +++ b/ci/scripts/build-other.sh @@ -11,8 +11,7 @@ cd java mvn -B package -Dmaven.test.skip=true mvn -B install -Dmaven.test.skip=true --pl java-binding-integration-test --am mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test -cd udf && mvn -B install && cd .. # don't skip test for udf -cd udf-example && mvn -B package && cd .. +mvn -B test --pl udf cd .. echo "--- Build rust binary for java binding integration test" diff --git a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java index 43bd7a96393dc..5a70a7ed5973b 100644 --- a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java +++ b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java @@ -76,9 +76,10 @@ static Field stringToField(String typeStr, String name) { } else if (typeStr.startsWith("STRUCT")) { // extract "STRUCT" var typeList = typeStr.substring(7, typeStr.length() - 1); - var fields = Arrays.stream(typeList.split(",")) - .map(s -> stringToField(s.trim(), "")) - .collect(Collectors.toList()); + var fields = + Arrays.stream(typeList.split(",")) + .map(s -> stringToField(s.trim(), "")) + .collect(Collectors.toList()); return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields); } else { throw new IllegalArgumentException("Unsupported type: " + typeStr); @@ -89,8 +90,8 @@ static Field stringToField(String typeStr, String name) { * Convert a Java class to an Arrow type. * * @param param The Java class. - * @param hint An optional DataTypeHint annotation. - * @param name The name of the field. + * @param hint An optional DataTypeHint annotation. + * @param name The name of the field. * @return The Arrow type. */ static Field classToField(Class param, DataTypeHint hint, String name) { @@ -163,7 +164,8 @@ static Schema tableFunctionToOutputSchema(Method method) { if (!Iterator.class.isAssignableFrom(type)) { throw new IllegalArgumentException("Table function must return Iterator"); } - var typeArguments = ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments(); + var typeArguments = + ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments(); type = (Class) typeArguments[0]; var rowIndex = Field.nullable("row_index", new ArrowType.Int(32, true)); return new Schema(Arrays.asList(rowIndex, classToField(type, hint, ""))); @@ -461,10 +463,7 @@ static long timestampToMicros(LocalDateTime timestamp) { return date * 24 * 3600 * 1000 * 1000 + time / 1000; } - /** - * Return a function that converts the object get from input array to the - * correct type. - */ + /** Return a function that converts the object get from input array to the correct type. */ static Function processFunc(Field field, Class targetClass) { var inner = processFunc0(field, targetClass); return obj -> obj == null ? null : inner.apply(obj); diff --git a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java index 952b84578fcd2..25977ec15e45c 100644 --- a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java +++ b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java @@ -255,7 +255,7 @@ public void all_types() throws Exception { var output = stream.getRoot(); assertTrue(stream.next()); assertEquals( - "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":10000000000000000000000000000000000000,\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}", + "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":\"MS4yMzQ=\",\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}", output.contentToTSVString().trim()); } } From 9964686b2e8c5d06b71a694d6511cb9469cc2350 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 1 Dec 2023 14:18:51 +0800 Subject: [PATCH 11/11] bump version and update changelog Signed-off-by: Runji Wang --- java/udf/CHANGELOG.md | 6 +++++- src/udf/python/CHANGELOG.md | 6 ++++++ src/udf/python/pyproject.toml | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/java/udf/CHANGELOG.md b/java/udf/CHANGELOG.md index 5f46b7ae339b5..48f2b014271a7 100644 --- a/java/udf/CHANGELOG.md +++ b/java/udf/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.1.1] - 2023-11-28 +## [0.1.1] - 2023-12-01 ### Added @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Bump Arrow version to 14. +### Fixed + +- Fix unconstrained decimal type. + ## [0.1.0] - 2023-09-01 - Initial release. \ No newline at end of file diff --git a/src/udf/python/CHANGELOG.md b/src/udf/python/CHANGELOG.md index e035aab4ebb9e..9255d727c1e1d 100644 --- a/src/udf/python/CHANGELOG.md +++ b/src/udf/python/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.0] - 2023-12-01 + +### Fixed + +- Fix unconstrained decimal type. + ## [0.0.12] - 2023-11-28 ### Changed diff --git a/src/udf/python/pyproject.toml b/src/udf/python/pyproject.toml index 67d17db55dadc..97003bc29157a 100644 --- a/src/udf/python/pyproject.toml +++ b/src/udf/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "risingwave" -version = "0.0.12" +version = "0.1.0" authors = [{ name = "RisingWave Labs" }] description = "RisingWave Python API" readme = "README.md"