From bf99d2c7b2e4436751c1d42c0e07b3e3ed36562e Mon Sep 17 00:00:00 2001 From: Soumyava <93540295+somu-imply@users.noreply.github.com> Date: Wed, 13 Sep 2023 13:15:14 -0700 Subject: [PATCH] Fix for schema mismatch to go down using the non vectorize path till we update the vectorized aggs properly (#14924) * Fix for schema mismatch to go down using the non vectorize path till we update the vectorized aggs properly * Fixing a failed test * Updating numericNilAgg * Moving to use default values in case of nil agg * Adding the same for first agg * Fixing a test * fixing vectorized string agg for last/first with cast if numeric * Updating tests to remove mockito and cover the case of string first/last on non string columns * Updating a test to vectorize * Addressing review comments: Name change to NilVectorAggregator and using static variables now * fixing intellij inspections --- .../any/DoubleAnyAggregatorFactory.java | 2 +- .../any/FloatAnyAggregatorFactory.java | 2 +- .../any/LongAnyAggregatorFactory.java | 2 +- ...gregator.java => NilVectorAggregator.java} | 26 +- .../first/DoubleFirstAggregatorFactory.java | 4 +- .../first/FloatFirstAggregatorFactory.java | 4 +- .../first/LongFirstAggregatorFactory.java | 4 +- .../first/StringFirstAggregatorFactory.java | 14 ++ .../last/DoubleLastAggregatorFactory.java | 14 +- .../last/FloatLastAggregatorFactory.java | 16 +- .../last/LongLastAggregatorFactory.java | 15 +- .../last/StringLastAggregatorFactory.java | 19 +- .../StringFirstVectorAggregatorTest.java | 227 +++++++++++++++-- .../last/StringLastVectorAggregatorTest.java | 230 ++++++++++++++++-- .../sql/calcite/CalciteSimpleQueryTest.java | 1 - 15 files changed, 497 insertions(+), 83 deletions(-) rename processing/src/main/java/org/apache/druid/query/aggregation/any/{NumericNilVectorAggregator.java => NilVectorAggregator.java} (66%) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java index 86f85455a6da..0a51e5633947 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java @@ -123,7 +123,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact if (capabilities == null || capabilities.isNumeric()) { return new DoubleAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); } else { - return NumericNilVectorAggregator.doubleNilVectorAggregator(); + return NilVectorAggregator.doubleNilVectorAggregator(); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java index 35495c9e30eb..a9ee3519b9e5 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java @@ -120,7 +120,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact if (capabilities == null || capabilities.isNumeric()) { return new FloatAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); } else { - return NumericNilVectorAggregator.floatNilVectorAggregator(); + return NilVectorAggregator.floatNilVectorAggregator(); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java index 9af417a600db..9b220337e353 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java @@ -119,7 +119,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact if (capabilities == null || capabilities.isNumeric()) { return new LongAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); } else { - return NumericNilVectorAggregator.longNilVectorAggregator(); + return NilVectorAggregator.longNilVectorAggregator(); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/NumericNilVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java similarity index 66% rename from processing/src/main/java/org/apache/druid/query/aggregation/any/NumericNilVectorAggregator.java rename to processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java index 3034524c1853..ac6c5c7a75e4 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/NumericNilVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java @@ -19,6 +19,7 @@ package org.apache.druid.query.aggregation.any; +import org.apache.druid.collections.SerializablePair; import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.VectorAggregator; @@ -28,24 +29,28 @@ /** * A vector aggregator that returns the default numeric value. */ -public class NumericNilVectorAggregator implements VectorAggregator +public class NilVectorAggregator implements VectorAggregator { - private static final NumericNilVectorAggregator DOUBLE_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( + private static final NilVectorAggregator DOUBLE_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator( NullHandling.defaultDoubleValue() ); - private static final NumericNilVectorAggregator FLOAT_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( + private static final NilVectorAggregator FLOAT_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator( NullHandling.defaultFloatValue() ); - private static final NumericNilVectorAggregator LONG_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( + private static final NilVectorAggregator LONG_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator( NullHandling.defaultLongValue() ); + public static final SerializablePair DOUBLE_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultDoubleValue()); + public static final SerializablePair LONG_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultLongValue()); + public static final SerializablePair FLOAT_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultFloatValue()); + /** * @return A vectorized aggregator that returns the default double value. */ - public static NumericNilVectorAggregator doubleNilVectorAggregator() + public static NilVectorAggregator doubleNilVectorAggregator() { return DOUBLE_NIL_VECTOR_AGGREGATOR; } @@ -53,7 +58,7 @@ public static NumericNilVectorAggregator doubleNilVectorAggregator() /** * @return A vectorized aggregator that returns the default float value. */ - public static NumericNilVectorAggregator floatNilVectorAggregator() + public static NilVectorAggregator floatNilVectorAggregator() { return FLOAT_NIL_VECTOR_AGGREGATOR; } @@ -61,7 +66,7 @@ public static NumericNilVectorAggregator floatNilVectorAggregator() /** * @return A vectorized aggregator that returns the default long value. */ - public static NumericNilVectorAggregator longNilVectorAggregator() + public static NilVectorAggregator longNilVectorAggregator() { return LONG_NIL_VECTOR_AGGREGATOR; } @@ -69,7 +74,12 @@ public static NumericNilVectorAggregator longNilVectorAggregator() @Nullable private final Object returnValue; - private NumericNilVectorAggregator(@Nullable Object returnValue) + public static NilVectorAggregator of(Object returnValue) + { + return new NilVectorAggregator(returnValue); + } + + private NilVectorAggregator(@Nullable Object returnValue) { this.returnValue = returnValue; } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java index 7cee6f5ca64d..4e9ecd2523b4 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java @@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.BaseDoubleColumnValueSelector; @@ -149,7 +149,7 @@ public VectorAggregator factorizeVector( timeColumn); return new DoubleFirstVectorAggregator(timeSelector, valueSelector); } - return NumericNilVectorAggregator.doubleNilVectorAggregator(); + return NilVectorAggregator.of(NilVectorAggregator.DOUBLE_NIL_PAIR); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java index 68826bc2c0a1..ee28cd35dc6a 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java @@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.BaseFloatColumnValueSelector; @@ -138,7 +138,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory columnSelect VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); return new FloatFirstVectorAggregator(timeSelector, valueSelector); } - return NumericNilVectorAggregator.floatNilVectorAggregator(); + return NilVectorAggregator.of(NilVectorAggregator.FLOAT_NIL_PAIR); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java index 729a1bef26e2..c8aee33f511e 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java @@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.BaseLongColumnValueSelector; @@ -138,7 +138,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory columnSelect timeColumn); return new LongFirstVectorAggregator(timeSelector, valueSelector); } - return NumericNilVectorAggregator.longNilVectorAggregator(); + return NilVectorAggregator.of(NilVectorAggregator.LONG_NIL_PAIR); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java index 193acaf25a99..1bc625258593 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java @@ -42,11 +42,14 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; + import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -188,6 +191,17 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact { final VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn); ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + selectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.STRING + ); + return new StringFirstVectorAggregator(timeSelector, objectSelector, maxStringBytes); + } if (capabilities != null) { if (capabilities.is(ValueType.STRING) && capabilities.isDictionaryEncoded().isTrue()) { // Case 1: Single value string with dimension selector diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java index d3770fb3ae47..4b5f965a854e 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.UOE; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; @@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; @@ -42,6 +43,7 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorValueSelector; @@ -125,14 +127,12 @@ public VectorAggregator factorizeVector( ) { ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); - VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( - timeColumn); - if (capabilities == null || capabilities.isNumeric()) { + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); return new DoubleLastVectorAggregator(timeSelector, valueSelector); } else { - return NumericNilVectorAggregator.doubleNilVectorAggregator(); + return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultDoubleValue())); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java index dff50b095bcf..bc0ee23a08bc 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.UOE; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; @@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; @@ -42,6 +43,7 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorValueSelector; @@ -136,15 +138,13 @@ public VectorAggregator factorizeVector( VectorColumnSelectorFactory columnSelectorFactory ) { - ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); - VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( - timeColumn); - if (capabilities == null || capabilities.isNumeric()) { + final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); return new FloatLastVectorAggregator(timeSelector, valueSelector); } else { - return NumericNilVectorAggregator.floatNilVectorAggregator(); + return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultFloatValue())); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java index 29d9ad2a06ea..b08d8d386461 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.UOE; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; @@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; @@ -42,6 +43,7 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorValueSelector; @@ -136,14 +138,13 @@ public VectorAggregator factorizeVector( VectorColumnSelectorFactory columnSelectorFactory ) { - ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); - VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( - timeColumn); - if (capabilities == null || capabilities.isNumeric()) { + final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); return new LongLastVectorAggregator(timeSelector, valueSelector); } else { - return NumericNilVectorAggregator.longNilVectorAggregator(); + return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultLongValue())); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java index f6ca2a09d3b4..234f03be3f9b 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java @@ -42,9 +42,11 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -156,16 +158,25 @@ public boolean canVectorize(ColumnInspector columnInspector) public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) { - ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); + final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); + VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn); + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + selectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.STRING + ); + return new StringLastVectorAggregator(timeSelector, objectSelector, maxStringBytes); + } VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName); - VectorValueSelector timeSelector = selectorFactory.makeValueSelector( - timeColumn); if (capabilities != null) { return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes); } else { return new StringLastVectorAggregator(null, vSelector, maxStringBytes); } - } @Override diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregatorTest.java index 148f4f95937a..e3b461da5687 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregatorTest.java @@ -23,35 +23,45 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; +import org.apache.druid.segment.vector.NoFilterVectorOffset; +import org.apache.druid.segment.vector.ReadableVectorInspector; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; +import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Answers; import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; -@RunWith(MockitoJUnitRunner.class) public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final String[] VALUES = new String[]{"a", "b", null, "c"}; + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; private static final boolean[] NULLS = new boolean[]{false, false, true, false}; private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; private static final String TIME_COL = "__time"; - private long[] times = {2436, 6879, 7888, 8224}; - private long[] timesSame = {2436, 2436}; - private SerializablePairLongString[] pairs = { + private final long[] times = {2436, 6879, 7888, 8224}; + private final long[] timesSame = {2436, 2436}; + private final SerializablePairLongString[] pairs = { new SerializablePairLongString(2345001L, "first"), new SerializablePairLongString(2345100L, "notFirst") }; @@ -69,8 +79,10 @@ public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest private StringFirstVectorAggregator targetWithPairs; private StringFirstAggregatorFactory stringFirstAggregatorFactory; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private StringFirstAggregatorFactory stringFirstAggregatorFactory1; + private VectorColumnSelectorFactory selectorFactory; + private VectorValueSelector nonStringValueSelector; @Before public void setup() @@ -78,19 +90,189 @@ public void setup() byte[] randomBytes = new byte[1024]; ThreadLocalRandom.current().nextBytes(randomBytes); buf = ByteBuffer.wrap(randomBytes); - Mockito.doReturn(VALUES).when(selector).getObjectVector(); - Mockito.doReturn(times).when(timeSelector).getLongVector(); - Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector(); - Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector(); + + timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length)) + { + @Override + public long[] getLongVector() + { + return times; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + }; + + selector = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return VALUES; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + timesSame.length, + 0, + timesSame.length + )) + { + @Override + public long[] getLongVector() + { + return timesSame; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + }; + selectorForPairs = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return pairs; + } + + @Override + public int getMaxVectorSize() + { + return 2; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) + { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return NULLS; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; + + selectorFactory = new VectorColumnSelectorFactory() + { + @Override + public ReadableVectorInspector getReadableVectorInspector() + { + return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length); + } + + @Override + public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public VectorValueSelector makeValueSelector(String column) + { + if (TIME_COL.equals(column)) { + return timeSelector; + } else if (FIELD_NAME_LONG.equals(column)) { + return nonStringValueSelector; + } + return null; + } + + @Override + public VectorObjectSelector makeObjectSelector(String column) + { + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } + } + + @Nullable + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + if (FIELD_NAME.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities(); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); + } + return null; + } + }; + target = new StringFirstVectorAggregator(timeSelector, selector, 10); targetWithPairs = new StringFirstVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); clearBufferForPositions(0, 0); - Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME); - Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL); stringFirstAggregatorFactory = new StringFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10); - + stringFirstAggregatorFactory1 = new StringFirstAggregatorFactory(NAME, FIELD_NAME_LONG, TIME_COL, 10); } @Test @@ -129,6 +311,19 @@ public void aggregate() Assert.assertEquals(VALUES[0], result.rhs); } + @Test + public void testStringEarliestOnNonStringColumns() + { + Assert.assertTrue(stringFirstAggregatorFactory1.canVectorize(selectorFactory)); + VectorAggregator vectorAggregator = stringFirstAggregatorFactory1.factorizeVector(selectorFactory); + Assert.assertNotNull(vectorAggregator); + Assert.assertEquals(StringFirstVectorAggregator.class, vectorAggregator.getClass()); + vectorAggregator.aggregate(buf, 0, 0, LONG_VALUES.length); + Pair result = (Pair) vectorAggregator.get(buf, 0); + Assert.assertEquals(times[0], result.lhs.longValue()); + Assert.assertEquals(STRING_VALUES[0], result.rhs); + } + @Test public void aggregateBatchWithoutRows() { diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java index 6e5c0275107b..f144552d57e9 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java @@ -23,74 +23,245 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; +import org.apache.druid.segment.vector.NoFilterVectorOffset; +import org.apache.druid.segment.vector.ReadableVectorInspector; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; +import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Answers; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; -@RunWith(MockitoJUnitRunner.class) public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final String[] VALUES = new String[]{"a", "b", null, "c"}; + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; private static final boolean[] NULLS = new boolean[]{false, false, true, false}; private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; private static final String TIME_COL = "__time"; - private long[] times = {2436, 6879, 7888, 8224}; - private long[] timesSame = {2436, 2436}; - private SerializablePairLongString[] pairs = { + private final long[] times = {2436, 6879, 7888, 8224}; + private final long[] timesSame = {2436, 2436}; + private final SerializablePairLongString[] pairs = { new SerializablePairLongString(2345100L, "last"), new SerializablePairLongString(2345001L, "notLast") }; - @Mock private VectorObjectSelector selector; - @Mock - private VectorObjectSelector selectorForPairs; - @Mock private BaseLongVectorValueSelector timeSelector; - @Mock - private BaseLongVectorValueSelector timeSelectorForPairs; + private VectorValueSelector nonStringValueSelector; private ByteBuffer buf; private StringLastVectorAggregator target; private StringLastVectorAggregator targetWithPairs; private StringLastAggregatorFactory stringLastAggregatorFactory; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private StringLastAggregatorFactory stringLastAggregatorFactory1; + private VectorColumnSelectorFactory selectorFactory; + @Before public void setup() { byte[] randomBytes = new byte[1024]; ThreadLocalRandom.current().nextBytes(randomBytes); buf = ByteBuffer.wrap(randomBytes); - Mockito.doReturn(VALUES).when(selector).getObjectVector(); - Mockito.doReturn(times).when(timeSelector).getLongVector(); - Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector(); - Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector(); + timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length)) + { + @Override + public long[] getLongVector() + { + return times; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return NULLS; + } + }; + nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) + { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return NULLS; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; + selector = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return VALUES; + } + + @Override + public int getMaxVectorSize() + { + return 0; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + BaseLongVectorValueSelector timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + times.length, + 0, + times.length + )) + { + @Override + public long[] getLongVector() + { + return timesSame; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return new boolean[0]; + } + }; + VectorObjectSelector selectorForPairs = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return pairs; + } + + @Override + public int getMaxVectorSize() + { + return 2; + } + + @Override + public int getCurrentVectorSize() + { + return 2; + } + }; + selectorFactory = new VectorColumnSelectorFactory() + { + @Override + public ReadableVectorInspector getReadableVectorInspector() + { + return new NoFilterVectorOffset(LONG_VALUES.length, 0, LONG_VALUES.length); + } + + @Override + public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public VectorValueSelector makeValueSelector(String column) + { + if (TIME_COL.equals(column)) { + return timeSelector; + } else if (FIELD_NAME_LONG.equals(column)) { + return nonStringValueSelector; + } + return null; + } + + @Override + public VectorObjectSelector makeObjectSelector(String column) + { + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } + } + + @Nullable + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + if (FIELD_NAME.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities(); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); + } + return null; + } + }; + target = new StringLastVectorAggregator(timeSelector, selector, 10); targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); clearBufferForPositions(0, 0); - Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME); - Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL); stringLastAggregatorFactory = new StringLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10); - + stringLastAggregatorFactory1 = new StringLastAggregatorFactory(NAME, FIELD_NAME_LONG, TIME_COL, 10); } @Test @@ -112,6 +283,19 @@ public void testFactory() Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass()); } + @Test + public void testStringLastOnNonStringColumns() + { + Assert.assertTrue(stringLastAggregatorFactory1.canVectorize(selectorFactory)); + VectorAggregator vectorAggregator = stringLastAggregatorFactory1.factorizeVector(selectorFactory); + Assert.assertNotNull(vectorAggregator); + Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass()); + vectorAggregator.aggregate(buf, 0, 0, LONG_VALUES.length); + Pair result = (Pair) vectorAggregator.get(buf, 0); + Assert.assertEquals(times[3], result.lhs.longValue()); + Assert.assertEquals(STRING_VALUES[3], result.rhs); + } + @Test public void initValueShouldBeMinDate() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSimpleQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSimpleQueryTest.java index b339e5f4b7de..1eb6d58bcb87 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSimpleQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSimpleQueryTest.java @@ -632,7 +632,6 @@ public void testGroupByDimAndTimeAndDimOrderByDimAndTimeDim() @Test public void testEarliestByLatestByWithExpression() { - cannotVectorize(); testBuilder() .sql("SELECT\n" + " channel\n"