From 0127a211e7898e83002332ce57072a6a37fec182 Mon Sep 17 00:00:00 2001 From: chendapao Date: Fri, 19 Jan 2024 10:45:10 +0800 Subject: [PATCH] fix-predicate-timestamp-and-decimal (#31) --- .../paimon/presto/PrestoFilterConverter.java | 55 +++-- .../paimon/presto/PrestoPageSourceBase.java | 9 +- .../paimon/presto/TestPrestoITCase.java | 225 +++++++++++++++++- 3 files changed, 261 insertions(+), 28 deletions(-) diff --git a/paimon-presto-common/src/main/java/org/apache/paimon/presto/PrestoFilterConverter.java b/paimon-presto-common/src/main/java/org/apache/paimon/presto/PrestoFilterConverter.java index a2064d8..9515d62 100644 --- a/paimon-presto-common/src/main/java/org/apache/paimon/presto/PrestoFilterConverter.java +++ b/paimon-presto-common/src/main/java/org/apache/paimon/presto/PrestoFilterConverter.java @@ -19,9 +19,10 @@ package org.apache.paimon.presto; import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.Timestamp; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.predicate.PredicateBuilder; -import org.apache.paimon.shade.guava30.com.google.common.base.Preconditions; import org.apache.paimon.types.RowType; import com.facebook.presto.common.predicate.Domain; @@ -38,20 +39,23 @@ import com.facebook.presto.common.type.IntegerType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.SqlTimestampWithTimeZone; import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TimestampWithTimeZoneType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarbinaryType; import com.facebook.presto.common.type.VarcharType; import io.airlift.slice.Slice; import java.math.BigDecimal; +import java.time.Instant; +import java.time.ZoneId; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.TimeUnit; /** Presto filter to Paimon predicate. */ public class PrestoFilterConverter { @@ -208,8 +212,23 @@ private Object getLiteralValue(Type type, Object prestoNativeValue) { return Math.toIntExact(((Long) prestoNativeValue)); } - if (type instanceof TimestampType || type instanceof TimeType) { - return TimeUnit.MILLISECONDS.toMicros((Long) prestoNativeValue); + if (type instanceof TimeType) { + return (int) ((long) prestoNativeValue / 1_000); + } + + if (type instanceof TimestampType) { + return Timestamp.fromLocalDateTime( + Instant.ofEpochMilli((Long) prestoNativeValue) + .atZone(ZoneId.systemDefault()) + .toLocalDateTime()); + } + + if (type instanceof TimestampWithTimeZoneType) { + if (prestoNativeValue instanceof Long) { + return prestoNativeValue; + } + return Timestamp.fromEpochMillis( + ((SqlTimestampWithTimeZone) prestoNativeValue).getMillisUtc()); } if (type instanceof VarcharType || type instanceof CharType) { @@ -221,23 +240,21 @@ private Object getLiteralValue(Type type, Object prestoNativeValue) { } if (type instanceof DecimalType) { + // Refer to trino. DecimalType decimalType = (DecimalType) type; - Object value = - Objects.requireNonNull( - prestoNativeValue, "The prestoNativeValue must be non-null"); - if (Decimals.isShortDecimal(decimalType)) { - Preconditions.checkArgument( - value instanceof Long, - "A short decimal should be represented by a Long value but was %s", - value.getClass().getName()); - return BigDecimal.valueOf((long) value).movePointLeft(decimalType.getScale()); + BigDecimal bigDecimal; + if (prestoNativeValue instanceof Long) { + bigDecimal = + BigDecimal.valueOf((long) prestoNativeValue) + .movePointLeft(decimalType.getScale()); + } else { + bigDecimal = + new BigDecimal( + Decimals.decodeUnscaledValue((Slice) prestoNativeValue), + decimalType.getScale()); } - Preconditions.checkArgument( - value instanceof Slice, - "A long decimal should be represented by a Slice value but was %s", - value.getClass().getName()); - return new BigDecimal( - Decimals.decodeUnscaledValue((Slice) value), decimalType.getScale()); + return Decimal.fromBigDecimal( + bigDecimal, decimalType.getPrecision(), decimalType.getScale()); } throw new UnsupportedOperationException("Unsupported type: " + type); diff --git a/paimon-presto-common/src/main/java/org/apache/paimon/presto/PrestoPageSourceBase.java b/paimon-presto-common/src/main/java/org/apache/paimon/presto/PrestoPageSourceBase.java index 8df72fb..a56de27 100644 --- a/paimon-presto-common/src/main/java/org/apache/paimon/presto/PrestoPageSourceBase.java +++ b/paimon-presto-common/src/main/java/org/apache/paimon/presto/PrestoPageSourceBase.java @@ -54,6 +54,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.math.BigDecimal; +import java.time.ZoneId; import java.util.ArrayList; import java.util.List; @@ -206,7 +207,13 @@ private void appendTo(Type prestoType, DataType paimonType, Object value, BlockB prestoType.writeLong( output, encodeShortScaledValue(decimal, decimalType.getScale())); } else if (prestoType.equals(TIMESTAMP)) { - prestoType.writeLong(output, ((Timestamp) value).toSQLTimestamp().getTime()); + prestoType.writeLong( + output, + ((Timestamp) value) + .toLocalDateTime() + .atZone(ZoneId.systemDefault()) + .toInstant() + .toEpochMilli()); } else if (prestoType.equals(TIME)) { prestoType.writeLong(output, (int) value * 1_000); } else { diff --git a/paimon-presto-common/src/test/java/org/apache/paimon/presto/TestPrestoITCase.java b/paimon-presto-common/src/test/java/org/apache/paimon/presto/TestPrestoITCase.java index b740e6c..0a0f0cb 100644 --- a/paimon-presto-common/src/test/java/org/apache/paimon/presto/TestPrestoITCase.java +++ b/paimon-presto-common/src/test/java/org/apache/paimon/presto/TestPrestoITCase.java @@ -41,19 +41,26 @@ import org.apache.paimon.types.TimestampType; import org.apache.paimon.types.VarCharType; +import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.DistributedQueryRunner; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeSuite; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; +import java.io.IOException; import java.math.BigDecimal; import java.nio.file.Files; +import java.time.Instant; import java.time.LocalDateTime; +import java.time.ZoneId; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.TimeZone; import java.util.UUID; import static com.facebook.airlift.testing.Closeables.closeAllSuppress; @@ -158,7 +165,9 @@ protected QueryRunner createQueryRunner() throws Exception { Path tablePath5 = new Path(warehouse, "default.db/test_timestamp"); RowType rowType = new RowType( - Collections.singletonList(new DataField(0, "ts", new TimestampType()))); + Arrays.asList( + new DataField(0, "ts", new TimestampType()), + new DataField(1, "ts_long_0", new TimestampType()))); new SchemaManager(LocalFileIO.create(), tablePath5) .createTable( new Schema( @@ -173,7 +182,11 @@ protected QueryRunner createQueryRunner() throws Exception { writer.write( GenericRow.of( Timestamp.fromLocalDateTime( - LocalDateTime.parse("2023-01-01T01:01:01.123")))); + LocalDateTime.parse("2023-01-01T01:01:01.123")), + Timestamp.fromLocalDateTime( + Instant.ofEpochMilli(1672534861123L) + .atZone(ZoneId.systemDefault()) + .toLocalDateTime()))); // 2023-01-01T01:01:01.123 UTC commit.commit(0, writer.prepareCommit(true, 0)); } @@ -207,7 +220,13 @@ protected QueryRunner createQueryRunner() throws Exception { try { queryRunner = DistributedQueryRunner.builder( - testSessionBuilder().setCatalog(CATALOG).setSchema(DB).build()) + testSessionBuilder() + .setTimeZoneKey( + TimeZoneKey.getTimeZoneKey( + ZoneId.systemDefault().getId())) + .setCatalog(CATALOG) + .setSchema(DB) + .build()) .build(); queryRunner.installPlugin(new PrestoPlugin()); Map options = new HashMap<>(); @@ -232,11 +251,23 @@ private static SimpleTableTestHelper createTestHelper(Path tablePath) throws Exc return new SimpleTableTestHelper(tablePath, rowType); } + @BeforeSuite + public void setup() throws Exception { + // Change the default time zone for presto-tests, like Trino. + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + } + @BeforeTest public void init() throws Exception { queryRunner = createQueryRunner(); } + @AfterTest + public void clear() throws IOException { + // TODO Delete default.db + queryRunner.close(); + } + @Test public void testComplexTypes() throws Exception { assertThat(sql("SELECT * FROM paimon.default.t4")).isEqualTo("[[1, {1=2}]]"); @@ -279,12 +310,12 @@ public void testGroupByWithCast() throws Exception { .isEqualTo("[[1, 1, 3, 3], [2, 3, 3, 3]]"); } - // Due to the inconsistency between the testing behavior and the real production environment, - // we are temporarily disabling timestamp testing here. - @Test(enabled = false) + @Test public void testTimestampFormat() throws Exception { - assertThat(sql("SELECT ts FROM paimon.default.test_timestamp")) - .isEqualTo("[[2023-01-01T01:01:01.123]]"); + assertThat( + sql( + "SELECT ts, format_datetime(ts, 'yyyy-MM-dd HH:mm:ss') FROM paimon.default.test_timestamp")) + .isEqualTo("[[2023-01-01T01:01:01.123, 2023-01-01 01:01:01]]"); } @Test @@ -293,6 +324,184 @@ public void testDecimal() throws Exception { .isEqualTo("[[10000000000, 123.456]]"); } + @Test + public void testTimestampPredicateWithTimezone() throws Exception { + // Pacific/Apia + assertThat( + sql( + "SELECT ts, ts_long_0 FROM paimon.default.test_timestamp " + + "where ts = TIMESTAMP '2023-01-01 01:01:01.123 Pacific/Apia'")) + .isEqualTo("[]"); + + // UTC + assertThat( + sql( + "SELECT ts, ts_long_0 FROM paimon.default.test_timestamp " + + "where ts = TIMESTAMP '2023-01-01 01:01:01.123 UTC'")) + .isEqualTo("[[2023-01-01T01:01:01.123, 2023-01-01T01:01:01.123]]"); + } + + @Test + public void testTimestampPredicateEq() throws Exception { + // In UT 1672534861123 is 2023-01-01T01:01:01.123 UTC. + + assertThat( + sql( + "SELECT ts, ts_long_0 FROM paimon.default.test_timestamp " + + "where ts = TIMESTAMP '2023-01-01 01:01:01.123'")) + .isEqualTo("[[2023-01-01T01:01:01.123, 2023-01-01T01:01:01.123]]"); + + assertThat( + sql( + "SELECT ts, ts_long_0 FROM paimon.default.test_timestamp " + + "where ts = TIMESTAMP '2023-01-01 01:01:01.123'")) + .isEqualTo("[[2023-01-01T01:01:01.123, 2023-01-01T01:01:01.123]]"); + + assertThat( + sql( + "SELECT ts, ts_long_0 FROM paimon.default.test_timestamp " + + "WHERE ts_long_0 = date_add(" + + "'millisecond', " + + "CAST(1672534861123 % 1000 AS INTEGER), " + + "from_unixtime(CAST(1672534861123 / 1000 AS BIGINT))" + + ")")) + .isEqualTo("[[2023-01-01T01:01:01.123, 2023-01-01T01:01:01.123]]"); + + assertThat( + sql( + "SELECT ts, ts_long_0 FROM paimon.default.test_timestamp " + + "WHERE ts = TIMESTAMP '2023-01-01 01:01:01.123' " + + "AND ts_long_0 = date_add(" + + "'millisecond', " + + "CAST(1672534861123 % 1000 AS INTEGER), " + + "from_unixtime(CAST(1672534861123 / 1000 AS BIGINT)))")) + .isEqualTo("[[2023-01-01T01:01:01.123, 2023-01-01T01:01:01.123]]"); + } + + @Test + public void testTimestampPredicate() throws Exception { + // Test gt and gte. + assertThat( + sql( + "SELECT ts FROM paimon.default.test_timestamp " + + "where ts > TIMESTAMP '2023-01-01 01:01:01'")) + .isEqualTo("[[2023-01-01T01:01:01.123]]"); + + assertThat( + sql( + "SELECT ts FROM paimon.default.test_timestamp " + + "where ts >= TIMESTAMP '2023-01-01 01:01:01.123'")) + .isEqualTo("[[2023-01-01T01:01:01.123]]"); + + // Test lt and lte. + assertThat( + sql( + "SELECT ts FROM paimon.default.test_timestamp " + + "where ts < TIMESTAMP '2023-01-01 01:01:02'")) + .isEqualTo("[[2023-01-01T01:01:01.123]]"); + + assertThat( + sql( + "SELECT ts FROM paimon.default.test_timestamp " + + "where ts <= TIMESTAMP '2023-01-01 01:01:01.123'")) + .isEqualTo("[[2023-01-01T01:01:01.123]]"); + + // Test gt and lt. + assertThat( + sql( + "SELECT ts FROM paimon.default.test_timestamp " + + "where ts > TIMESTAMP '2023-01-01 01:01:00' " + + "and ts < TIMESTAMP '2023-01-01 01:01:02'")) + .isEqualTo("[[2023-01-01T01:01:01.123]]"); + + // Test gt and lte. + assertThat( + sql( + "SELECT ts FROM paimon.default.test_timestamp " + + "where ts > TIMESTAMP '2023-01-01 01:01:00' " + + "and ts <= TIMESTAMP '2023-01-01 01:01:01.123'")) + .isEqualTo("[[2023-01-01T01:01:01.123]]"); + + // Test gte and lte. + assertThat( + sql( + "SELECT ts FROM paimon.default.test_timestamp " + + "where ts >= TIMESTAMP '2023-01-01 01:01:01.123' " + + "and ts <= TIMESTAMP '2023-01-01 01:01:01.123'")) + .isEqualTo("[[2023-01-01T01:01:01.123]]"); + + // Test gte and lt. + assertThat( + sql( + "SELECT ts FROM paimon.default.test_timestamp " + + "where ts >= TIMESTAMP '2023-01-01 01:01:01' " + + "and ts < TIMESTAMP '2023-01-01 01:01:02'")) + .isEqualTo("[[2023-01-01T01:01:01.123]]"); + } + + @Test + public void testDecimalPredicate() throws Exception { + // Test eq. + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 = 123.456")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c1 FROM paimon.default.test_decimal where c1 = 10000000000")) + .isEqualTo("[[10000000000]]"); + + // Test gt and gte. + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 > 123")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 > 123.455")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 >= 123")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 >= 123.456")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c1 FROM paimon.default.test_decimal where c1 >= 10000000000")) + .isEqualTo("[[10000000000]]"); + + // Test lt and lte. + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 < 124")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 < 123.457")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 <= 124")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 <= 123.457")) + .isEqualTo("[[123.456]]"); + + assertThat(sql("SELECT c1 FROM paimon.default.test_decimal where c1 <= 10000000000")) + .isEqualTo("[[10000000000]]"); + + // Test gt and lt. + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 > 123 and c2 < 666")) + .isEqualTo("[[123.456]]"); + + // Test gt and lte. + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 > 123 and c2 <= 666")) + .isEqualTo("[[123.456]]"); + + // Test gte and lte. + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 >= 123 and c2 <= 666")) + .isEqualTo("[[123.456]]"); + + // Test gte and lt. + assertThat(sql("SELECT c2 FROM paimon.default.test_decimal where c2 >= 123 and c2 < 666")) + .isEqualTo("[[123.456]]"); + + assertThat( + sql( + "SELECT c1 FROM paimon.default.test_decimal where c1 >= 10000000000 and c1 < 10000000001")) + .isEqualTo("[[10000000000]]"); + } + private String sql(String sql) throws Exception { MaterializedResult result = queryRunner.execute(sql); return result.getMaterializedRows().toString();