Skip to content

Commit

Permalink
fix decimal convert in iceberg sink
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME committed Mar 26, 2024
1 parent f3ebeaa commit 55c1261
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 35 deletions.
6 changes: 6 additions & 0 deletions e2e_test/iceberg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import date
from datetime import datetime
from datetime import timezone
import decimal;


def strtobool(v):
Expand Down Expand Up @@ -81,6 +82,11 @@ def verify_result(args,verify_sql,verify_schema,verify_data):
tc.assertEqual(row1[idx], datetime.fromisoformat(row2[idx]))
elif ty == "string":
tc.assertEqual(row1[idx], row2[idx])
elif ty == "decimal":
if row2[idx] == "none":
tc.assert_(row1[idx] is None)
else:
tc.assertEqual(row1[idx], decimal.Decimal(row2[idx]))
else:
tc.fail(f"Unsupported type {ty}")

Expand Down
13 changes: 7 additions & 6 deletions e2e_test/iceberg/test_case/iceberg_sink_append_only.slt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ v_varchar varchar,
v_bool boolean,
v_date date,
v_timestamp timestamptz,
v_ts_ntz timestamp
v_ts_ntz timestamp,
v_decimal decimal
);

statement ok
Expand All @@ -36,10 +37,10 @@ CREATE SINK s6 AS select * from mv6 WITH (

statement ok
INSERT INTO t6 VALUES
(1, 1, 1000, 1.1, 1.11, '1-1', true, '2022-03-11', '2022-03-11 01:00:00Z'::timestamptz, '2022-03-11 01:00:00'),
(2, 2, 2000, 2.2, 2.22, '2-2', false, '2022-03-12', '2022-03-12 02:00:00Z'::timestamptz, '2022-03-12 02:00:00'),
(3, 3, 3000, 3.3, 3.33, '3-3', true, '2022-03-13', '2022-03-13 03:00:00Z'::timestamptz, '2022-03-13 03:00:00'),
(4, 4, 4000, 4.4, 4.44, '4-4', false, '2022-03-14', '2022-03-14 04:00:00Z'::timestamptz, '2022-03-14 04:00:00');
(1, 1, 1000, 1.1, 1.11, '1-1', true, '2022-03-11', '2022-03-11 01:00:00Z'::timestamptz, '2022-03-11 01:00:00',1.11),
(2, 2, 2000, 2.2, 2.22, '2-2', false, '2022-03-12', '2022-03-12 02:00:00Z'::timestamptz, '2022-03-12 02:00:00',2.22),
(3, 3, 3000, 3.3, 3.33, '3-3', true, '2022-03-13', '2022-03-13 03:00:00Z'::timestamptz, '2022-03-13 03:00:00','inf'),
(4, 4, 4000, 4.4, 4.44, '4-4', false, '2022-03-14', '2022-03-14 04:00:00Z'::timestamptz, '2022-03-14 04:00:00','-inf');

statement ok
FLUSH;
Expand All @@ -48,7 +49,7 @@ sleep 5s

statement ok
INSERT INTO t6 VALUES
(5, 5, 5000, 5.5, 5.55, '5-5', true, '2022-03-15', '2022-03-15 05:00:00Z'::timestamptz, '2022-03-15 05:00:00');
(5, 5, 5000, 5.5, 5.55, '5-5', true, '2022-03-15', '2022-03-15 05:00:00Z'::timestamptz, '2022-03-15 05:00:00','nan');

statement ok
FLUSH;
Expand Down
15 changes: 8 additions & 7 deletions e2e_test/iceberg/test_case/no_partition_append_only.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,25 @@ init_sqls = [
v_bool boolean,
v_date date,
v_timestamp timestamp,
v_ts_ntz timestamp_ntz
v_ts_ntz timestamp_ntz,
v_decimal decimal(10,5)
) USING iceberg TBLPROPERTIES ('format-version'='2');
'''
]

slt = 'test_case/iceberg_sink_append_only.slt'

verify_schema = ['long', 'int', 'long', 'float', 'double', 'string', 'boolean', 'date', 'timestamp', 'timestamp_ntz']
verify_schema = ['long', 'int', 'long', 'float', 'double', 'string', 'boolean', 'date', 'timestamp', 'timestamp_ntz','decimal']

verify_sql = 'SELECT * FROM demo_db.demo_table ORDER BY id ASC'


verify_data = """
1,1,1000,1.1,1.11,1-1,true,2022-03-11,2022-03-11 01:00:00+00:00,2022-03-11 01:00:00
2,2,2000,2.2,2.22,2-2,false,2022-03-12,2022-03-12 02:00:00+00:00,2022-03-12 02:00:00
3,3,3000,3.3,3.33,3-3,true,2022-03-13,2022-03-13 03:00:00+00:00,2022-03-13 03:00:00
4,4,4000,4.4,4.44,4-4,false,2022-03-14,2022-03-14 04:00:00+00:00,2022-03-14 04:00:00
5,5,5000,5.5,5.55,5-5,true,2022-03-15,2022-03-15 05:00:00+00:00,2022-03-15 05:00:00
1,1,1000,1.1,1.11,1-1,true,2022-03-11,2022-03-11 01:00:00+00:00,2022-03-11 01:00:00,1.11
2,2,2000,2.2,2.22,2-2,false,2022-03-12,2022-03-12 02:00:00+00:00,2022-03-12 02:00:00,2.22
3,3,3000,3.3,3.33,3-3,true,2022-03-13,2022-03-13 03:00:00+00:00,2022-03-13 03:00:00,none
4,4,4000,4.4,4.44,4-4,false,2022-03-14,2022-03-14 04:00:00+00:00,2022-03-14 04:00:00,none
5,5,5000,5.5,5.55,5-5,true,2022-03-15,2022-03-15 05:00:00+00:00,2022-03-15 05:00:00,none
"""

drop_sqls = [
Expand Down
15 changes: 8 additions & 7 deletions e2e_test/iceberg/test_case/partition_append_only.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ init_sqls = [
v_bool boolean,
v_date date,
v_timestamp timestamp,
v_ts_ntz timestamp_ntz
v_ts_ntz timestamp_ntz,
v_decimal decimal(10,5)
)
PARTITIONED BY (v_int,bucket(10,v_long),truncate(30,v_long),years(v_date),months(v_timestamp),days(v_ts_ntz))
TBLPROPERTIES ('format-version'='2');
Expand All @@ -21,17 +22,17 @@ init_sqls = [

slt = 'test_case/iceberg_sink_append_only.slt'

verify_schema = ['long', 'int', 'long', 'float', 'double', 'string', 'boolean', 'date', 'timestamp', 'timestamp_ntz']
verify_schema = ['long', 'int', 'long', 'float', 'double', 'string', 'boolean', 'date', 'timestamp', 'timestamp_ntz','decimal']

verify_sql = 'SELECT * FROM demo_db.demo_table ORDER BY id ASC'


verify_data = """
1,1,1000,1.1,1.11,1-1,true,2022-03-11,2022-03-11 01:00:00+00:00,2022-03-11 01:00:00
2,2,2000,2.2,2.22,2-2,false,2022-03-12,2022-03-12 02:00:00+00:00,2022-03-12 02:00:00
3,3,3000,3.3,3.33,3-3,true,2022-03-13,2022-03-13 03:00:00+00:00,2022-03-13 03:00:00
4,4,4000,4.4,4.44,4-4,false,2022-03-14,2022-03-14 04:00:00+00:00,2022-03-14 04:00:00
5,5,5000,5.5,5.55,5-5,true,2022-03-15,2022-03-15 05:00:00+00:00,2022-03-15 05:00:00
1,1,1000,1.1,1.11,1-1,true,2022-03-11,2022-03-11 01:00:00+00:00,2022-03-11 01:00:00,1.11
2,2,2000,2.2,2.22,2-2,false,2022-03-12,2022-03-12 02:00:00+00:00,2022-03-12 02:00:00,2.22
3,3,3000,3.3,3.33,3-3,true,2022-03-13,2022-03-13 03:00:00+00:00,2022-03-13 03:00:00,none
4,4,4000,4.4,4.44,4-4,false,2022-03-14,2022-03-14 04:00:00+00:00,2022-03-14 04:00:00,none
5,5,5000,5.5,5.55,5-5,true,2022-03-15,2022-03-15 05:00:00+00:00,2022-03-15 05:00:00,none
"""

drop_sqls = [
Expand Down
105 changes: 103 additions & 2 deletions src/common/src/array/arrow/arrow_iceberg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,72 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::ops::{Div, Mul};
use std::sync::Arc;

use arrow_array::{ArrayRef, StructArray};
use arrow_schema::DataType;
use itertools::Itertools;
use num_traits::abs;

use crate::array::{ArrayError, DataChunk};
use super::{ToArrowArrayWithTypeConvert, ToArrowTypeConvert};
use crate::array::{Array, ArrayError, DataChunk, DecimalArray};
use crate::util::iter_util::ZipEqFast;

struct IcebergArrowConvert;

impl ToArrowTypeConvert for IcebergArrowConvert {
#[inline]
fn decimal_type_to_arrow(&self) -> arrow_schema::DataType {
arrow_schema::DataType::Decimal128(arrow_schema::DECIMAL128_MAX_PRECISION, 0)
}
}

impl ToArrowArrayWithTypeConvert for IcebergArrowConvert {
fn decimal_to_arrow(
&self,
data_type: &arrow_schema::DataType,
array: &DecimalArray,
) -> Result<arrow_array::ArrayRef, ArrayError> {
let (precision, max_scale) = match data_type {
arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale),
_ => return Err(ArrayError::to_arrow("Invalid decimal type")),
};

// Convert Decimal to i128:
let values: Vec<Option<i128>> = array
.iter()
.map(|e| {
e.and_then(|e| match e {
crate::array::Decimal::Normalized(e) => {
let value = e.mantissa();
let scale = e.scale() as i8;
let diff_scale = abs(max_scale - scale);
let value = match scale {
_ if scale < max_scale => {
value.mul(10_i32.pow(diff_scale as u32) as i128)
}
_ if scale > max_scale => {
value.div(10_i32.pow(diff_scale as u32) as i128)
}
_ => value,
};
Some(value)
}
crate::array::Decimal::PositiveInf => None,
crate::array::Decimal::NegativeInf => None,
crate::array::Decimal::NaN => None,
})
})
.collect();

let array = arrow_array::Decimal128Array::from(values)
.with_precision_and_scale(precision, max_scale)
.map_err(ArrayError::from_arrow)?;
Ok(Arc::new(array) as ArrayRef)
}
}

/// Converts RisingWave array to Arrow array with the schema.
/// The behavior is specified for iceberg:
/// For different struct type, try to use fields in schema to cast.
Expand All @@ -37,7 +94,8 @@ pub fn to_iceberg_record_batch_with_schema(
.iter()
.zip_eq_fast(schema.fields().iter())
.map(|(column, field)| {
let column: arrow_array::ArrayRef = column.as_ref().try_into()?;
let column: arrow_array::ArrayRef =
IcebergArrowConvert {}.to_arrow_with_type(field.data_type(), column)?;
if column.data_type() == field.data_type() {
Ok(column)
} else if let DataType::Struct(actual) = column.data_type()
Expand Down Expand Up @@ -71,3 +129,46 @@ pub fn to_iceberg_record_batch_with_schema(
arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts)
.map_err(ArrayError::to_arrow)
}

pub fn iceberg_to_arrow_type(
data_type: &crate::array::DataType,
) -> Result<arrow_schema::DataType, ArrayError> {
IcebergArrowConvert {}.to_arrow_type(data_type)
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use arrow_array::ArrayRef;

use crate::array::arrow::arrow_iceberg::IcebergArrowConvert;
use crate::array::arrow::ToArrowArrayWithTypeConvert;
use crate::array::{Decimal, DecimalArray};

#[test]
fn decimal() {
let array = DecimalArray::from_iter([
None,
Some(Decimal::NaN),
Some(Decimal::PositiveInf),
Some(Decimal::NegativeInf),
Some(Decimal::Normalized("123.4".parse().unwrap())),
Some(Decimal::Normalized("123.456".parse().unwrap())),
]);
let ty = arrow_schema::DataType::Decimal128(38, 3);
let arrow_array = IcebergArrowConvert.decimal_to_arrow(&ty, &array).unwrap();
let expect_array = Arc::new(
arrow_array::Decimal128Array::from(vec![
None,
None,
None,
None,
Some(123400),
Some(123456),
])
.with_data_type(ty),
) as ArrayRef;
assert_eq!(&arrow_array, &expect_array);
}
}
6 changes: 4 additions & 2 deletions src/common/src/array/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mod arrow_default;
mod arrow_deltalake;
mod arrow_iceberg;

pub use arrow_default::to_record_batch_with_schema;
pub use arrow_default::{
to_record_batch_with_schema, ToArrowArrayWithTypeConvert, ToArrowTypeConvert,
};
pub use arrow_deltalake::to_deltalake_record_batch_with_schema;
pub use arrow_iceberg::to_iceberg_record_batch_with_schema;
pub use arrow_iceberg::{iceberg_to_arrow_type, to_iceberg_record_batch_with_schema};
4 changes: 2 additions & 2 deletions src/common/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
mod arrow;
pub use arrow::{
to_deltalake_record_batch_with_schema, to_iceberg_record_batch_with_schema,
to_record_batch_with_schema,
iceberg_to_arrow_type, to_deltalake_record_batch_with_schema,
to_iceberg_record_batch_with_schema, to_record_batch_with_schema,
};
mod bool_array;
pub mod bytes_array;
Expand Down
14 changes: 5 additions & 9 deletions src/connector/src/sink/iceberg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ use icelake::transaction::Transaction;
use icelake::types::{data_file_from_json, data_file_to_json, Any, DataFile};
use icelake::{Table, TableIdentifier};
use itertools::Itertools;
use risingwave_common::array::{to_iceberg_record_batch_with_schema, Op, StreamChunk};
use risingwave_common::array::{
iceberg_to_arrow_type, to_iceberg_record_batch_with_schema, Op, StreamChunk,
};
use risingwave_common::bail;
use risingwave_common::buffer::Bitmap;
use risingwave_common::catalog::Schema;
Expand Down Expand Up @@ -998,17 +1000,11 @@ pub fn try_matches_arrow_schema(
// RisingWave decimal type cannot specify precision and scale, so we use the default value.
ArrowDataType::Decimal128(38, 0)
} else {
ArrowDataType::try_from(*our_field_type).map_err(|e| anyhow!(e))?
iceberg_to_arrow_type(our_field_type).map_err(|e| anyhow!(e))?
};

let compatible = match (&converted_arrow_data_type, arrow_field.data_type()) {
(ArrowDataType::Decimal128(p1, s1), ArrowDataType::Decimal128(p2, s2)) => {
if for_source {
true
} else {
*p1 >= *p2 && *s1 >= *s2
}
}
(ArrowDataType::Decimal128(_, _), ArrowDataType::Decimal128(_, _)) => true,
(left, right) => left == right,
};
if !compatible {
Expand Down

0 comments on commit 55c1261

Please sign in to comment.