From f5a1c50c7fd832f7dbff8f0cf9196dc9d3881ad3 Mon Sep 17 00:00:00 2001 From: Ankit Kothari Date: Wed, 11 Oct 2023 22:44:13 -0700 Subject: [PATCH] Change *VectorAggregator to support pair serdes --- ...PairLongObjectDeltaEncodedStagedSerde.java | 6 +- ...izablePairLongObjectSimpleStagedSerde.java | 6 +- ...zablePairLongDoubleComplexMetricSerde.java | 30 +-- ...PairLongDoubleDeltaEncodedStagedSerde.java | 4 +- ...izablePairLongDoubleSimpleStagedSerde.java | 4 +- ...izablePairLongFloatComplexMetricSerde.java | 30 +-- ...ePairLongFloatDeltaEncodedStagedSerde.java | 4 +- ...lizablePairLongFloatSimpleStagedSerde.java | 4 +- ...lizablePairLongLongComplexMetricSerde.java | 30 +-- ...lePairLongLongDeltaEncodedStagedSerde.java | 4 +- ...alizablePairLongLongSimpleStagedSerde.java | 4 +- .../aggregation/any/NilVectorAggregator.java | 10 +- .../first/DoubleFirstAggregatorFactory.java | 24 +- .../first/DoubleFirstVectorAggregator.java | 9 +- .../aggregation/first/FirstLastUtils.java | 83 +++++++ .../first/FloatFirstAggregatorFactory.java | 23 +- .../first/FloatFirstVectorAggregator.java | 10 +- .../first/LongFirstAggregatorFactory.java | 25 +- .../first/LongFirstVectorAggregator.java | 11 +- .../first/NumericFirstVectorAggregator.java | 98 ++++++-- .../first/StringFirstAggregatorFactory.java | 4 +- .../first/StringFirstLastUtils.java | 41 ---- .../first/StringFirstVectorAggregator.java | 4 +- .../last/DoubleLastAggregatorFactory.java | 28 ++- .../last/DoubleLastVectorAggregator.java | 9 +- .../last/FloatLastAggregatorFactory.java | 28 ++- .../last/FloatLastVectorAggregator.java | 11 +- .../last/LongLastAggregatorFactory.java | 29 ++- .../last/LongLastVectorAggregator.java | 10 +- .../last/NumericLastVectorAggregator.java | 98 ++++++-- .../last/StringLastAggregatorFactory.java | 5 +- .../last/StringLastVectorAggregator.java | 5 +- ...alizablePairLongDoubleBufferStoreTest.java | 2 +- ...ePairLongDoubleComplexMetricSerdeTest.java | 10 +- ...LongDoubleDeltaEncodedStagedSerdeTest.java | 12 +- ...ializablePairLongFloatBufferStoreTest.java | 2 +- ...lePairLongFloatComplexMetricSerdeTest.java | 12 +- ...rLongFloatDeltaEncodedStagedSerdeTest.java | 12 +- ...rializablePairLongLongBufferStoreTest.java | 2 +- ...blePairLongLongComplexMetricSerdeTest.java | 8 +- ...irLongLongDeltaEncodedStagedSerdeTest.java | 12 +- ...ablePairLongLongSimpleStagedSerdeTest.java | 6 +- .../DoubleFirstVectorAggregationTest.java | 134 ++++++---- .../FloatFirstVectorAggregationTest.java | 113 ++++++--- .../first/LongFirstVectorAggregationTest.java | 120 ++++++--- .../last/DoubleLastVectorAggregatorTest.java | 224 ++++++++++++++--- .../last/FloatLastVectorAggregatorTest.java | 230 ++++++++++++++--- .../last/LongLastVectorAggregatorTest.java | 232 +++++++++++++++--- 48 files changed, 1352 insertions(+), 470 deletions(-) create mode 100644 processing/src/main/java/org/apache/druid/query/aggregation/first/FirstLastUtils.java diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/AbstractSerializablePairLongObjectDeltaEncodedStagedSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/AbstractSerializablePairLongObjectDeltaEncodedStagedSerde.java index 2abbc88e104a..62b0a52642e7 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/AbstractSerializablePairLongObjectDeltaEncodedStagedSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/AbstractSerializablePairLongObjectDeltaEncodedStagedSerde.java @@ -22,6 +22,7 @@ import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.segment.serde.cell.StagedSerde; import org.apache.druid.segment.serde.cell.StorableBuffer; @@ -75,6 +76,7 @@ public void store(ByteBuffer byteBuffer) } if (rhsObject != null) { + byteBuffer.put(NullHandling.IS_NOT_NULL_BYTE); if (pairClass.isAssignableFrom(SerializablePairLongLong.class)) { byteBuffer.putLong((long) rhsObject); } else if (pairClass.isAssignableFrom(SerializablePairLongDouble.class)) { @@ -82,6 +84,8 @@ public void store(ByteBuffer byteBuffer) } else if (pairClass.isAssignableFrom(SerializablePairLongFloat.class)) { byteBuffer.putFloat((float) rhsObject); } + } else { + byteBuffer.put(NullHandling.IS_NULL_BYTE); } } @@ -100,7 +104,7 @@ public int getSerializedSize() } } - return (useIntegerDelta ? Integer.BYTES : Long.BYTES) + rhsBytes; + return (useIntegerDelta ? Integer.BYTES : Long.BYTES) + Byte.BYTES + rhsBytes; } }; } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/AbstractSerializablePairLongObjectSimpleStagedSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/AbstractSerializablePairLongObjectSimpleStagedSerde.java index 95a1b8cc3097..cf43cbc34c5c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/AbstractSerializablePairLongObjectSimpleStagedSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/AbstractSerializablePairLongObjectSimpleStagedSerde.java @@ -21,6 +21,7 @@ import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.segment.serde.cell.StagedSerde; import org.apache.druid.segment.serde.cell.StorableBuffer; @@ -57,6 +58,7 @@ public void store(ByteBuffer byteBuffer) Preconditions.checkNotNull(value.getLhs(), String.format(Locale.ENGLISH, "Long in %s must be non-null", pairCLass.getSimpleName())); byteBuffer.putLong(value.getLhs()); if (rhsObject != null) { + byteBuffer.put(NullHandling.IS_NOT_NULL_BYTE); if (pairCLass.isAssignableFrom(SerializablePairLongLong.class)) { byteBuffer.putLong((long) rhsObject); } else if (pairCLass.isAssignableFrom(SerializablePairLongDouble.class)) { @@ -64,6 +66,8 @@ public void store(ByteBuffer byteBuffer) } else if (pairCLass.isAssignableFrom(SerializablePairLongFloat.class)) { byteBuffer.putFloat((float) rhsObject); } + } else { + byteBuffer.put(NullHandling.IS_NULL_BYTE); } } @@ -81,7 +85,7 @@ public int getSerializedSize() rhsBytes = Float.BYTES; } } - return Long.BYTES + rhsBytes; + return Long.BYTES + Byte.BYTES + rhsBytes; } }; } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleComplexMetricSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleComplexMetricSerde.java index 9506524f9d49..2fd854c1da97 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleComplexMetricSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleComplexMetricSerde.java @@ -20,7 +20,6 @@ package org.apache.druid.query.aggregation; import org.apache.druid.collections.SerializablePair; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.segment.GenericColumnSerializer; import org.apache.druid.segment.column.ColumnBuilder; import org.apache.druid.segment.data.ObjectStrategy; @@ -35,6 +34,8 @@ public class SerializablePairLongDoubleComplexMetricSerde extends AbstractSerial { public static final String TYPE_NAME = "serializablePairLongDouble"; + private static final SerializablePairLongDoubleSimpleStagedSerde SERDE = new SerializablePairLongDoubleSimpleStagedSerde(); + private static final Comparator> COMPARATOR = SerializablePair.createNullHandlingComparator( Double::compare, true @@ -90,32 +91,17 @@ public Class getClazz() @Override public SerializablePairLongDouble fromByteBuffer(ByteBuffer buffer, int numBytes) { - final ByteBuffer readOnlyBuffer = buffer.asReadOnlyBuffer(); - long lhs = readOnlyBuffer.getLong(); - boolean isNotNull = readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE; - if (isNotNull) { - return new SerializablePairLongDouble(lhs, readOnlyBuffer.getDouble()); - } else { - return new SerializablePairLongDouble(lhs, null); - } + ByteBuffer readOnlyByteBuffer = buffer.asReadOnlyBuffer().order(buffer.order()); + + readOnlyByteBuffer.limit(buffer.position() + numBytes); + + return SERDE.deserialize(readOnlyByteBuffer); } @Override public byte[] toBytes(@Nullable SerializablePairLongDouble inPair) { - if (inPair == null) { - return new byte[]{}; - } - - ByteBuffer bbuf = ByteBuffer.allocate(Long.BYTES + Byte.BYTES + Double.BYTES); - bbuf.putLong(inPair.lhs); - if (inPair.rhs == null) { - bbuf.put(NullHandling.IS_NULL_BYTE); - } else { - bbuf.put(NullHandling.IS_NOT_NULL_BYTE); - bbuf.putDouble(inPair.rhs); - } - return bbuf.array(); + return SERDE.serialize(inPair); } }; } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleDeltaEncodedStagedSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleDeltaEncodedStagedSerde.java index 5278ef691175..ce087ec623e0 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleDeltaEncodedStagedSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleDeltaEncodedStagedSerde.java @@ -19,6 +19,8 @@ package org.apache.druid.query.aggregation; +import org.apache.druid.common.config.NullHandling; + import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -51,7 +53,7 @@ public SerializablePairLongDouble deserialize(ByteBuffer byteBuffer) lhs += minValue; Double rhs = null; - if (readOnlyBuffer.hasRemaining()) { + if (readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE) { rhs = readOnlyBuffer.getDouble(); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleSimpleStagedSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleSimpleStagedSerde.java index 5f62a7ed3a65..bf5c60e0c5b5 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleSimpleStagedSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleSimpleStagedSerde.java @@ -19,6 +19,8 @@ package org.apache.druid.query.aggregation; +import org.apache.druid.common.config.NullHandling; + import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -42,7 +44,7 @@ public SerializablePairLongDouble deserialize(ByteBuffer byteBuffer) long lhs = readOnlyBuffer.getLong(); Double rhs = null; - if (readOnlyBuffer.hasRemaining()) { + if (readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE) { rhs = readOnlyBuffer.getDouble(); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatComplexMetricSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatComplexMetricSerde.java index 137e8e58268f..2ffcb635b82e 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatComplexMetricSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatComplexMetricSerde.java @@ -20,7 +20,6 @@ package org.apache.druid.query.aggregation; import org.apache.druid.collections.SerializablePair; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.segment.GenericColumnSerializer; import org.apache.druid.segment.column.ColumnBuilder; import org.apache.druid.segment.data.ObjectStrategy; @@ -35,6 +34,8 @@ public class SerializablePairLongFloatComplexMetricSerde extends AbstractSeriali { public static final String TYPE_NAME = "serializablePairLongFloat"; + private static final SerializablePairLongFloatSimpleStagedSerde SERDE = new SerializablePairLongFloatSimpleStagedSerde(); + private static final Comparator> COMPARATOR = SerializablePair.createNullHandlingComparator( Float::compare, true @@ -91,32 +92,17 @@ public Class getClazz() @Override public SerializablePairLongFloat fromByteBuffer(ByteBuffer buffer, int numBytes) { - final ByteBuffer readOnlyBuffer = buffer.asReadOnlyBuffer(); - long lhs = readOnlyBuffer.getLong(); - boolean isNotNull = readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE; - if (isNotNull) { - return new SerializablePairLongFloat(lhs, readOnlyBuffer.getFloat()); - } else { - return new SerializablePairLongFloat(lhs, null); - } + ByteBuffer readOnlyByteBuffer = buffer.asReadOnlyBuffer().order(buffer.order()); + + readOnlyByteBuffer.limit(buffer.position() + numBytes); + + return SERDE.deserialize(readOnlyByteBuffer); } @Override public byte[] toBytes(@Nullable SerializablePairLongFloat inPair) { - if (inPair == null) { - return new byte[]{}; - } - - ByteBuffer bbuf = ByteBuffer.allocate(Long.BYTES + Byte.BYTES + Float.BYTES); - bbuf.putLong(inPair.lhs); - if (inPair.rhs == null) { - bbuf.put(NullHandling.IS_NULL_BYTE); - } else { - bbuf.put(NullHandling.IS_NOT_NULL_BYTE); - bbuf.putFloat(inPair.rhs); - } - return bbuf.array(); + return SERDE.serialize(inPair); } }; } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatDeltaEncodedStagedSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatDeltaEncodedStagedSerde.java index 059fa4bd0b6d..0f5e4f7f4a93 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatDeltaEncodedStagedSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatDeltaEncodedStagedSerde.java @@ -19,6 +19,8 @@ package org.apache.druid.query.aggregation; +import org.apache.druid.common.config.NullHandling; + import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -51,7 +53,7 @@ public SerializablePairLongFloat deserialize(ByteBuffer byteBuffer) lhs += minValue; Float rhs = null; - if (readOnlyBuffer.hasRemaining()) { + if (readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE) { rhs = readOnlyBuffer.getFloat(); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatSimpleStagedSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatSimpleStagedSerde.java index d8d6f9b5ef47..d8390974a95d 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatSimpleStagedSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongFloatSimpleStagedSerde.java @@ -19,6 +19,8 @@ package org.apache.druid.query.aggregation; +import org.apache.druid.common.config.NullHandling; + import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -42,7 +44,7 @@ public SerializablePairLongFloat deserialize(ByteBuffer byteBuffer) long lhs = readOnlyBuffer.getLong(); Float rhs = null; - if (readOnlyBuffer.hasRemaining()) { + if (readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE) { rhs = readOnlyBuffer.getFloat(); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongComplexMetricSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongComplexMetricSerde.java index fececb873101..a5f79f007279 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongComplexMetricSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongComplexMetricSerde.java @@ -20,7 +20,6 @@ package org.apache.druid.query.aggregation; import org.apache.druid.collections.SerializablePair; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.segment.GenericColumnSerializer; import org.apache.druid.segment.column.ColumnBuilder; import org.apache.druid.segment.data.ObjectStrategy; @@ -35,6 +34,8 @@ public class SerializablePairLongLongComplexMetricSerde extends AbstractSerializ { public static final String TYPE_NAME = "serializablePairLongLong"; + private static final SerializablePairLongLongSimpleStagedSerde SERDE = new SerializablePairLongLongSimpleStagedSerde(); + private static final Comparator> COMPARATOR = SerializablePair.createNullHandlingComparator( Long::compare, true @@ -90,32 +91,17 @@ public Class getClazz() @Override public SerializablePairLongLong fromByteBuffer(ByteBuffer buffer, int numBytes) { - final ByteBuffer readOnlyBuffer = buffer.asReadOnlyBuffer(); - long lhs = readOnlyBuffer.getLong(); - boolean isNotNull = readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE; - if (isNotNull) { - return new SerializablePairLongLong(lhs, readOnlyBuffer.getLong()); - } else { - return new SerializablePairLongLong(lhs, null); - } + ByteBuffer readOnlyByteBuffer = buffer.asReadOnlyBuffer().order(buffer.order()); + + readOnlyByteBuffer.limit(buffer.position() + numBytes); + + return SERDE.deserialize(readOnlyByteBuffer); } @Override public byte[] toBytes(@Nullable SerializablePairLongLong inPair) { - if (inPair == null) { - return new byte[]{}; - } - - ByteBuffer bbuf = ByteBuffer.allocate(Long.BYTES + Byte.BYTES + Long.BYTES); - bbuf.putLong(inPair.lhs); - if (inPair.rhs == null) { - bbuf.put(NullHandling.IS_NULL_BYTE); - } else { - bbuf.put(NullHandling.IS_NOT_NULL_BYTE); - bbuf.putLong(inPair.rhs); - } - return bbuf.array(); + return SERDE.serialize(inPair); } }; } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongDeltaEncodedStagedSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongDeltaEncodedStagedSerde.java index bf0414031a88..dad3711c3c73 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongDeltaEncodedStagedSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongDeltaEncodedStagedSerde.java @@ -19,6 +19,8 @@ package org.apache.druid.query.aggregation; +import org.apache.druid.common.config.NullHandling; + import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -51,7 +53,7 @@ public SerializablePairLongLong deserialize(ByteBuffer byteBuffer) lhs += minValue; Long rhs = null; - if (readOnlyBuffer.hasRemaining()) { + if (readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE) { rhs = readOnlyBuffer.getLong(); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongSimpleStagedSerde.java b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongSimpleStagedSerde.java index c94e842b5289..587ce18a0b4c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongSimpleStagedSerde.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongLongSimpleStagedSerde.java @@ -19,6 +19,8 @@ package org.apache.druid.query.aggregation; +import org.apache.druid.common.config.NullHandling; + import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -42,7 +44,7 @@ public SerializablePairLongLong deserialize(ByteBuffer byteBuffer) long lhs = readOnlyBuffer.getLong(); Long rhs = null; - if (readOnlyBuffer.hasRemaining()) { + if (readOnlyBuffer.get() == NullHandling.IS_NOT_NULL_BYTE) { rhs = readOnlyBuffer.getLong(); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java index ac6c5c7a75e4..80a1e0f7e12a 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java @@ -19,8 +19,10 @@ package org.apache.druid.query.aggregation.any; -import org.apache.druid.collections.SerializablePair; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.query.aggregation.SerializablePairLongDouble; +import org.apache.druid.query.aggregation.SerializablePairLongFloat; +import org.apache.druid.query.aggregation.SerializablePairLongLong; import org.apache.druid.query.aggregation.VectorAggregator; import javax.annotation.Nullable; @@ -43,9 +45,9 @@ public class NilVectorAggregator implements VectorAggregator NullHandling.defaultLongValue() ); - public static final SerializablePair DOUBLE_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultDoubleValue()); - public static final SerializablePair LONG_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultLongValue()); - public static final SerializablePair FLOAT_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultFloatValue()); + public static final SerializablePairLongDouble DOUBLE_NIL_PAIR = new SerializablePairLongDouble(0L, NullHandling.defaultDoubleValue()); + public static final SerializablePairLongLong LONG_NIL_PAIR = new SerializablePairLongLong(0L, NullHandling.defaultLongValue()); + public static final SerializablePairLongFloat FLOAT_NIL_PAIR = new SerializablePairLongFloat(0L, NullHandling.defaultFloatValue()); /** * @return A vectorized aggregator that returns the default double value. diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java index 8c9f13a6ced0..20251410f38f 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java @@ -43,7 +43,9 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -125,7 +127,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) return new DoubleFirstAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongDouble.class @@ -144,7 +146,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) return new DoubleFirstBufferAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongDouble.class @@ -158,12 +160,24 @@ public VectorAggregator factorizeVector( VectorColumnSelectorFactory columnSelectorFactory ) { + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); + if (Types.isNumeric(capabilities)) { VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( - timeColumn); - return new DoubleFirstVectorAggregator(timeSelector, valueSelector); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + columnSelectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.DOUBLE + ); + return new DoubleFirstVectorAggregator(timeSelector, objectSelector); + } + + VectorObjectSelector vSelector = columnSelectorFactory.makeObjectSelector(fieldName); + if (capabilities != null) { + return new DoubleFirstVectorAggregator(timeSelector, vSelector); } return NilVectorAggregator.of(NilVectorAggregator.DOUBLE_NIL_PAIR); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstVectorAggregator.java index 5221bc74b3e4..d58882e8439d 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstVectorAggregator.java @@ -20,6 +20,7 @@ package org.apache.druid.query.aggregation.first; import org.apache.druid.query.aggregation.SerializablePairLongDouble; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import javax.annotation.Nullable; @@ -28,9 +29,9 @@ public class DoubleFirstVectorAggregator extends NumericFirstVectorAggregator { - public DoubleFirstVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector) + public DoubleFirstVectorAggregator(VectorValueSelector timeSelector, VectorObjectSelector valueSelector) { - super(timeSelector, valueSelector); + super(timeSelector, valueSelector, SerializablePairLongDouble.class); } @Override @@ -41,9 +42,9 @@ public void initValue(ByteBuffer buf, int position) @Override - void putValue(ByteBuffer buf, int position, int index) + void putValue(ByteBuffer buf, int position, Number number) { - double firstValue = valueSelector.getDoubleVector()[index]; + double firstValue = number.doubleValue(); buf.putDouble(position, firstValue); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/FirstLastUtils.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/FirstLastUtils.java new file mode 100644 index 000000000000..7120d3a4605f --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/FirstLastUtils.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.first; + +import org.apache.druid.segment.BaseObjectColumnValueSelector; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ValueType; + +import javax.annotation.Nullable; + +public class FirstLastUtils +{ + + /** + * Returns whether a given value selector *might* contain SerializablePairLongString objects. + */ + public static boolean selectorNeedsFoldCheck( + final BaseObjectColumnValueSelector valueSelector, + @Nullable final ColumnCapabilities valueSelectorCapabilities, + Class pairClass + ) + { + if (valueSelectorCapabilities != null && !valueSelectorCapabilities.is(ValueType.COMPLEX)) { + // Known, non-complex type. + return false; + } + + if (valueSelector instanceof NilColumnValueSelector) { + // Nil column, definitely no SerializablePairLongStrings. + return false; + } + + // Check if the selector class could possibly be a SerializablePairLongString (either a superclass or subclass). + final Class clazz = valueSelector.classOfObject(); + return clazz.isAssignableFrom(pairClass) + || pairClass.isAssignableFrom(clazz); + } + + /** + * Returns whether an object *might* contain SerializablePairLongString objects. + */ + public static boolean objectNeedsFoldCheck(Object obj, Class pairClass) + { + if (obj == null) { + return false; + } + final Class clazz = obj.getClass(); + return clazz.isAssignableFrom(pairClass) + || pairClass.isAssignableFrom(clazz); + } + + + public static boolean[] getNullVector(Object[] objectVector) + { + boolean containsNonNullValues = false; + boolean[] nullValueVector = new boolean[objectVector.length]; + for (int i = 0; i < objectVector.length; i++) { + if (objectVector[i] != null) { + containsNonNullValues = true; + nullValueVector[i] = false; + } + } + return containsNonNullValues ? null : nullValueVector; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java index 56556ab13745..51eb0345d77c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java @@ -43,7 +43,9 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -116,7 +118,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) return new FloatFirstAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongFloat.class @@ -135,7 +137,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) return new FloatFirstBufferAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongFloat.class @@ -147,12 +149,25 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) @Override public VectorAggregator factorizeVector(VectorColumnSelectorFactory columnSelectorFactory) { + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); if (Types.isNumeric(capabilities)) { VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); - return new FloatFirstVectorAggregator(timeSelector, valueSelector); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + columnSelectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.FLOAT + ); + return new FloatFirstVectorAggregator(timeSelector, objectSelector); } + + VectorObjectSelector vSelector = columnSelectorFactory.makeObjectSelector(fieldName); + if (capabilities != null) { + return new FloatFirstVectorAggregator(timeSelector, vSelector); + } + return NilVectorAggregator.of(NilVectorAggregator.FLOAT_NIL_PAIR); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstVectorAggregator.java index a1671d9c7a24..44ddf24c0319 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstVectorAggregator.java @@ -20,6 +20,7 @@ package org.apache.druid.query.aggregation.first; import org.apache.druid.query.aggregation.SerializablePairLongFloat; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import javax.annotation.Nullable; @@ -28,9 +29,9 @@ public class FloatFirstVectorAggregator extends NumericFirstVectorAggregator { - public FloatFirstVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector) + public FloatFirstVectorAggregator(VectorValueSelector timeSelector, VectorObjectSelector valueSelector) { - super(timeSelector, valueSelector); + super(timeSelector, valueSelector, SerializablePairLongFloat.class); } @Override @@ -39,11 +40,10 @@ public void initValue(ByteBuffer buf, int position) buf.putFloat(position, 0); } - @Override - void putValue(ByteBuffer buf, int position, int index) + void putValue(ByteBuffer buf, int position, Number number) { - float firstValue = valueSelector.getFloatVector()[index]; + float firstValue = number.floatValue(); buf.putFloat(position, firstValue); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java index 9be911d35a25..39afef69c300 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -43,7 +44,9 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -56,6 +59,7 @@ public class LongFirstAggregatorFactory extends AggregatorFactory { public static final ColumnType TYPE = ColumnType.ofComplex(SerializablePairLongLongComplexMetricSerde.TYPE_NAME); + private static final Logger log = new Logger(LongFirstAggregatorFactory.class); private static final Aggregator NIL_AGGREGATOR = new LongFirstAggregator( NilColumnValueSelector.instance(), @@ -115,7 +119,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) return new LongFirstAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongLong.class @@ -134,7 +138,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) return new LongFirstBufferAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongLong.class @@ -146,12 +150,23 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) @Override public VectorAggregator factorizeVector(VectorColumnSelectorFactory columnSelectorFactory) { + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); if (Types.isNumeric(capabilities)) { VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( - timeColumn); - return new LongFirstVectorAggregator(timeSelector, valueSelector); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + columnSelectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.LONG + ); + return new LongFirstVectorAggregator(timeSelector, objectSelector); + } + + VectorObjectSelector vSelector = columnSelectorFactory.makeObjectSelector(fieldName); + if (capabilities != null) { + return new LongFirstVectorAggregator(timeSelector, vSelector); } return NilVectorAggregator.of(NilVectorAggregator.LONG_NIL_PAIR); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstVectorAggregator.java index 0a40e5ad870c..961f124c5df3 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstVectorAggregator.java @@ -20,6 +20,7 @@ package org.apache.druid.query.aggregation.first; import org.apache.druid.query.aggregation.SerializablePairLongLong; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import javax.annotation.Nullable; @@ -27,9 +28,9 @@ public class LongFirstVectorAggregator extends NumericFirstVectorAggregator { - public LongFirstVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector) + public LongFirstVectorAggregator(VectorValueSelector timeSelector, VectorObjectSelector valueSelector) { - super(timeSelector, valueSelector); + super(timeSelector, valueSelector, SerializablePairLongLong.class); } @Override @@ -38,15 +39,13 @@ public void initValue(ByteBuffer buf, int position) buf.putLong(position, 0); } - @Override - void putValue(ByteBuffer buf, int position, int index) + void putValue(ByteBuffer buf, int position, Number number) { - long firstValue = valueSelector.getLongVector()[index]; + long firstValue = number.longValue(); buf.putLong(position, firstValue); } - /** * @return The object as a pair with the position and the value stored at the position in the buffer. */ diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstVectorAggregator.java index 7fcd10352da9..46d311ed45fa 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstVectorAggregator.java @@ -19,8 +19,10 @@ package org.apache.druid.query.aggregation.first; +import org.apache.druid.collections.SerializablePair; import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import javax.annotation.Nullable; @@ -33,15 +35,17 @@ public abstract class NumericFirstVectorAggregator implements VectorAggregator { static final int NULL_OFFSET = Long.BYTES; static final int VALUE_OFFSET = NULL_OFFSET + Byte.BYTES; - final VectorValueSelector valueSelector; + final VectorObjectSelector valueSelector; + private final Class pairClass; private final boolean useDefault = NullHandling.replaceWithDefault(); private final VectorValueSelector timeSelector; private long firstTime; - public NumericFirstVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector) + NumericFirstVectorAggregator(VectorValueSelector timeSelector, VectorObjectSelector valueSelector, Class pairClass) { this.timeSelector = timeSelector; this.valueSelector = valueSelector; + this.pairClass = pairClass; firstTime = Long.MAX_VALUE; } @@ -58,7 +62,8 @@ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) { final long[] timeVector = timeSelector.getLongVector(); final boolean[] nullTimeVector = timeSelector.getNullVector(); - final boolean[] nullValueVector = valueSelector.getNullVector(); + final Object[] objectsWhichMightBeNumeric = valueSelector.getObjectVector(); + final boolean[] nullValueVector = FirstLastUtils.getNullVector(objectsWhichMightBeNumeric); firstTime = buf.getLong(position); // the time vector is already sorted @@ -68,22 +73,42 @@ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) // A possible optimization here is to have 2 paths one for earliest where // we can take advantage of the sorted nature of time // and the earliest_by where we have to go over all elements. - int index; + int index; for (int i = startRow; i < endRow; i++) { - index = i; - if (nullTimeVector != null && nullTimeVector[index]) { + if (nullTimeVector != null && nullTimeVector[i]) { continue; } - final long earliestTime = timeVector[index]; - if (earliestTime >= firstTime) { + + if (timeVector[i] >= firstTime) { continue; } - firstTime = earliestTime; - if (useDefault || nullValueVector == null || !nullValueVector[index]) { - updateTimeWithValue(buf, position, firstTime, index); + index = i; + + final boolean foldNeeded = FirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeNumeric[index], pairClass); + if (foldNeeded) { + final SerializablePair inPair = (SerializablePair) objectsWhichMightBeNumeric[index]; + + if (inPair.lhs < firstTime) { + firstTime = inPair.lhs; + if (useDefault || inPair.rhs != null) { + updateTimeWithValue(buf, position, firstTime, inPair.getRhs()); + } else { + updateTimeWithNull(buf, position, firstTime); + } + } } else { - updateTimeWithNull(buf, position, firstTime); + final long earliestTime = timeVector[index]; + + if (earliestTime < firstTime) { + firstTime = earliestTime; + + if (useDefault || nullValueVector == null || !nullValueVector[index]) { + updateTimeWithValue(buf, position, earliestTime, (Number) objectsWhichMightBeNumeric[index]); + } else { + updateTimeWithNull(buf, position, earliestTime); + } + } } } } @@ -110,20 +135,45 @@ public void aggregate( int positionOffset ) { - boolean[] nulls = useDefault ? null : valueSelector.getNullVector(); - long[] timeVector = timeSelector.getLongVector(); + final long[] timeVector = timeSelector.getLongVector(); + final Object[] objectsWhichMightBeNumeric = valueSelector.getObjectVector(); + final boolean[] nullValueVector = FirstLastUtils.getNullVector(objectsWhichMightBeNumeric); + boolean[] nulls = useDefault ? null : nullValueVector; + + // iterate once over the object vector to find first non null element and + // determine if the type is Pair or not + boolean foldNeeded = false; + for (Object obj : objectsWhichMightBeNumeric) { + if (obj != null) { + foldNeeded = FirstLastUtils.objectNeedsFoldCheck(obj, pairClass); + break; + } + } for (int i = 0; i < numRows; i++) { int position = positions[i] + positionOffset; int row = rows == null ? i : rows[i]; - long firstTime = buf.getLong(position); - if (timeVector[row] < firstTime) { - if (useDefault || nulls == null || !nulls[row]) { - updateTimeWithValue(buf, position, timeVector[row], row); + firstTime = buf.getLong(position); + + if (foldNeeded) { + final SerializablePair inPair = (SerializablePair) objectsWhichMightBeNumeric[row]; + if (useDefault || inPair != null) { + if (inPair.lhs < firstTime) { + updateTimeWithValue(buf, position, inPair.lhs, inPair.rhs); + } } else { - updateTimeWithNull(buf, position, timeVector[row]); + updateTimeWithNull(buf, position, inPair.lhs); + } + } else { + if (timeVector[row] < firstTime) { + if (useDefault || nulls == null || objectsWhichMightBeNumeric[row] != null) { + updateTimeWithValue(buf, position, timeVector[row], (Number) objectsWhichMightBeNumeric[row]); + } else { + updateTimeWithNull(buf, position, timeVector[row]); + } } } + } } @@ -132,14 +182,14 @@ public void aggregate( * * @param buf byte buffer storing the byte array representation of the aggregate * @param position offset within the byte buffer at which the current aggregate value is stored - * @param time the time to be updated in the buffer as the last time - * @param index the index of the vectorized vector which is the last value + * @param time the time to be updated in the buffer as the first time + * @param number number which is the first value */ - void updateTimeWithValue(ByteBuffer buf, int position, long time, int index) + void updateTimeWithValue(ByteBuffer buf, int position, long time, Number number) { buf.putLong(position, time); buf.put(position + NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE); - putValue(buf, position + VALUE_OFFSET, index); + putValue(buf, position + VALUE_OFFSET, number); } /** @@ -164,7 +214,7 @@ void updateTimeWithNull(ByteBuffer buf, int position, long time) * Abstract function which needs to be overridden by subclasses to set the * latest value in the buffer depending on the datatype */ - abstract void putValue(ByteBuffer buf, int position, int index); + abstract void putValue(ByteBuffer buf, int position, Number number); @Override public void close() diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java index c1eaf2abd1e1..8dbe12d4ddc5 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java @@ -165,7 +165,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) metricFactory.makeColumnValueSelector(timeColumn), valueSelector, maxStringBytes, - StringFirstLastUtils.selectorNeedsFoldCheck(valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongString.class) + FirstLastUtils.selectorNeedsFoldCheck(valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongString.class) ); } } @@ -181,7 +181,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) metricFactory.makeColumnValueSelector(timeColumn), valueSelector, maxStringBytes, - StringFirstLastUtils.selectorNeedsFoldCheck(valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongString.class) + FirstLastUtils.selectorNeedsFoldCheck(valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongString.class) ); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java index 4e5809f3cc8f..ff1113c91132 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java @@ -24,9 +24,6 @@ import org.apache.druid.segment.BaseLongColumnValueSelector; import org.apache.druid.segment.BaseObjectColumnValueSelector; import org.apache.druid.segment.DimensionHandlerUtils; -import org.apache.druid.segment.NilColumnValueSelector; -import org.apache.druid.segment.column.ColumnCapabilities; -import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; @@ -37,44 +34,6 @@ public class StringFirstLastUtils { private static final int NULL_VALUE = -1; - /** - * Returns whether a given value selector *might* contain SerializablePairLongString objects. - */ - public static boolean selectorNeedsFoldCheck( - final BaseObjectColumnValueSelector valueSelector, - @Nullable final ColumnCapabilities valueSelectorCapabilities, - Class pairClass - ) - { - if (valueSelectorCapabilities != null && !valueSelectorCapabilities.is(ValueType.COMPLEX)) { - // Known, non-complex type. - return false; - } - - if (valueSelector instanceof NilColumnValueSelector) { - // Nil column, definitely no SerializablePairLongStrings. - return false; - } - - // Check if the selector class could possibly be a SerializablePairLongString (either a superclass or subclass). - final Class clazz = valueSelector.classOfObject(); - return clazz.isAssignableFrom(pairClass) - || pairClass.isAssignableFrom(clazz); - } - - /** - * Returns whether an object *might* contain SerializablePairLongString objects. - */ - public static boolean objectNeedsFoldCheck(Object obj) - { - if (obj == null) { - return false; - } - final Class clazz = obj.getClass(); - return clazz.isAssignableFrom(SerializablePairLongString.class) - || SerializablePairLongString.class.isAssignableFrom(clazz); - } - /** * Return the object at a particular index from the vector selectors. * index of bounds issues is the responsibility of the caller diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregator.java index 3e31300bad8b..fd2260b8d665 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregator.java @@ -76,7 +76,7 @@ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) continue; } index = i; - final boolean foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]); + final boolean foldNeeded = FirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index], SerializablePairLongString.class); if (foldNeeded) { final SerializablePairLongString inPair = StringFirstLastUtils.readPairFromVectorSelectorsAtIndex( timeSelector, @@ -125,7 +125,7 @@ public void aggregate(ByteBuffer buf, int numRows, int[] positions, @Nullable in if (obj == null) { continue; } else { - foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj); + foldNeeded = FirstLastUtils.objectNeedsFoldCheck(obj, SerializablePairLongString.class); break; } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java index 63c1385f2f5f..bac1109e570c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java @@ -24,7 +24,6 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -35,7 +34,7 @@ import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory; -import org.apache.druid.query.aggregation.first.StringFirstLastUtils; +import org.apache.druid.query.aggregation.first.FirstLastUtils; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnSelectorFactory; @@ -46,7 +45,9 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -116,7 +117,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) return new DoubleLastAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongDouble.class @@ -136,13 +137,26 @@ public VectorAggregator factorizeVector( VectorColumnSelectorFactory columnSelectorFactory ) { + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); + if (Types.isNumeric(capabilities)) { VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); - return new DoubleLastVectorAggregator(timeSelector, valueSelector); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + columnSelectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.DOUBLE + ); + return new DoubleLastVectorAggregator(timeSelector, objectSelector); + } + + VectorObjectSelector vSelector = columnSelectorFactory.makeObjectSelector(fieldName); + if (capabilities != null) { + return new DoubleLastVectorAggregator(timeSelector, vSelector); } else { - return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultDoubleValue())); + return NilVectorAggregator.of(NilVectorAggregator.DOUBLE_NIL_PAIR); } } @@ -156,7 +170,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) return new DoubleLastBufferAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongDouble.class diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastVectorAggregator.java index a5e9cf9d324e..ab02d8e86d22 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastVectorAggregator.java @@ -20,6 +20,7 @@ package org.apache.druid.query.aggregation.last; import org.apache.druid.query.aggregation.SerializablePairLongDouble; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import javax.annotation.Nullable; @@ -32,16 +33,16 @@ public class DoubleLastVectorAggregator extends NumericLastVectorAggregator { double lastValue; - public DoubleLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector) + public DoubleLastVectorAggregator(VectorValueSelector timeSelector, VectorObjectSelector valueSelector) { - super(timeSelector, valueSelector); + super(timeSelector, valueSelector, SerializablePairLongDouble.class); lastValue = 0; } @Override - void putValue(ByteBuffer buf, int position, int index) + void putValue(ByteBuffer buf, int position, Number number) { - lastValue = valueSelector.getDoubleVector()[index]; + lastValue = number.doubleValue(); buf.putDouble(position, lastValue); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java index 7977b4ff601e..d8d71765a25c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java @@ -24,7 +24,6 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -35,7 +34,7 @@ import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory; -import org.apache.druid.query.aggregation.first.StringFirstLastUtils; +import org.apache.druid.query.aggregation.first.FirstLastUtils; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnSelectorFactory; @@ -46,7 +45,9 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -114,7 +115,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) return new FloatLastAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongFloat.class @@ -133,7 +134,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) return new FloatLastBufferAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongFloat.class @@ -154,12 +155,25 @@ public VectorAggregator factorizeVector( ) { final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); + if (Types.isNumeric(capabilities)) { VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); - return new FloatLastVectorAggregator(timeSelector, valueSelector); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + columnSelectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.FLOAT + ); + return new FloatLastVectorAggregator(timeSelector, objectSelector); + } + + VectorObjectSelector vSelector = columnSelectorFactory.makeObjectSelector(fieldName); + if (capabilities != null) { + return new FloatLastVectorAggregator(timeSelector, vSelector); } else { - return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultFloatValue())); + return NilVectorAggregator.of(NilVectorAggregator.FLOAT_NIL_PAIR); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastVectorAggregator.java index b1814b5dc07e..c9bde26266c9 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastVectorAggregator.java @@ -20,6 +20,7 @@ package org.apache.druid.query.aggregation.last; import org.apache.druid.query.aggregation.SerializablePairLongFloat; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import javax.annotation.Nullable; @@ -32,20 +33,20 @@ public class FloatLastVectorAggregator extends NumericLastVectorAggregator { float lastValue; - public FloatLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector) + public FloatLastVectorAggregator(VectorValueSelector timeSelector, VectorObjectSelector valueSelector) { - super(timeSelector, valueSelector); + super(timeSelector, valueSelector, SerializablePairLongFloat.class); lastValue = 0; } - @Override - void putValue(ByteBuffer buf, int position, int index) + void putValue(ByteBuffer buf, int position, Number number) { - lastValue = valueSelector.getFloatVector()[index]; + lastValue = number.floatValue(); buf.putFloat(position, lastValue); } + @Override public void initValue(ByteBuffer buf, int position) { diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java index 2637d1a26c7b..751de9fbb7ef 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java @@ -24,7 +24,6 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -35,7 +34,7 @@ import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory; -import org.apache.druid.query.aggregation.first.StringFirstLastUtils; +import org.apache.druid.query.aggregation.first.FirstLastUtils; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnSelectorFactory; @@ -46,7 +45,9 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -115,7 +116,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) return new LongLastAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongLong.class @@ -134,7 +135,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) return new LongLastBufferAggregator( metricFactory.makeColumnValueSelector(timeColumn), valueSelector, - StringFirstLastUtils.selectorNeedsFoldCheck( + FirstLastUtils.selectorNeedsFoldCheck( valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongLong.class @@ -155,12 +156,26 @@ public VectorAggregator factorizeVector( ) { final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); + if (Types.isNumeric(capabilities)) { VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); - return new LongLastVectorAggregator(timeSelector, valueSelector); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + columnSelectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.LONG + ); + + return new LongLastVectorAggregator(timeSelector, objectSelector); + } + + VectorObjectSelector vSelector = columnSelectorFactory.makeObjectSelector(fieldName); + if (capabilities != null) { + return new LongLastVectorAggregator(timeSelector, vSelector); } else { - return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultLongValue())); + return NilVectorAggregator.of(NilVectorAggregator.LONG_NIL_PAIR); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastVectorAggregator.java index ea91430c80f1..c0446e9e8e32 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastVectorAggregator.java @@ -20,6 +20,7 @@ package org.apache.druid.query.aggregation.last; import org.apache.druid.query.aggregation.SerializablePairLongLong; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import javax.annotation.Nullable; @@ -32,9 +33,9 @@ public class LongLastVectorAggregator extends NumericLastVectorAggregator { long lastValue; - public LongLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector) + public LongLastVectorAggregator(VectorValueSelector timeSelector, VectorObjectSelector valueSelector) { - super(timeSelector, valueSelector); + super(timeSelector, valueSelector, SerializablePairLongLong.class); lastValue = 0; } @@ -44,11 +45,10 @@ public void initValue(ByteBuffer buf, int position) buf.putLong(position, 0); } - @Override - void putValue(ByteBuffer buf, int position, int index) + void putValue(ByteBuffer buf, int position, Number number) { - lastValue = valueSelector.getLongVector()[index]; + lastValue = number.longValue(); buf.putLong(position, lastValue); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastVectorAggregator.java index 717470e8921d..0cda99dfba56 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastVectorAggregator.java @@ -19,8 +19,11 @@ package org.apache.druid.query.aggregation.last; +import org.apache.druid.collections.SerializablePair; import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.aggregation.first.FirstLastUtils; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import javax.annotation.Nullable; @@ -33,15 +36,18 @@ public abstract class NumericLastVectorAggregator implements VectorAggregator { static final int NULL_OFFSET = Long.BYTES; static final int VALUE_OFFSET = NULL_OFFSET + Byte.BYTES; - final VectorValueSelector valueSelector; + final VectorObjectSelector valueSelector; + private final Class pairClass; private final boolean useDefault = NullHandling.replaceWithDefault(); private final VectorValueSelector timeSelector; private long lastTime; - public NumericLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector) + + NumericLastVectorAggregator(VectorValueSelector timeSelector, VectorObjectSelector valueSelector, Class pairClass) { this.timeSelector = timeSelector; this.valueSelector = valueSelector; + this.pairClass = pairClass; lastTime = Long.MIN_VALUE; } @@ -56,13 +62,17 @@ public void init(ByteBuffer buf, int position) @Override public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) { + if (timeSelector == null) { + return; + } + final long[] timeVector = timeSelector.getLongVector(); - final boolean[] nullValueVector = valueSelector.getNullVector(); + final Object[] objectsWhichMightBeNumeric = valueSelector.getObjectVector(); + final boolean[] nullValueVector = FirstLastUtils.getNullVector(objectsWhichMightBeNumeric); + boolean nullAbsent = false; lastTime = buf.getLong(position); - //check if nullVector is found or not - // the nullVector is null if no null values are found - // set the nullAbsent flag accordingly + if (nullValueVector == null) { nullAbsent = true; } @@ -79,14 +89,26 @@ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) } } - //find the first non-null value - final long latestTime = timeVector[index]; - if (latestTime >= lastTime) { - lastTime = latestTime; - if (useDefault || nullValueVector == null || !nullValueVector[index]) { - updateTimeWithValue(buf, position, lastTime, index); - } else { - updateTimeWithNull(buf, position, lastTime); + final boolean foldNeeded = FirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeNumeric[index], pairClass); + if (foldNeeded) { + final SerializablePair inPair = (SerializablePair) objectsWhichMightBeNumeric[index]; + if (inPair.lhs >= lastTime) { + lastTime = inPair.lhs; + if (useDefault || inPair.rhs != null) { + updateTimeWithValue(buf, position, lastTime, inPair.getRhs()); + } else { + updateTimeWithNull(buf, position, lastTime); + } + } + } else { + final long latestTime = timeVector[index]; + if (latestTime >= lastTime) { + lastTime = latestTime; + if (useDefault || nullValueVector == null || !nullValueVector[index]) { + updateTimeWithValue(buf, position, lastTime, (Number) objectsWhichMightBeNumeric[index]); + } else { + updateTimeWithNull(buf, position, lastTime); + } } } } @@ -113,21 +135,47 @@ public void aggregate( int positionOffset ) { + if (timeSelector == null) { + return; + } + + final long[] timeVector = timeSelector.getLongVector(); - boolean[] nulls = useDefault ? null : valueSelector.getNullVector(); - long[] timeVector = timeSelector.getLongVector(); + final Object[] objectsWhichMightBeNumeric = valueSelector.getObjectVector(); + boolean[] nulls = useDefault ? null : FirstLastUtils.getNullVector(objectsWhichMightBeNumeric); + + boolean foldNeeded = false; + for (Object obj : objectsWhichMightBeNumeric) { + if (obj != null) { + foldNeeded = FirstLastUtils.objectNeedsFoldCheck(obj, pairClass); + break; + } + } for (int i = 0; i < numRows; i++) { int position = positions[i] + positionOffset; int row = rows == null ? i : rows[i]; long lastTime = buf.getLong(position); - if (timeVector[row] >= lastTime) { - if (useDefault || nulls == null || !nulls[row]) { - updateTimeWithValue(buf, position, timeVector[row], row); - } else { - updateTimeWithNull(buf, position, timeVector[row]); + + if (foldNeeded) { + final SerializablePair inPair = (SerializablePair) objectsWhichMightBeNumeric[row]; + if (inPair.lhs >= lastTime) { + if (useDefault || inPair.rhs != null) { + updateTimeWithValue(buf, position, inPair.lhs, inPair.rhs); + } else { + updateTimeWithNull(buf, position, inPair.lhs); + } + } + } else { + if (timeVector[row] >= lastTime) { + if (useDefault || nulls == null || !nulls[row]) { + updateTimeWithValue(buf, position, timeVector[row], (Number) objectsWhichMightBeNumeric[row]); + } else { + updateTimeWithNull(buf, position, timeVector[row]); + } } } + } } @@ -136,13 +184,13 @@ public void aggregate( * @param buf byte buffer storing the byte array representation of the aggregate * @param position offset within the byte buffer at which the current aggregate value is stored * @param time the time to be updated in the buffer as the last time - * @param index the index of the vectorized vector which is the last value + * @param number number which is the last value */ - void updateTimeWithValue(ByteBuffer buf, int position, long time, int index) + void updateTimeWithValue(ByteBuffer buf, int position, long time, Number number) { buf.putLong(position, time); buf.put(position + NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE); - putValue(buf, position + VALUE_OFFSET, index); + putValue(buf, position + VALUE_OFFSET, number); } /** @@ -166,7 +214,7 @@ void updateTimeWithNull(ByteBuffer buf, int position, long time) *Abstract function which needs to be overridden by subclasses to set the * latest value in the buffer depending on the datatype */ - abstract void putValue(ByteBuffer buf, int position, int index); + abstract void putValue(ByteBuffer buf, int position, Number number); @Override public void close() diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java index 78cba0f3e034..c8282529091c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java @@ -32,6 +32,7 @@ import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.aggregation.first.FirstLastUtils; import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.StringFirstLastUtils; import org.apache.druid.query.cache.CacheKeyBuilder; @@ -130,7 +131,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) metricFactory.makeColumnValueSelector(timeColumn), valueSelector, maxStringBytes, - StringFirstLastUtils.selectorNeedsFoldCheck(valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongString.class) + FirstLastUtils.selectorNeedsFoldCheck(valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongString.class) ); } } @@ -146,7 +147,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) metricFactory.makeColumnValueSelector(timeColumn), valueSelector, maxStringBytes, - StringFirstLastUtils.selectorNeedsFoldCheck(valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongString.class) + FirstLastUtils.selectorNeedsFoldCheck(valueSelector, metricFactory.getColumnCapabilities(fieldName), SerializablePairLongString.class) ); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java index 00e70c78098e..09ef10572b59 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java @@ -22,6 +22,7 @@ import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.aggregation.first.FirstLastUtils; import org.apache.druid.query.aggregation.first.StringFirstLastUtils; import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.vector.VectorObjectSelector; @@ -81,7 +82,7 @@ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) continue; } index = i; - final boolean foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]); + final boolean foldNeeded = FirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index], SerializablePairLongString.class); if (foldNeeded) { // Less efficient code path when folding is a possibility (we must read the value selector first just in case // it's a foldable object). @@ -140,7 +141,7 @@ public void aggregate( boolean foldNeeded = false; for (Object obj : objectsWhichMightBeStrings) { if (obj != null) { - foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj); + foldNeeded = FirstLastUtils.objectNeedsFoldCheck(obj, SerializablePairLongString.class); break; } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleBufferStoreTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleBufferStoreTest.java index 69ee858cbd78..5e6344dd3368 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleBufferStoreTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleBufferStoreTest.java @@ -241,7 +241,7 @@ public void testOverflowTransfer() throws Exception writeOutMedium ); - Assert.assertEquals(93, transferredBuffer.getSerializedSize()); + Assert.assertEquals(94, transferredBuffer.getSerializedSize()); } @Test diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleComplexMetricSerdeTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleComplexMetricSerdeTest.java index 178aab7d1cc0..1f5900ce1c99 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleComplexMetricSerdeTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleComplexMetricSerdeTest.java @@ -59,7 +59,7 @@ public class SerializablePairLongDoubleComplexMetricSerdeTest @Test public void testSingle() throws Exception { - assertExpected(ImmutableList.of(new SerializablePairLongDouble(100L, 10D)), 78); + assertExpected(ImmutableList.of(new SerializablePairLongDouble(100L, 10D)), 75); } @Test @@ -86,7 +86,7 @@ public void testCompressable() throws Exception valueList.add(new SerializablePairLongDouble(Integer.MAX_VALUE + (long) i, doubleList.get(i % numLongs))); } - assertExpected(valueList, 80258); + assertExpected(valueList, 80509); } @Test @@ -99,7 +99,7 @@ public void testHighlyCompressable() throws Exception valueList.add(new SerializablePairLongDouble(Integer.MAX_VALUE + (long) i, doubleValue)); } - assertExpected(valueList, 80024); + assertExpected(valueList, 80274); } @Test @@ -111,13 +111,13 @@ public void testRandom() throws Exception valueList.add(new SerializablePairLongDouble(random.nextLong(), random.nextDouble())); } - assertExpected(valueList, 200612); + assertExpected(valueList, 210958); } @Test public void testNullRHS() throws Exception { - assertExpected(ImmutableList.of(new SerializablePairLongDouble(100L, null)), 70); + assertExpected(ImmutableList.of(new SerializablePairLongDouble(100L, null)), 71); } @Test diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleDeltaEncodedStagedSerdeTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleDeltaEncodedStagedSerdeTest.java index 85e17b14b8fa..9618404419af 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleDeltaEncodedStagedSerdeTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongDoubleDeltaEncodedStagedSerdeTest.java @@ -44,13 +44,13 @@ public void testNull() @Test public void testSimpleInteger() { - assertValueEquals(new SerializablePairLongDouble(100L, 1000000000000.12312312312D), 12, INTEGER_SERDE); + assertValueEquals(new SerializablePairLongDouble(100L, 1000000000000.12312312312D), 13, INTEGER_SERDE); } @Test public void testNullRHSInteger() { - assertValueEquals(new SerializablePairLongDouble(100L, null), 4, INTEGER_SERDE); + assertValueEquals(new SerializablePairLongDouble(100L, null), 5, INTEGER_SERDE); } @Test @@ -58,7 +58,7 @@ public void testLargeRHSInteger() { assertValueEquals( new SerializablePairLongDouble(100L, random.nextDouble()), - 12, + 13, INTEGER_SERDE ); } @@ -66,13 +66,13 @@ public void testLargeRHSInteger() @Test public void testSimpleLong() { - assertValueEquals(new SerializablePairLongDouble(100L, 1000000000000.12312312312D), 16, LONG_SERDE); + assertValueEquals(new SerializablePairLongDouble(100L, 1000000000000.12312312312D), 17, LONG_SERDE); } @Test public void testNullRHSLong() { - assertValueEquals(new SerializablePairLongDouble(100L, null), 8, LONG_SERDE); + assertValueEquals(new SerializablePairLongDouble(100L, null), 9, LONG_SERDE); } @Test @@ -80,7 +80,7 @@ public void testLargeRHSLong() { assertValueEquals( new SerializablePairLongDouble(100L, random.nextDouble()), - 16, + 17, LONG_SERDE ); } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatBufferStoreTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatBufferStoreTest.java index cb71b49b44b2..7cf12af184dc 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatBufferStoreTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatBufferStoreTest.java @@ -241,7 +241,7 @@ public void testOverflowTransfer() throws Exception writeOutMedium ); - Assert.assertEquals(92, transferredBuffer.getSerializedSize()); + Assert.assertEquals(90, transferredBuffer.getSerializedSize()); } @Test diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatComplexMetricSerdeTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatComplexMetricSerdeTest.java index 97d6493110c0..7fc270dae1fc 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatComplexMetricSerdeTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatComplexMetricSerdeTest.java @@ -59,7 +59,7 @@ public class SerializablePairLongFloatComplexMetricSerdeTest @Test public void testSingle() throws Exception { - assertExpected(ImmutableList.of(new SerializablePairLongFloat(100L, 10F)), 74); + assertExpected(ImmutableList.of(new SerializablePairLongFloat(100L, 10F)), 75); } @Test @@ -69,7 +69,7 @@ public void testLargeRHS() throws Exception assertExpected(ImmutableList.of(new SerializablePairLongFloat( 100L, random.nextFloat() - )), 74); + )), 75); } @Test @@ -86,7 +86,7 @@ public void testCompressable() throws Exception valueList.add(new SerializablePairLongFloat(Integer.MAX_VALUE + (long) i, floatList.get(i % numLongs))); } - assertExpected(valueList, 80124); + assertExpected(valueList, 80418); } @Test @@ -99,7 +99,7 @@ public void testHighlyCompressable() throws Exception valueList.add(new SerializablePairLongFloat(Integer.MAX_VALUE + (long) i, floatValue)); } - assertExpected(valueList, 79970); + assertExpected(valueList, 80260); } @Test @@ -111,13 +111,13 @@ public void testRandom() throws Exception valueList.add(new SerializablePairLongFloat(random.nextLong(), random.nextFloat())); } - assertExpected(valueList, 160464); + assertExpected(valueList, 170749); } @Test public void testNullRHS() throws Exception { - assertExpected(ImmutableList.of(new SerializablePairLongFloat(100L, null)), 70); + assertExpected(ImmutableList.of(new SerializablePairLongFloat(100L, null)), 71); } @Test diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatDeltaEncodedStagedSerdeTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatDeltaEncodedStagedSerdeTest.java index d76d4d7772f6..7bf898f5b142 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatDeltaEncodedStagedSerdeTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongFloatDeltaEncodedStagedSerdeTest.java @@ -44,13 +44,13 @@ public void testNull() @Test public void testSimpleInteger() { - assertValueEquals(new SerializablePairLongFloat(100L, 10F), 8, INTEGER_SERDE); + assertValueEquals(new SerializablePairLongFloat(100L, 10F), 9, INTEGER_SERDE); } @Test public void testNullRHSInteger() { - assertValueEquals(new SerializablePairLongFloat(100L, null), 4, INTEGER_SERDE); + assertValueEquals(new SerializablePairLongFloat(100L, null), 5, INTEGER_SERDE); } @Test @@ -58,7 +58,7 @@ public void testLargeRHSInteger() { assertValueEquals( new SerializablePairLongFloat(100L, random.nextFloat()), - 8, + 9, INTEGER_SERDE ); } @@ -66,13 +66,13 @@ public void testLargeRHSInteger() @Test public void testSimpleLong() { - assertValueEquals(new SerializablePairLongFloat(100L, 10F), 12, LONG_SERDE); + assertValueEquals(new SerializablePairLongFloat(100L, 10F), 13, LONG_SERDE); } @Test public void testNullRHSLong() { - assertValueEquals(new SerializablePairLongFloat(100L, null), 8, LONG_SERDE); + assertValueEquals(new SerializablePairLongFloat(100L, null), 9, LONG_SERDE); } @Test @@ -80,7 +80,7 @@ public void testLargeRHSLong() { assertValueEquals( new SerializablePairLongFloat(100L, random.nextFloat()), - 12, + 13, LONG_SERDE ); } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongBufferStoreTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongBufferStoreTest.java index 20c1437297b3..61d4391c45d8 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongBufferStoreTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongBufferStoreTest.java @@ -241,7 +241,7 @@ public void testOverflowTransfer() throws Exception writeOutMedium ); - Assert.assertEquals(92, transferredBuffer.getSerializedSize()); + Assert.assertEquals(94, transferredBuffer.getSerializedSize()); } @Test diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongComplexMetricSerdeTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongComplexMetricSerdeTest.java index 3365cee5b4d5..68429ad99f4b 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongComplexMetricSerdeTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongComplexMetricSerdeTest.java @@ -86,7 +86,7 @@ public void testCompressable() throws Exception valueList.add(new SerializablePairLongLong(Integer.MAX_VALUE + (long) i, longList.get(i % numLongs))); } - assertExpected(valueList, 80258); + assertExpected(valueList, 80509); } @Test @@ -99,7 +99,7 @@ public void testHighlyCompressable() throws Exception valueList.add(new SerializablePairLongLong(Integer.MAX_VALUE + (long) i, longValue)); } - assertExpected(valueList, 80023); + assertExpected(valueList, 80274); } @Test @@ -111,13 +111,13 @@ public void testRandom() throws Exception valueList.add(new SerializablePairLongLong(random.nextLong(), random.nextLong())); } - assertExpected(valueList, 200618); + assertExpected(valueList, 210967); } @Test public void testNullRHS() throws Exception { - assertExpected(ImmutableList.of(new SerializablePairLongLong(100L, null)), 70); + assertExpected(ImmutableList.of(new SerializablePairLongLong(100L, null)), 71); } @Test diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongDeltaEncodedStagedSerdeTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongDeltaEncodedStagedSerdeTest.java index 6edfb022d811..6b6c0c9b037f 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongDeltaEncodedStagedSerdeTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongDeltaEncodedStagedSerdeTest.java @@ -44,13 +44,13 @@ public void testNull() @Test public void testSimpleInteger() { - assertValueEquals(new SerializablePairLongLong(100L, 10L), 12, INTEGER_SERDE); + assertValueEquals(new SerializablePairLongLong(100L, 10L), 13, INTEGER_SERDE); } @Test public void testNullRHSInteger() { - assertValueEquals(new SerializablePairLongLong(100L, null), 4, INTEGER_SERDE); + assertValueEquals(new SerializablePairLongLong(100L, null), 5, INTEGER_SERDE); } @Test @@ -58,7 +58,7 @@ public void testLargeRHSInteger() { assertValueEquals( new SerializablePairLongLong(100L, random.nextLong()), - 12, + 13, INTEGER_SERDE ); } @@ -66,13 +66,13 @@ public void testLargeRHSInteger() @Test public void testSimpleLong() { - assertValueEquals(new SerializablePairLongLong(100L, 10L), 16, LONG_SERDE); + assertValueEquals(new SerializablePairLongLong(100L, 10L), 17, LONG_SERDE); } @Test public void testNullRHSLong() { - assertValueEquals(new SerializablePairLongLong(100L, null), 8, LONG_SERDE); + assertValueEquals(new SerializablePairLongLong(100L, null), 9, LONG_SERDE); } @Test @@ -80,7 +80,7 @@ public void testLargeRHSLong() { assertValueEquals( new SerializablePairLongLong(100L, random.nextLong()), - 16, + 17, LONG_SERDE ); } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongSimpleStagedSerdeTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongSimpleStagedSerdeTest.java index e903087105f0..2e0f3e8b0ef7 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongSimpleStagedSerdeTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/SerializablePairLongLongSimpleStagedSerdeTest.java @@ -35,7 +35,7 @@ public class SerializablePairLongLongSimpleStagedSerdeTest @Test public void testSimple() { - assertValueEquals(new SerializablePairLongLong(Long.MAX_VALUE, 10L), 16); + assertValueEquals(new SerializablePairLongLong(Long.MAX_VALUE, 10L), 17); } @Test @@ -47,7 +47,7 @@ public void testNull() @Test public void testNullString() { - assertValueEquals(new SerializablePairLongLong(Long.MAX_VALUE, null), 8); + assertValueEquals(new SerializablePairLongLong(Long.MAX_VALUE, null), 9); } @Test @@ -55,7 +55,7 @@ public void testLargeRHS() { assertValueEquals( new SerializablePairLongLong(Long.MAX_VALUE, random.nextLong()), - 16 + 17 ); } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/first/DoubleFirstVectorAggregationTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/first/DoubleFirstVectorAggregationTest.java index 49c15fa22a51..e055930a2169 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/first/DoubleFirstVectorAggregationTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/first/DoubleFirstVectorAggregationTest.java @@ -19,14 +19,13 @@ package org.apache.druid.query.aggregation.first; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.query.aggregation.SerializablePairLongDouble; import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilitiesImpl; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.vector.BaseDoubleVectorValueSelector; import org.apache.druid.segment.vector.BaseLongVectorValueSelector; import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; import org.apache.druid.segment.vector.NoFilterVectorOffset; @@ -48,23 +47,29 @@ public class DoubleFirstVectorAggregationTest extends InitializedNullHandlingTes { private static final double EPSILON = 1e-5; private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60}; - private static final boolean[] NULLS = new boolean[]{false, false, true, false}; - private long[] times = {2436, 6879, 7888, 8224}; - + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; private static final String TIME_COL = "__time"; + private final long[] times = {2345001L, 2345100L, 2345200L, 2345300L}; + private final SerializablePairLongDouble[] pairs = { + new SerializablePairLongDouble(2345001L, 1D), + new SerializablePairLongDouble(2345100L, 2D), + new SerializablePairLongDouble(2345200L, 3D), + new SerializablePairLongDouble(2345300L, 4D) + }; - private VectorValueSelector selector; - + private VectorObjectSelector selector; private BaseLongVectorValueSelector timeSelector; private ByteBuffer buf; - private DoubleFirstVectorAggregator target; private DoubleFirstAggregatorFactory doubleFirstAggregatorFactory; - private VectorColumnSelectorFactory selectorFactory; + private VectorValueSelector nonLongValueSelector; @Before public void setup() @@ -86,39 +91,80 @@ public long[] getLongVector() @Override public boolean[] getNullVector() { - return NULLS; + return null; } }; - selector = new BaseDoubleVectorValueSelector(new NoFilterVectorOffset(VALUES.length, 0, VALUES.length) + selector = new VectorObjectSelector() { + @Override + public Object[] getObjectVector() + { + return pairs; + } - }) + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + nonLongValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + @Override public double[] getDoubleVector() { - return VALUES; + return DOUBLE_VALUES; } @Nullable @Override public boolean[] getNullVector() { - if (!NullHandling.replaceWithDefault()) { - return NULLS; - } return null; } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } }; - target = new DoubleFirstVectorAggregator(timeSelector, selector); - clearBufferForPositions(0, 0); selectorFactory = new VectorColumnSelectorFactory() { @Override public ReadableVectorInspector getReadableVectorInspector() { - return null; + return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length); } @Override @@ -138,17 +184,21 @@ public VectorValueSelector makeValueSelector(String column) { if (TIME_COL.equals(column)) { return timeSelector; - } else if (FIELD_NAME.equals(column)) { - return selector; - } else { - return null; + } else if (FIELD_NAME_LONG.equals(column)) { + return nonLongValueSelector; } + return null; } + @Override public VectorObjectSelector makeObjectSelector(String column) { - return null; + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } } @Nullable @@ -157,11 +207,16 @@ public ColumnCapabilities getColumnCapabilities(String column) { if (FIELD_NAME.equals(column)) { return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); } return null; } }; + target = new DoubleFirstVectorAggregator(timeSelector, selector); + clearBufferForPositions(0, 0); + doubleFirstAggregatorFactory = new DoubleFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL); } @@ -185,19 +240,19 @@ public void initValueShouldInitZero() @Test public void aggregate() { - target.aggregate(buf, 0, 0, VALUES.length); + target.aggregate(buf, 0, 0, pairs.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[0], result.lhs.longValue()); - Assert.assertEquals(VALUES[0], result.rhs, EPSILON); + Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[0].rhs, result.rhs, EPSILON); } @Test public void aggregateWithNulls() { - target.aggregate(buf, 0, 0, VALUES.length); + target.aggregate(buf, 0, 0, pairs.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[0], result.lhs.longValue()); - Assert.assertEquals(VALUES[0], result.rhs, EPSILON); + Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[0].rhs, result.rhs, EPSILON); } @Test @@ -209,12 +264,8 @@ public void aggregateBatchWithoutRows() target.aggregate(buf, 3, positions, null, positionOffset); for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); - Assert.assertEquals(times[i], result.lhs.longValue()); - if (!NullHandling.replaceWithDefault() && NULLS[i]) { - Assert.assertNull(result.rhs); - } else { - Assert.assertEquals(VALUES[i], result.rhs, EPSILON); - } + Assert.assertEquals(pairs[i].getLhs().longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[i].rhs, result.rhs, EPSILON); } } @@ -222,18 +273,15 @@ public void aggregateBatchWithoutRows() public void aggregateBatchWithRows() { int[] positions = new int[]{0, 43, 70}; - int[] rows = new int[]{3, 2, 0}; + int[] rows = new int[]{3, 0, 2}; int positionOffset = 2; clearBufferForPositions(positionOffset, positions); target.aggregate(buf, 3, positions, rows, positionOffset); for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); - Assert.assertEquals(times[rows[i]], result.lhs.longValue()); - if (!NullHandling.replaceWithDefault() && NULLS[rows[i]]) { - Assert.assertNull(result.rhs); - } else { - Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON); - } + Assert.assertEquals(pairs[rows[i]].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[rows[i]].rhs, result.rhs, EPSILON); + } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/first/FloatFirstVectorAggregationTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/first/FloatFirstVectorAggregationTest.java index 6b02037824a9..c7e8210d6c6d 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/first/FloatFirstVectorAggregationTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/first/FloatFirstVectorAggregationTest.java @@ -21,12 +21,12 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.query.aggregation.SerializablePairLongFloat; import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilitiesImpl; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.vector.BaseFloatVectorValueSelector; import org.apache.druid.segment.vector.BaseLongVectorValueSelector; import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; import org.apache.druid.segment.vector.NoFilterVectorOffset; @@ -48,24 +48,32 @@ public class FloatFirstVectorAggregationTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final float[] VALUES = new float[]{7.2f, 15.6f, 2.1f, 150.0f}; - private static final boolean[] NULLS = new boolean[]{false, false, true, false}; - private long[] times = {2436, 6879, 7888, 8224}; - + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; + private static final boolean[] NULLS = new boolean[]{false, false, false, false}; private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; private static final String TIME_COL = "__time"; + private final long[] times = {2345001L, 2345100L, 2345200L, 2345300L}; + private final SerializablePairLongFloat[] pairs = { + new SerializablePairLongFloat(2345001L, 1.2F), + new SerializablePairLongFloat(2345100L, 2.2F), + new SerializablePairLongFloat(2345200L, 3.2F), + new SerializablePairLongFloat(2345300L, 4.2F) + }; - private VectorValueSelector selector; + private VectorObjectSelector selector; private BaseLongVectorValueSelector timeSelector; private ByteBuffer buf; - private FloatFirstVectorAggregator target; private FloatFirstAggregatorFactory floatFirstAggregatorFactory; - private VectorColumnSelectorFactory selectorFactory; + private VectorValueSelector nonFloatValueSelector; @Before public void setup() @@ -87,41 +95,80 @@ public long[] getLongVector() @Override public boolean[] getNullVector() { - return NULLS; + return null; } }; - selector = new BaseFloatVectorValueSelector(new NoFilterVectorOffset(VALUES.length, 0, VALUES.length) + selector = new VectorObjectSelector() { + @Override + public Object[] getObjectVector() + { + return pairs; + } - }) + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + nonFloatValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } @Override public float[] getFloatVector() { - return VALUES; + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; } @Nullable @Override public boolean[] getNullVector() { - if (!NullHandling.replaceWithDefault()) { - return NULLS; - } - return null; + return NULLS; } - }; - target = new FloatFirstVectorAggregator(timeSelector, selector); - clearBufferForPositions(0, 0); + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; selectorFactory = new VectorColumnSelectorFactory() { @Override public ReadableVectorInspector getReadableVectorInspector() { - return null; + return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length); } @Override @@ -142,7 +189,7 @@ public VectorValueSelector makeValueSelector(String column) if (TIME_COL.equals(column)) { return timeSelector; } else if (FIELD_NAME.equals(column)) { - return selector; + return nonFloatValueSelector; } else { return null; } @@ -151,7 +198,11 @@ public VectorValueSelector makeValueSelector(String column) @Override public VectorObjectSelector makeObjectSelector(String column) { - return null; + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } } @Nullable @@ -160,10 +211,16 @@ public ColumnCapabilities getColumnCapabilities(String column) { if (FIELD_NAME.equals(column)) { return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.FLOAT); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); } return null; } }; + + target = new FloatFirstVectorAggregator(timeSelector, selector); + clearBufferForPositions(0, 0); + floatFirstAggregatorFactory = new FloatFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL); } @@ -191,8 +248,8 @@ public void aggregate() target.init(buf, 0); target.aggregate(buf, 0, 0, VALUES.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[0], result.lhs.longValue()); - Assert.assertEquals(VALUES[0], result.rhs, EPSILON); + Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[0].rhs, result.rhs, EPSILON); } @Test @@ -200,8 +257,8 @@ public void aggregateWithNulls() { target.aggregate(buf, 0, 0, VALUES.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[0], result.lhs.longValue()); - Assert.assertEquals(VALUES[0], result.rhs, EPSILON); + Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[0].rhs, result.rhs, EPSILON); } @Test @@ -213,11 +270,11 @@ public void aggregateBatchWithoutRows() target.aggregate(buf, 3, positions, null, positionOffset); for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); - Assert.assertEquals(times[i], result.lhs.longValue()); + Assert.assertEquals(pairs[i].getLhs().longValue(), result.lhs.longValue()); if (!NullHandling.replaceWithDefault() && NULLS[i]) { Assert.assertNull(result.rhs); } else { - Assert.assertEquals(VALUES[i], result.rhs, EPSILON); + Assert.assertEquals(pairs[i].rhs, result.rhs, EPSILON); } } } @@ -236,7 +293,7 @@ public void aggregateBatchWithRows() if (!NullHandling.replaceWithDefault() && NULLS[rows[i]]) { Assert.assertNull(result.rhs); } else { - Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON); + Assert.assertEquals(pairs[rows[i]].rhs, result.rhs, EPSILON); } } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/first/LongFirstVectorAggregationTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/first/LongFirstVectorAggregationTest.java index ec4017600628..a45e0f25563d 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/first/LongFirstVectorAggregationTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/first/LongFirstVectorAggregationTest.java @@ -21,6 +21,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.query.aggregation.SerializablePairLongLong; import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnCapabilities; @@ -48,18 +49,30 @@ public class LongFirstVectorAggregationTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final long[] VALUES = new long[]{7, 15, 2, 150}; - private static final boolean[] NULLS = new boolean[]{false, false, true, false}; + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; + private static final boolean[] NULLS = new boolean[]{false, false, false, false}; private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; private static final String TIME_COL = "__time"; - private long[] times = {2436, 6879, 7888, 8224}; - private VectorValueSelector selector; + private final long[] times = {2345001L, 2345100L, 2345200L, 2345300L}; + private final SerializablePairLongLong[] pairs = { + new SerializablePairLongLong(2345001L, 1L), + new SerializablePairLongLong(2345100L, 2L), + new SerializablePairLongLong(2345200L, 3L), + new SerializablePairLongLong(2345300L, 4L) + }; + + private VectorObjectSelector selector; private BaseLongVectorValueSelector timeSelector; private ByteBuffer buf; private LongFirstVectorAggregator target; private LongFirstAggregatorFactory longFirstAggregatorFactory; private VectorColumnSelectorFactory selectorFactory; + private VectorValueSelector nonLongValueSelector; @Before public void setup() @@ -81,40 +94,81 @@ public long[] getLongVector() @Override public boolean[] getNullVector() { - return NULLS; + return null; } }; - selector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(VALUES.length, 0, VALUES.length) + + selector = new VectorObjectSelector() { + @Override + public Object[] getObjectVector() + { + return pairs; + } - }) + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + nonLongValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) { @Override public long[] getLongVector() { - return VALUES; + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; } @Nullable @Override public boolean[] getNullVector() { - if (!NullHandling.replaceWithDefault()) { - return NULLS; - } return null; } - }; - target = new LongFirstVectorAggregator(timeSelector, selector); - clearBufferForPositions(0, 0); + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; selectorFactory = new VectorColumnSelectorFactory() { @Override public ReadableVectorInspector getReadableVectorInspector() { - return null; + return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length); } @Override @@ -134,17 +188,21 @@ public VectorValueSelector makeValueSelector(String column) { if (TIME_COL.equals(column)) { return timeSelector; - } else if (FIELD_NAME.equals(column)) { - return selector; - } else { - return null; + } else if (FIELD_NAME_LONG.equals(column)) { + return nonLongValueSelector; } + return null; } + @Override public VectorObjectSelector makeObjectSelector(String column) { - return null; + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } } @Nullable @@ -153,10 +211,16 @@ public ColumnCapabilities getColumnCapabilities(String column) { if (FIELD_NAME.equals(column)) { return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE); } return null; } }; + + target = new LongFirstVectorAggregator(timeSelector, selector); + clearBufferForPositions(0, 0); + longFirstAggregatorFactory = new LongFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL); } @@ -180,19 +244,19 @@ public void initValueShouldInitZero() @Test public void aggregate() { - target.aggregate(buf, 0, 0, VALUES.length); + target.aggregate(buf, 0, 0, pairs.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[0], result.lhs.longValue()); - Assert.assertEquals(VALUES[0], result.rhs, EPSILON); + Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[0].rhs, result.rhs, EPSILON); } @Test public void aggregateWithNulls() { - target.aggregate(buf, 0, 0, VALUES.length); + target.aggregate(buf, 0, 0, pairs.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[0], result.lhs.longValue()); - Assert.assertEquals(VALUES[0], result.rhs, EPSILON); + Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[0].rhs, result.rhs, EPSILON); } @Test @@ -204,11 +268,11 @@ public void aggregateBatchWithoutRows() target.aggregate(buf, 3, positions, null, positionOffset); for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); - Assert.assertEquals(times[i], result.lhs.longValue()); + Assert.assertEquals(pairs[i].getLhs().longValue(), result.lhs.longValue()); if (!NullHandling.replaceWithDefault() && NULLS[i]) { Assert.assertNull(result.rhs); } else { - Assert.assertEquals(VALUES[i], result.rhs, EPSILON); + Assert.assertEquals(pairs[i].rhs, result.rhs, EPSILON); } } } @@ -227,7 +291,7 @@ public void aggregateBatchWithRows() if (!NullHandling.replaceWithDefault() && NULLS[rows[i]]) { Assert.assertNull(result.rhs); } else { - Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON); + Assert.assertEquals(pairs[rows[i]].rhs, result.rhs, EPSILON); } } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/last/DoubleLastVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/last/DoubleLastVectorAggregatorTest.java index 391aa60866bc..d3125f190e1c 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/last/DoubleLastVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/last/DoubleLastVectorAggregatorTest.java @@ -19,47 +19,215 @@ package org.apache.druid.query.aggregation.last; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.query.aggregation.SerializablePairLongDouble; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; +import org.apache.druid.segment.vector.NoFilterVectorOffset; +import org.apache.druid.segment.vector.ReadableVectorInspector; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; -@RunWith(MockitoJUnitRunner.class) public class DoubleLastVectorAggregatorTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60}; + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; private static final boolean[] NULLS = new boolean[]{false, false, true, false}; - private long[] times = {2436, 6879, 7888, 8224}; + private static final String NAME = "NAME"; + private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; + private static final String TIME_COL = "__time"; + private final long[] times = {2345001L, 2345100L, 2345200L, 2345300L}; + private final SerializablePairLongDouble[] pairs = { + new SerializablePairLongDouble(2345001L, 1D), + new SerializablePairLongDouble(2345100L, 2D), + new SerializablePairLongDouble(2345200L, 3D), + new SerializablePairLongDouble(2345300L, 4D) + }; - @Mock - private VectorValueSelector selector; - @Mock - private VectorValueSelector timeSelector; + private VectorObjectSelector selector; + private BaseLongVectorValueSelector timeSelector; private ByteBuffer buf; - private DoubleLastVectorAggregator target; + private DoubleLastAggregatorFactory doubleLastAggregatorFactory; + private VectorColumnSelectorFactory selectorFactory; + private VectorValueSelector nonLongValueSelector; + @Before public void setup() { byte[] randomBytes = new byte[1024]; ThreadLocalRandom.current().nextBytes(randomBytes); buf = ByteBuffer.wrap(randomBytes); - Mockito.doReturn(VALUES).when(selector).getDoubleVector(); - Mockito.doReturn(times).when(timeSelector).getLongVector(); + timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length) + { + }) + { + @Override + public long[] getLongVector() + { + return times; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + }; + selector = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return pairs; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + nonLongValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) + { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; + + selectorFactory = new VectorColumnSelectorFactory() + { + @Override + public ReadableVectorInspector getReadableVectorInspector() + { + return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length); + } + + @Override + public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public VectorValueSelector makeValueSelector(String column) + { + if (TIME_COL.equals(column)) { + return timeSelector; + } else if (FIELD_NAME_LONG.equals(column)) { + return nonLongValueSelector; + } + return null; + } + + + @Override + public VectorObjectSelector makeObjectSelector(String column) + { + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } + } + + @Nullable + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + if (FIELD_NAME.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); + } + return null; + } + }; + target = new DoubleLastVectorAggregator(timeSelector, selector); clearBufferForPositions(0, 0); + + doubleLastAggregatorFactory = new DoubleLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL); + } + + @Test + public void testFactory() + { + Assert.assertTrue(doubleLastAggregatorFactory.canVectorize(selectorFactory)); + VectorAggregator vectorAggregator = doubleLastAggregatorFactory.factorizeVector(selectorFactory); + Assert.assertNotNull(vectorAggregator); + Assert.assertEquals(DoubleLastVectorAggregator.class, vectorAggregator.getClass()); } @Test @@ -73,20 +241,19 @@ public void initValueShouldInitZero() @Test public void aggregate() { - target.aggregate(buf, 0, 0, VALUES.length); + target.aggregate(buf, 0, 0, pairs.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[3], result.lhs.longValue()); - Assert.assertEquals(VALUES[3], result.rhs, EPSILON); + Assert.assertEquals(pairs[3].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[3].rhs, result.rhs, EPSILON); } @Test public void aggregateWithNulls() { - mockNullsVector(); - target.aggregate(buf, 0, 0, VALUES.length); + target.aggregate(buf, 0, 0, pairs.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[3], result.lhs.longValue()); - Assert.assertEquals(VALUES[3], result.rhs, EPSILON); + Assert.assertEquals(pairs[3].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[3].rhs, result.rhs, EPSILON); } @Test @@ -98,8 +265,8 @@ public void aggregateBatchWithoutRows() target.aggregate(buf, 3, positions, null, positionOffset); for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); - Assert.assertEquals(times[i], result.lhs.longValue()); - Assert.assertEquals(VALUES[i], result.rhs, EPSILON); + Assert.assertEquals(pairs[i].getLhs().longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[i].rhs, result.rhs, EPSILON); } } @@ -113,8 +280,8 @@ public void aggregateBatchWithRows() target.aggregate(buf, 3, positions, rows, positionOffset); for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); - Assert.assertEquals(times[rows[i]], result.lhs.longValue()); - Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON); + Assert.assertEquals(pairs[rows[i]].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[rows[i]].rhs, result.rhs, EPSILON); } } @@ -124,11 +291,4 @@ private void clearBufferForPositions(int offset, int... positions) target.init(buf, offset + position); } } - - private void mockNullsVector() - { - if (!NullHandling.replaceWithDefault()) { - Mockito.doReturn(NULLS).when(selector).getNullVector(); - } - } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/last/FloatLastVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/last/FloatLastVectorAggregatorTest.java index 82615bcd7fe2..957d585e5fd9 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/last/FloatLastVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/last/FloatLastVectorAggregatorTest.java @@ -21,45 +21,217 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.query.aggregation.SerializablePairLongFloat; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; +import org.apache.druid.segment.vector.NoFilterVectorOffset; +import org.apache.druid.segment.vector.ReadableVectorInspector; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; -@RunWith(MockitoJUnitRunner.class) public class FloatLastVectorAggregatorTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final float[] VALUES = new float[]{7.2f, 15.6f, 2.1f, 150.0f}; - private static final boolean[] NULLS = new boolean[]{false, false, true, false}; - private long[] times = {2436, 6879, 7888, 8224}; + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; + private static final boolean[] NULLS = new boolean[]{false, false, false, false}; + private static final String NAME = "NAME"; + private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; + private static final String TIME_COL = "__time"; + private final long[] times = {2345001L, 2345100L, 2345200L, 2345300L}; + private final SerializablePairLongFloat[] pairs = { + new SerializablePairLongFloat(2345001L, 1.2F), + new SerializablePairLongFloat(2345100L, 2.2F), + new SerializablePairLongFloat(2345200L, 3.2F), + new SerializablePairLongFloat(2345300L, 4.2F) + }; + - @Mock - private VectorValueSelector selector; - @Mock - private VectorValueSelector timeSelector; - private ByteBuffer buf; + private VectorObjectSelector selector; + private BaseLongVectorValueSelector timeSelector; + private ByteBuffer buf; private FloatLastVectorAggregator target; + private FloatLastAggregatorFactory floatLastAggregatorFactory; + private VectorColumnSelectorFactory selectorFactory; + private VectorValueSelector nonFloatValueSelector; + @Before public void setup() { byte[] randomBytes = new byte[1024]; ThreadLocalRandom.current().nextBytes(randomBytes); buf = ByteBuffer.wrap(randomBytes); - Mockito.doReturn(VALUES).when(selector).getFloatVector(); - Mockito.doReturn(times).when(timeSelector).getLongVector(); + timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length) + { + }) + { + @Override + public long[] getLongVector() + { + return times; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + }; + selector = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return pairs; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + nonFloatValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) + { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return NULLS; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; + + selectorFactory = new VectorColumnSelectorFactory() + { + @Override + public ReadableVectorInspector getReadableVectorInspector() + { + return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length); + } + + @Override + public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public VectorValueSelector makeValueSelector(String column) + { + if (TIME_COL.equals(column)) { + return timeSelector; + } else if (FIELD_NAME.equals(column)) { + return nonFloatValueSelector; + } else { + return null; + } + } + + @Override + public VectorObjectSelector makeObjectSelector(String column) + { + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } + } + + @Nullable + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + if (FIELD_NAME.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.FLOAT); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); + } + return null; + } + }; + target = new FloatLastVectorAggregator(timeSelector, selector); clearBufferForPositions(0, 0); + + floatLastAggregatorFactory = new FloatLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL); + + } + + @Test + public void testFactory() + { + Assert.assertTrue(floatLastAggregatorFactory.canVectorize(selectorFactory)); + VectorAggregator vectorAggregator = floatLastAggregatorFactory.factorizeVector(selectorFactory); + Assert.assertNotNull(vectorAggregator); + Assert.assertEquals(FloatLastVectorAggregator.class, vectorAggregator.getClass()); } @Test @@ -76,18 +248,17 @@ public void aggregate() target.init(buf, 0); target.aggregate(buf, 0, 0, VALUES.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[3], result.lhs.longValue()); - Assert.assertEquals(VALUES[3], result.rhs, EPSILON); + Assert.assertEquals(pairs[3].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[3].rhs, result.rhs, EPSILON); } @Test public void aggregateWithNulls() { - mockNullsVector(); target.aggregate(buf, 0, 0, VALUES.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[3], result.lhs.longValue()); - Assert.assertEquals(VALUES[3], result.rhs, EPSILON); + Assert.assertEquals(pairs[3].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[3].rhs, result.rhs, EPSILON); } @Test @@ -99,8 +270,12 @@ public void aggregateBatchWithoutRows() target.aggregate(buf, 3, positions, null, positionOffset); for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); - Assert.assertEquals(times[i], result.lhs.longValue()); - Assert.assertEquals(VALUES[i], result.rhs, EPSILON); + Assert.assertEquals(pairs[i].getLhs().longValue(), result.lhs.longValue()); + if (!NullHandling.replaceWithDefault() && NULLS[i]) { + Assert.assertNull(result.rhs); + } else { + Assert.assertEquals(pairs[i].rhs, result.rhs, EPSILON); + } } } @@ -115,7 +290,11 @@ public void aggregateBatchWithRows() for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); Assert.assertEquals(times[rows[i]], result.lhs.longValue()); - Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON); + if (!NullHandling.replaceWithDefault() && NULLS[rows[i]]) { + Assert.assertNull(result.rhs); + } else { + Assert.assertEquals(pairs[rows[i]].rhs, result.rhs, EPSILON); + } } } @@ -125,12 +304,5 @@ private void clearBufferForPositions(int offset, int... positions) target.init(buf, offset + position); } } - - private void mockNullsVector() - { - if (!NullHandling.replaceWithDefault()) { - Mockito.doReturn(NULLS).when(selector).getNullVector(); - } - } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/last/LongLastVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/last/LongLastVectorAggregatorTest.java index acbc1a8f2480..6cc8bfa9f5a2 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/last/LongLastVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/last/LongLastVectorAggregatorTest.java @@ -21,45 +21,215 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.query.aggregation.SerializablePairLongLong; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; +import org.apache.druid.segment.vector.NoFilterVectorOffset; +import org.apache.druid.segment.vector.ReadableVectorInspector; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; -@RunWith(MockitoJUnitRunner.class) public class LongLastVectorAggregatorTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final long[] VALUES = new long[]{7, 15, 2, 150}; - private static final boolean[] NULLS = new boolean[]{false, false, true, false}; - private long[] times = {2436, 6879, 7888, 8224}; + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; + private static final boolean[] NULLS = new boolean[]{false, false, false, false}; + private static final String NAME = "NAME"; + private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; + private static final String TIME_COL = "__time"; + private final long[] times = {2345001L, 2345100L, 2345200L, 2345300L}; + private final SerializablePairLongLong[] pairs = { + new SerializablePairLongLong(2345001L, 1L), + new SerializablePairLongLong(2345100L, 2L), + new SerializablePairLongLong(2345200L, 3L), + new SerializablePairLongLong(2345300L, 4L) + }; - @Mock - private VectorValueSelector selector; - @Mock - private VectorValueSelector timeSelector; + private VectorObjectSelector selector; + private BaseLongVectorValueSelector timeSelector; private ByteBuffer buf; - private LongLastVectorAggregator target; + private LongLastAggregatorFactory longLastAggregatorFactory; + private VectorColumnSelectorFactory selectorFactory; + private VectorValueSelector nonLongValueSelector; + @Before public void setup() { byte[] randomBytes = new byte[1024]; ThreadLocalRandom.current().nextBytes(randomBytes); buf = ByteBuffer.wrap(randomBytes); - Mockito.doReturn(VALUES).when(selector).getLongVector(); - Mockito.doReturn(times).when(timeSelector).getLongVector(); + timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length) + { + }) + { + @Override + public long[] getLongVector() + { + return times; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + }; + + selector = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return pairs; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + nonLongValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) + { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; + + selectorFactory = new VectorColumnSelectorFactory() + { + @Override + public ReadableVectorInspector getReadableVectorInspector() + { + return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length); + } + + @Override + public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public VectorValueSelector makeValueSelector(String column) + { + if (TIME_COL.equals(column)) { + return timeSelector; + } else if (FIELD_NAME_LONG.equals(column)) { + return nonLongValueSelector; + } + return null; + } + + + @Override + public VectorObjectSelector makeObjectSelector(String column) + { + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } + } + + @Nullable + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + if (FIELD_NAME.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE); + } + return null; + } + }; + target = new LongLastVectorAggregator(timeSelector, selector); clearBufferForPositions(0, 0); + + longLastAggregatorFactory = new LongLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL); + } + + @Test + public void testFactory() + { + Assert.assertTrue(longLastAggregatorFactory.canVectorize(selectorFactory)); + VectorAggregator vectorAggregator = longLastAggregatorFactory.factorizeVector(selectorFactory); + Assert.assertNotNull(vectorAggregator); + Assert.assertEquals(LongLastVectorAggregator.class, vectorAggregator.getClass()); } @Test @@ -73,20 +243,19 @@ public void initValueShouldInitZero() @Test public void aggregate() { - target.aggregate(buf, 0, 0, VALUES.length); + target.aggregate(buf, 0, 0, pairs.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[3], result.lhs.longValue()); - Assert.assertEquals(VALUES[3], result.rhs, EPSILON); + Assert.assertEquals(pairs[3].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[3].rhs, result.rhs, EPSILON); } @Test public void aggregateWithNulls() { - mockNullsVector(); - target.aggregate(buf, 0, 0, VALUES.length); + target.aggregate(buf, 0, 0, pairs.length); Pair result = (Pair) target.get(buf, 0); - Assert.assertEquals(times[3], result.lhs.longValue()); - Assert.assertEquals(VALUES[3], result.rhs, EPSILON); + Assert.assertEquals(pairs[3].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[3].rhs, result.rhs, EPSILON); } @Test @@ -98,8 +267,12 @@ public void aggregateBatchWithoutRows() target.aggregate(buf, 3, positions, null, positionOffset); for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); - Assert.assertEquals(times[i], result.lhs.longValue()); - Assert.assertEquals(VALUES[i], result.rhs, EPSILON); + Assert.assertEquals(pairs[i].getLhs().longValue(), result.lhs.longValue()); + if (!NullHandling.replaceWithDefault() && NULLS[i]) { + Assert.assertNull(result.rhs); + } else { + Assert.assertEquals(pairs[i].rhs, result.rhs, EPSILON); + } } } @@ -114,7 +287,11 @@ public void aggregateBatchWithRows() for (int i = 0; i < positions.length; i++) { Pair result = (Pair) target.get(buf, positions[i] + positionOffset); Assert.assertEquals(times[rows[i]], result.lhs.longValue()); - Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON); + if (!NullHandling.replaceWithDefault() && NULLS[rows[i]]) { + Assert.assertNull(result.rhs); + } else { + Assert.assertEquals(pairs[rows[i]].rhs, result.rhs, EPSILON); + } } } @@ -124,11 +301,4 @@ private void clearBufferForPositions(int offset, int... positions) target.init(buf, offset + position); } } - - private void mockNullsVector() - { - if (!NullHandling.replaceWithDefault()) { - Mockito.doReturn(NULLS).when(selector).getNullVector(); - } - } }