From 7f757e33f06ff3d6df4a3f200c095ffb3dfcd7ab Mon Sep 17 00:00:00 2001 From: Benedict Jin Date: Wed, 13 Sep 2023 21:12:35 +0800 Subject: [PATCH 01/10] Fix the created property in DOAP RDF file (#14971) --- doap_Druid.rdf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doap_Druid.rdf b/doap_Druid.rdf index bb95885517e8..4ac5f1ea999a 100644 --- a/doap_Druid.rdf +++ b/doap_Druid.rdf @@ -22,7 +22,7 @@ limitations under the License. --> - 2023-09-08 + 2012-10-23 Apache Druid From bf99d2c7b2e4436751c1d42c0e07b3e3ed36562e Mon Sep 17 00:00:00 2001 From: Soumyava <93540295+somu-imply@users.noreply.github.com> Date: Wed, 13 Sep 2023 13:15:14 -0700 Subject: [PATCH 02/10] Fix for schema mismatch to go down using the non vectorize path till we update the vectorized aggs properly (#14924) * Fix for schema mismatch to go down using the non vectorize path till we update the vectorized aggs properly * Fixing a failed test * Updating numericNilAgg * Moving to use default values in case of nil agg * Adding the same for first agg * Fixing a test * fixing vectorized string agg for last/first with cast if numeric * Updating tests to remove mockito and cover the case of string first/last on non string columns * Updating a test to vectorize * Addressing review comments: Name change to NilVectorAggregator and using static variables now * fixing intellij inspections --- .../any/DoubleAnyAggregatorFactory.java | 2 +- .../any/FloatAnyAggregatorFactory.java | 2 +- .../any/LongAnyAggregatorFactory.java | 2 +- ...gregator.java => NilVectorAggregator.java} | 26 +- .../first/DoubleFirstAggregatorFactory.java | 4 +- .../first/FloatFirstAggregatorFactory.java | 4 +- .../first/LongFirstAggregatorFactory.java | 4 +- .../first/StringFirstAggregatorFactory.java | 14 ++ .../last/DoubleLastAggregatorFactory.java | 14 +- .../last/FloatLastAggregatorFactory.java | 16 +- .../last/LongLastAggregatorFactory.java | 15 +- .../last/StringLastAggregatorFactory.java | 19 +- .../StringFirstVectorAggregatorTest.java | 227 +++++++++++++++-- .../last/StringLastVectorAggregatorTest.java | 230 ++++++++++++++++-- .../sql/calcite/CalciteSimpleQueryTest.java | 1 - 15 files changed, 497 insertions(+), 83 deletions(-) rename processing/src/main/java/org/apache/druid/query/aggregation/any/{NumericNilVectorAggregator.java => NilVectorAggregator.java} (66%) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java index 86f85455a6da..0a51e5633947 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java @@ -123,7 +123,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact if (capabilities == null || capabilities.isNumeric()) { return new DoubleAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); } else { - return NumericNilVectorAggregator.doubleNilVectorAggregator(); + return NilVectorAggregator.doubleNilVectorAggregator(); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java index 35495c9e30eb..a9ee3519b9e5 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java @@ -120,7 +120,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact if (capabilities == null || capabilities.isNumeric()) { return new FloatAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); } else { - return NumericNilVectorAggregator.floatNilVectorAggregator(); + return NilVectorAggregator.floatNilVectorAggregator(); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java index 9af417a600db..9b220337e353 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java @@ -119,7 +119,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact if (capabilities == null || capabilities.isNumeric()) { return new LongAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); } else { - return NumericNilVectorAggregator.longNilVectorAggregator(); + return NilVectorAggregator.longNilVectorAggregator(); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/NumericNilVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java similarity index 66% rename from processing/src/main/java/org/apache/druid/query/aggregation/any/NumericNilVectorAggregator.java rename to processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java index 3034524c1853..ac6c5c7a75e4 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/NumericNilVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/NilVectorAggregator.java @@ -19,6 +19,7 @@ package org.apache.druid.query.aggregation.any; +import org.apache.druid.collections.SerializablePair; import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.VectorAggregator; @@ -28,24 +29,28 @@ /** * A vector aggregator that returns the default numeric value. */ -public class NumericNilVectorAggregator implements VectorAggregator +public class NilVectorAggregator implements VectorAggregator { - private static final NumericNilVectorAggregator DOUBLE_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( + private static final NilVectorAggregator DOUBLE_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator( NullHandling.defaultDoubleValue() ); - private static final NumericNilVectorAggregator FLOAT_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( + private static final NilVectorAggregator FLOAT_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator( NullHandling.defaultFloatValue() ); - private static final NumericNilVectorAggregator LONG_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( + private static final NilVectorAggregator LONG_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator( NullHandling.defaultLongValue() ); + public static final SerializablePair DOUBLE_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultDoubleValue()); + public static final SerializablePair LONG_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultLongValue()); + public static final SerializablePair FLOAT_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultFloatValue()); + /** * @return A vectorized aggregator that returns the default double value. */ - public static NumericNilVectorAggregator doubleNilVectorAggregator() + public static NilVectorAggregator doubleNilVectorAggregator() { return DOUBLE_NIL_VECTOR_AGGREGATOR; } @@ -53,7 +58,7 @@ public static NumericNilVectorAggregator doubleNilVectorAggregator() /** * @return A vectorized aggregator that returns the default float value. */ - public static NumericNilVectorAggregator floatNilVectorAggregator() + public static NilVectorAggregator floatNilVectorAggregator() { return FLOAT_NIL_VECTOR_AGGREGATOR; } @@ -61,7 +66,7 @@ public static NumericNilVectorAggregator floatNilVectorAggregator() /** * @return A vectorized aggregator that returns the default long value. */ - public static NumericNilVectorAggregator longNilVectorAggregator() + public static NilVectorAggregator longNilVectorAggregator() { return LONG_NIL_VECTOR_AGGREGATOR; } @@ -69,7 +74,12 @@ public static NumericNilVectorAggregator longNilVectorAggregator() @Nullable private final Object returnValue; - private NumericNilVectorAggregator(@Nullable Object returnValue) + public static NilVectorAggregator of(Object returnValue) + { + return new NilVectorAggregator(returnValue); + } + + private NilVectorAggregator(@Nullable Object returnValue) { this.returnValue = returnValue; } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java index 7cee6f5ca64d..4e9ecd2523b4 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/DoubleFirstAggregatorFactory.java @@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.BaseDoubleColumnValueSelector; @@ -149,7 +149,7 @@ public VectorAggregator factorizeVector( timeColumn); return new DoubleFirstVectorAggregator(timeSelector, valueSelector); } - return NumericNilVectorAggregator.doubleNilVectorAggregator(); + return NilVectorAggregator.of(NilVectorAggregator.DOUBLE_NIL_PAIR); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java index 68826bc2c0a1..ee28cd35dc6a 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/FloatFirstAggregatorFactory.java @@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.BaseFloatColumnValueSelector; @@ -138,7 +138,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory columnSelect VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); return new FloatFirstVectorAggregator(timeSelector, valueSelector); } - return NumericNilVectorAggregator.floatNilVectorAggregator(); + return NilVectorAggregator.of(NilVectorAggregator.FLOAT_NIL_PAIR); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java index 729a1bef26e2..c8aee33f511e 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/LongFirstAggregatorFactory.java @@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.BaseLongColumnValueSelector; @@ -138,7 +138,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory columnSelect timeColumn); return new LongFirstVectorAggregator(timeSelector, valueSelector); } - return NumericNilVectorAggregator.longNilVectorAggregator(); + return NilVectorAggregator.of(NilVectorAggregator.LONG_NIL_PAIR); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java index 193acaf25a99..1bc625258593 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java @@ -42,11 +42,14 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; + import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -188,6 +191,17 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact { final VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn); ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + selectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.STRING + ); + return new StringFirstVectorAggregator(timeSelector, objectSelector, maxStringBytes); + } if (capabilities != null) { if (capabilities.is(ValueType.STRING) && capabilities.isDictionaryEncoded().isTrue()) { // Case 1: Single value string with dimension selector diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java index d3770fb3ae47..4b5f965a854e 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/DoubleLastAggregatorFactory.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.UOE; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; @@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; @@ -42,6 +43,7 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorValueSelector; @@ -125,14 +127,12 @@ public VectorAggregator factorizeVector( ) { ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); - VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( - timeColumn); - if (capabilities == null || capabilities.isNumeric()) { + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); return new DoubleLastVectorAggregator(timeSelector, valueSelector); } else { - return NumericNilVectorAggregator.doubleNilVectorAggregator(); + return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultDoubleValue())); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java index dff50b095bcf..bc0ee23a08bc 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/FloatLastAggregatorFactory.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.UOE; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; @@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; @@ -42,6 +43,7 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorValueSelector; @@ -136,15 +138,13 @@ public VectorAggregator factorizeVector( VectorColumnSelectorFactory columnSelectorFactory ) { - ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); - VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( - timeColumn); - if (capabilities == null || capabilities.isNumeric()) { + final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); return new FloatLastVectorAggregator(timeSelector, valueSelector); } else { - return NumericNilVectorAggregator.floatNilVectorAggregator(); + return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultFloatValue())); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java index 29d9ad2a06ea..b08d8d386461 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/LongLastAggregatorFactory.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.collections.SerializablePair; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.UOE; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; @@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.VectorAggregator; -import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; +import org.apache.druid.query.aggregation.any.NilVectorAggregator; import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; @@ -42,6 +43,7 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorValueSelector; @@ -136,14 +138,13 @@ public VectorAggregator factorizeVector( VectorColumnSelectorFactory columnSelectorFactory ) { - ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); - VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); - VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( - timeColumn); - if (capabilities == null || capabilities.isNumeric()) { + final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); + VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); return new LongLastVectorAggregator(timeSelector, valueSelector); } else { - return NumericNilVectorAggregator.longNilVectorAggregator(); + return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultLongValue())); } } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java index f6ca2a09d3b4..234f03be3f9b 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java @@ -42,9 +42,11 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.segment.virtual.ExpressionVectorSelectors; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -156,16 +158,25 @@ public boolean canVectorize(ColumnInspector columnInspector) public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) { - ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); + final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); + VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn); + if (Types.isNumeric(capabilities)) { + VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName); + VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( + selectorFactory.getReadableVectorInspector(), + fieldName, + valueSelector, + capabilities.toColumnType(), + ColumnType.STRING + ); + return new StringLastVectorAggregator(timeSelector, objectSelector, maxStringBytes); + } VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName); - VectorValueSelector timeSelector = selectorFactory.makeValueSelector( - timeColumn); if (capabilities != null) { return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes); } else { return new StringLastVectorAggregator(null, vSelector, maxStringBytes); } - } @Override diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregatorTest.java index 148f4f95937a..e3b461da5687 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstVectorAggregatorTest.java @@ -23,35 +23,45 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; +import org.apache.druid.segment.vector.NoFilterVectorOffset; +import org.apache.druid.segment.vector.ReadableVectorInspector; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; +import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Answers; import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; -@RunWith(MockitoJUnitRunner.class) public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final String[] VALUES = new String[]{"a", "b", null, "c"}; + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; private static final boolean[] NULLS = new boolean[]{false, false, true, false}; private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; private static final String TIME_COL = "__time"; - private long[] times = {2436, 6879, 7888, 8224}; - private long[] timesSame = {2436, 2436}; - private SerializablePairLongString[] pairs = { + private final long[] times = {2436, 6879, 7888, 8224}; + private final long[] timesSame = {2436, 2436}; + private final SerializablePairLongString[] pairs = { new SerializablePairLongString(2345001L, "first"), new SerializablePairLongString(2345100L, "notFirst") }; @@ -69,8 +79,10 @@ public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest private StringFirstVectorAggregator targetWithPairs; private StringFirstAggregatorFactory stringFirstAggregatorFactory; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private StringFirstAggregatorFactory stringFirstAggregatorFactory1; + private VectorColumnSelectorFactory selectorFactory; + private VectorValueSelector nonStringValueSelector; @Before public void setup() @@ -78,19 +90,189 @@ public void setup() byte[] randomBytes = new byte[1024]; ThreadLocalRandom.current().nextBytes(randomBytes); buf = ByteBuffer.wrap(randomBytes); - Mockito.doReturn(VALUES).when(selector).getObjectVector(); - Mockito.doReturn(times).when(timeSelector).getLongVector(); - Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector(); - Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector(); + + timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length)) + { + @Override + public long[] getLongVector() + { + return times; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + }; + + selector = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return VALUES; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + timesSame.length, + 0, + timesSame.length + )) + { + @Override + public long[] getLongVector() + { + return timesSame; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return null; + } + }; + selectorForPairs = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return pairs; + } + + @Override + public int getMaxVectorSize() + { + return 2; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + + nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) + { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return NULLS; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; + + selectorFactory = new VectorColumnSelectorFactory() + { + @Override + public ReadableVectorInspector getReadableVectorInspector() + { + return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length); + } + + @Override + public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public VectorValueSelector makeValueSelector(String column) + { + if (TIME_COL.equals(column)) { + return timeSelector; + } else if (FIELD_NAME_LONG.equals(column)) { + return nonStringValueSelector; + } + return null; + } + + @Override + public VectorObjectSelector makeObjectSelector(String column) + { + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } + } + + @Nullable + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + if (FIELD_NAME.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities(); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); + } + return null; + } + }; + target = new StringFirstVectorAggregator(timeSelector, selector, 10); targetWithPairs = new StringFirstVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); clearBufferForPositions(0, 0); - Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME); - Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL); stringFirstAggregatorFactory = new StringFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10); - + stringFirstAggregatorFactory1 = new StringFirstAggregatorFactory(NAME, FIELD_NAME_LONG, TIME_COL, 10); } @Test @@ -129,6 +311,19 @@ public void aggregate() Assert.assertEquals(VALUES[0], result.rhs); } + @Test + public void testStringEarliestOnNonStringColumns() + { + Assert.assertTrue(stringFirstAggregatorFactory1.canVectorize(selectorFactory)); + VectorAggregator vectorAggregator = stringFirstAggregatorFactory1.factorizeVector(selectorFactory); + Assert.assertNotNull(vectorAggregator); + Assert.assertEquals(StringFirstVectorAggregator.class, vectorAggregator.getClass()); + vectorAggregator.aggregate(buf, 0, 0, LONG_VALUES.length); + Pair result = (Pair) vectorAggregator.get(buf, 0); + Assert.assertEquals(times[0], result.lhs.longValue()); + Assert.assertEquals(STRING_VALUES[0], result.rhs); + } + @Test public void aggregateBatchWithoutRows() { diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java index 6e5c0275107b..f144552d57e9 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java @@ -23,74 +23,245 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; +import org.apache.druid.segment.vector.NoFilterVectorOffset; +import org.apache.druid.segment.vector.ReadableVectorInspector; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; +import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Answers; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; -@RunWith(MockitoJUnitRunner.class) public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final String[] VALUES = new String[]{"a", "b", null, "c"}; + private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; + private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"}; + private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; + private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; private static final boolean[] NULLS = new boolean[]{false, false, true, false}; private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; + private static final String FIELD_NAME_LONG = "LONG_NAME"; private static final String TIME_COL = "__time"; - private long[] times = {2436, 6879, 7888, 8224}; - private long[] timesSame = {2436, 2436}; - private SerializablePairLongString[] pairs = { + private final long[] times = {2436, 6879, 7888, 8224}; + private final long[] timesSame = {2436, 2436}; + private final SerializablePairLongString[] pairs = { new SerializablePairLongString(2345100L, "last"), new SerializablePairLongString(2345001L, "notLast") }; - @Mock private VectorObjectSelector selector; - @Mock - private VectorObjectSelector selectorForPairs; - @Mock private BaseLongVectorValueSelector timeSelector; - @Mock - private BaseLongVectorValueSelector timeSelectorForPairs; + private VectorValueSelector nonStringValueSelector; private ByteBuffer buf; private StringLastVectorAggregator target; private StringLastVectorAggregator targetWithPairs; private StringLastAggregatorFactory stringLastAggregatorFactory; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private StringLastAggregatorFactory stringLastAggregatorFactory1; + private VectorColumnSelectorFactory selectorFactory; + @Before public void setup() { byte[] randomBytes = new byte[1024]; ThreadLocalRandom.current().nextBytes(randomBytes); buf = ByteBuffer.wrap(randomBytes); - Mockito.doReturn(VALUES).when(selector).getObjectVector(); - Mockito.doReturn(times).when(timeSelector).getLongVector(); - Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector(); - Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector(); + timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length)) + { + @Override + public long[] getLongVector() + { + return times; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return NULLS; + } + }; + nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + LONG_VALUES.length, + 0, + LONG_VALUES.length + )) + { + @Override + public long[] getLongVector() + { + return LONG_VALUES; + } + + @Override + public float[] getFloatVector() + { + return FLOAT_VALUES; + } + + @Override + public double[] getDoubleVector() + { + return DOUBLE_VALUES; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return NULLS; + } + + @Override + public int getMaxVectorSize() + { + return 4; + } + + @Override + public int getCurrentVectorSize() + { + return 4; + } + }; + selector = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return VALUES; + } + + @Override + public int getMaxVectorSize() + { + return 0; + } + + @Override + public int getCurrentVectorSize() + { + return 0; + } + }; + BaseLongVectorValueSelector timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset( + times.length, + 0, + times.length + )) + { + @Override + public long[] getLongVector() + { + return timesSame; + } + + @Nullable + @Override + public boolean[] getNullVector() + { + return new boolean[0]; + } + }; + VectorObjectSelector selectorForPairs = new VectorObjectSelector() + { + @Override + public Object[] getObjectVector() + { + return pairs; + } + + @Override + public int getMaxVectorSize() + { + return 2; + } + + @Override + public int getCurrentVectorSize() + { + return 2; + } + }; + selectorFactory = new VectorColumnSelectorFactory() + { + @Override + public ReadableVectorInspector getReadableVectorInspector() + { + return new NoFilterVectorOffset(LONG_VALUES.length, 0, LONG_VALUES.length); + } + + @Override + public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public VectorValueSelector makeValueSelector(String column) + { + if (TIME_COL.equals(column)) { + return timeSelector; + } else if (FIELD_NAME_LONG.equals(column)) { + return nonStringValueSelector; + } + return null; + } + + @Override + public VectorObjectSelector makeObjectSelector(String column) + { + if (FIELD_NAME.equals(column)) { + return selector; + } else { + return null; + } + } + + @Nullable + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + if (FIELD_NAME.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities(); + } else if (FIELD_NAME_LONG.equals(column)) { + return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG); + } + return null; + } + }; + target = new StringLastVectorAggregator(timeSelector, selector, 10); targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); clearBufferForPositions(0, 0); - Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME); - Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL); stringLastAggregatorFactory = new StringLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10); - + stringLastAggregatorFactory1 = new StringLastAggregatorFactory(NAME, FIELD_NAME_LONG, TIME_COL, 10); } @Test @@ -112,6 +283,19 @@ public void testFactory() Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass()); } + @Test + public void testStringLastOnNonStringColumns() + { + Assert.assertTrue(stringLastAggregatorFactory1.canVectorize(selectorFactory)); + VectorAggregator vectorAggregator = stringLastAggregatorFactory1.factorizeVector(selectorFactory); + Assert.assertNotNull(vectorAggregator); + Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass()); + vectorAggregator.aggregate(buf, 0, 0, LONG_VALUES.length); + Pair result = (Pair) vectorAggregator.get(buf, 0); + Assert.assertEquals(times[3], result.lhs.longValue()); + Assert.assertEquals(STRING_VALUES[3], result.rhs); + } + @Test public void initValueShouldBeMinDate() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSimpleQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSimpleQueryTest.java index b339e5f4b7de..1eb6d58bcb87 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSimpleQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSimpleQueryTest.java @@ -632,7 +632,6 @@ public void testGroupByDimAndTimeAndDimOrderByDimAndTimeDim() @Test public void testEarliestByLatestByWithExpression() { - cannotVectorize(); testBuilder() .sql("SELECT\n" + " channel\n" From 5c42ac8c4dbbc8fa34ab3910145847cea2e49ec1 Mon Sep 17 00:00:00 2001 From: Soumyava <93540295+somu-imply@users.noreply.github.com> Date: Wed, 13 Sep 2023 17:37:26 -0700 Subject: [PATCH 03/10] =?UTF-8?q?Fix=20for=20latest=20agg=20to=20handle=20?= =?UTF-8?q?nulls=20in=20time=20column.=20Also=20adding=20optimi=E2=80=A6?= =?UTF-8?q?=20(#14911)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix for latest agg to handle nulls in time column. Also adding optimization for dictionary encoded string columns * One minor fix * Adding more tests for the new class * Changing the init to a putInt --- ...eStringFirstDimensionVectorAggregator.java | 2 +- ...leStringLastDimensionVectorAggregator.java | 124 ++++++++++++++++++ .../last/StringLastAggregatorFactory.java | 22 +++- .../last/StringLastVectorAggregator.java | 20 ++- .../last/StringLastVectorAggregatorTest.java | 109 ++++++++++++++- 5 files changed, 261 insertions(+), 16 deletions(-) create mode 100644 processing/src/main/java/org/apache/druid/query/aggregation/last/SingleStringLastDimensionVectorAggregator.java diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/SingleStringFirstDimensionVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/SingleStringFirstDimensionVectorAggregator.java index 22fa50ea4623..119e13464a09 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/SingleStringFirstDimensionVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/SingleStringFirstDimensionVectorAggregator.java @@ -57,7 +57,7 @@ public void init(ByteBuffer buf, int position) position + NumericFirstVectorAggregator.NULL_OFFSET, useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE ); - buf.putLong(position + NumericFirstVectorAggregator.VALUE_OFFSET, 0); + buf.putInt(position + NumericFirstVectorAggregator.VALUE_OFFSET, 0); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/SingleStringLastDimensionVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/SingleStringLastDimensionVectorAggregator.java new file mode 100644 index 000000000000..6b39088faa2e --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/SingleStringLastDimensionVectorAggregator.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.last; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.query.aggregation.SerializablePairLongString; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; +import org.apache.druid.segment.vector.VectorValueSelector; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +public class SingleStringLastDimensionVectorAggregator implements VectorAggregator +{ + private final VectorValueSelector timeSelector; + private final SingleValueDimensionVectorSelector valueDimensionVectorSelector; + private long lastTime; + private final int maxStringBytes; + private final boolean useDefault = NullHandling.replaceWithDefault(); + + public SingleStringLastDimensionVectorAggregator( + VectorValueSelector timeSelector, + SingleValueDimensionVectorSelector valueDimensionVectorSelector, + int maxStringBytes + ) + { + this.timeSelector = timeSelector; + this.valueDimensionVectorSelector = valueDimensionVectorSelector; + this.maxStringBytes = maxStringBytes; + this.lastTime = Long.MIN_VALUE; + } + + @Override + public void init(ByteBuffer buf, int position) + { + buf.putLong(position, Long.MIN_VALUE); + buf.put( + position + NumericLastVectorAggregator.NULL_OFFSET, + useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE + ); + buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, 0); + } + + @Override + public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) + { + final long[] timeVector = timeSelector.getLongVector(); + final boolean[] nullTimeVector = timeSelector.getNullVector(); + final int[] valueVector = valueDimensionVectorSelector.getRowVector(); + lastTime = buf.getLong(position); + int index; + + long latestTime; + for (index = endRow - 1; index >= startRow; index--) { + if (nullTimeVector != null && nullTimeVector[index]) { + continue; + } + latestTime = timeVector[index]; + if (latestTime > lastTime) { + lastTime = latestTime; + buf.putLong(position, lastTime); + buf.put(position + NumericLastVectorAggregator.NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE); + buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, valueVector[index]); + } + } + } + + @Override + public void aggregate(ByteBuffer buf, int numRows, int[] positions, @Nullable int[] rows, int positionOffset) + { + final long[] timeVector = timeSelector.getLongVector(); + final boolean[] nullTimeVector = timeSelector.getNullVector(); + final int[] values = valueDimensionVectorSelector.getRowVector(); + for (int i = numRows - 1; i >= 0; i--) { + if (nullTimeVector != null && nullTimeVector[i]) { + continue; + } + int position = positions[i] + positionOffset; + int row = rows == null ? i : rows[i]; + lastTime = buf.getLong(position); + if (timeVector[row] > lastTime) { + lastTime = timeVector[row]; + buf.putLong(position, lastTime); + buf.put(position + NumericLastVectorAggregator.NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE); + buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, values[row]); + } + } + } + + @Nullable + @Override + public Object get(ByteBuffer buf, int position) + { + int index = buf.getInt(position + NumericLastVectorAggregator.VALUE_OFFSET); + long earliest = buf.getLong(position); + String strValue = valueDimensionVectorSelector.lookupName(index); + return new SerializablePairLongString(earliest, StringUtils.chop(strValue, maxStringBytes)); + } + + @Override + public void close() + { + // nothing to close + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java index 234f03be3f9b..909b7d4971eb 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java @@ -35,6 +35,7 @@ import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.StringFirstLastUtils; import org.apache.druid.query.cache.CacheKeyBuilder; +import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.segment.BaseObjectColumnValueSelector; import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnSelectorFactory; @@ -43,6 +44,8 @@ import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.Types; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorValueSelector; @@ -160,6 +163,7 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn); + if (Types.isNumeric(capabilities)) { VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName); VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( @@ -171,6 +175,18 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact ); return new StringLastVectorAggregator(timeSelector, objectSelector, maxStringBytes); } + + if (capabilities != null) { + if (capabilities.is(ValueType.STRING) && capabilities.isDictionaryEncoded().isTrue()) { + if (!capabilities.hasMultipleValues().isTrue()) { + SingleValueDimensionVectorSelector sSelector = selectorFactory.makeSingleValueDimensionSelector( + DefaultDimensionSpec.of( + fieldName)); + return new SingleStringLastDimensionVectorAggregator(timeSelector, sSelector, maxStringBytes); + } + } + } + VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName); if (capabilities != null) { return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes); @@ -296,9 +312,9 @@ public boolean equals(Object o) } StringLastAggregatorFactory that = (StringLastAggregatorFactory) o; return maxStringBytes == that.maxStringBytes && - Objects.equals(fieldName, that.fieldName) && - Objects.equals(timeColumn, that.timeColumn) && - Objects.equals(name, that.name); + Objects.equals(fieldName, that.fieldName) && + Objects.equals(timeColumn, that.timeColumn) && + Objects.equals(name, that.name); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java index a18a1d4c9631..00e70c78098e 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java @@ -64,8 +64,9 @@ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) if (timeSelector == null) { return; } - long[] times = timeSelector.getLongVector(); - Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector(); + final long[] times = timeSelector.getLongVector(); + final boolean[] nullTimeVector = timeSelector.getNullVector(); + final Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector(); lastTime = buf.getLong(position); int index; @@ -76,6 +77,9 @@ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) if (times[i] <= lastTime) { continue; } + if (nullTimeVector != null && nullTimeVector[i]) { + continue; + } index = i; final boolean foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]); if (foldNeeded) { @@ -127,22 +131,24 @@ public void aggregate( if (timeSelector == null) { return; } - long[] timeVector = timeSelector.getLongVector(); - Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector(); + final long[] timeVector = timeSelector.getLongVector(); + final boolean[] nullTimeVector = timeSelector.getNullVector(); + final Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector(); // iterate once over the object vector to find first non null element and // determine if the type is Pair or not boolean foldNeeded = false; for (Object obj : objectsWhichMightBeStrings) { - if (obj == null) { - continue; - } else { + if (obj != null) { foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj); break; } } for (int i = 0; i < numRows; i++) { + if (nullTimeVector != null && nullTimeVector[i]) { + continue; + } int position = positions[i] + positionOffset; int row = rows == null ? i : rows[i]; long lastTime = buf.getLong(position); diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java index f144552d57e9..da79faae3c15 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java @@ -23,7 +23,9 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.IdLookup; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilitiesImpl; import org.apache.druid.segment.column.ColumnType; @@ -49,11 +51,13 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest { private static final double EPSILON = 1e-5; private static final String[] VALUES = new String[]{"a", "b", null, "c"}; + private static final int[] DICT_VALUES = new int[]{1, 2, 0, 3}; private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"}; private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; private static final boolean[] NULLS = new boolean[]{false, false, true, false}; + private static final boolean[] NULLS1 = new boolean[]{false, false}; private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; private static final String FIELD_NAME_LONG = "LONG_NAME"; @@ -74,6 +78,7 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest private StringLastAggregatorFactory stringLastAggregatorFactory; private StringLastAggregatorFactory stringLastAggregatorFactory1; + private SingleStringLastDimensionVectorAggregator targetSingleDim; private VectorColumnSelectorFactory selectorFactory; @@ -96,7 +101,7 @@ public long[] getLongVector() @Override public boolean[] getNullVector() { - return NULLS; + return null; } }; nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( @@ -163,9 +168,9 @@ public int getCurrentVectorSize() } }; BaseLongVectorValueSelector timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset( - times.length, + timesSame.length, 0, - times.length + timesSame.length )) { @Override @@ -178,7 +183,7 @@ public long[] getLongVector() @Override public boolean[] getNullVector() { - return new boolean[0]; + return NULLS1; } }; VectorObjectSelector selectorForPairs = new VectorObjectSelector() @@ -212,7 +217,61 @@ public ReadableVectorInspector getReadableVectorInspector() @Override public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) { - return null; + return new SingleValueDimensionVectorSelector() + { + @Override + public int[] getRowVector() + { + return DICT_VALUES; + } + + @Override + public int getValueCardinality() + { + return DICT_VALUES.length; + } + + @Nullable + @Override + public String lookupName(int id) + { + switch (id) { + case 1: + return "a"; + case 2: + return "b"; + case 3: + return "c"; + default: + return null; + } + } + + @Override + public boolean nameLookupPossibleInAdvance() + { + return false; + } + + @Nullable + @Override + public IdLookup idLookup() + { + return null; + } + + @Override + public int getMaxVectorSize() + { + return DICT_VALUES.length; + } + + @Override + public int getCurrentVectorSize() + { + return DICT_VALUES.length; + } + }; } @Override @@ -257,6 +316,8 @@ public ColumnCapabilities getColumnCapabilities(String column) target = new StringLastVectorAggregator(timeSelector, selector, 10); targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); + targetSingleDim = new SingleStringLastDimensionVectorAggregator(timeSelector, selectorFactory.makeSingleValueDimensionSelector( + DefaultDimensionSpec.of(FIELD_NAME)), 10); clearBufferForPositions(0, 0); @@ -361,6 +422,44 @@ public void aggregateBatchWithRows() } } + @Test + public void aggregateSingleDim() + { + targetSingleDim.aggregate(buf, 0, 0, VALUES.length); + Pair result = (Pair) targetSingleDim.get(buf, 0); + Assert.assertEquals(times[3], result.lhs.longValue()); + Assert.assertEquals(VALUES[3], result.rhs); + } + + @Test + public void aggregateBatchWithoutRowsSingleDim() + { + int[] positions = new int[]{0, 43, 70}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + targetSingleDim.aggregate(buf, 3, positions, null, positionOffset); + for (int i = 0; i < positions.length; i++) { + Pair result = (Pair) targetSingleDim.get(buf, positions[i] + positionOffset); + Assert.assertEquals(times[i], result.lhs.longValue()); + Assert.assertEquals(VALUES[i], result.rhs); + } + } + + @Test + public void aggregateBatchWithRowsSingleDim() + { + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + targetSingleDim.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + Pair result = (Pair) targetSingleDim.get(buf, positions[i] + positionOffset); + Assert.assertEquals(times[rows[i]], result.lhs.longValue()); + Assert.assertEquals(VALUES[rows[i]], result.rhs); + } + } + private void clearBufferForPositions(int offset, int... positions) { for (int position : positions) { From 7bbefd57413ce4fcaa291af1af75e49079db3980 Mon Sep 17 00:00:00 2001 From: Soumyava <93540295+somu-imply@users.noreply.github.com> Date: Wed, 13 Sep 2023 22:11:36 -0700 Subject: [PATCH 04/10] Updating version in from.ftl (#14982) --- sql/src/main/codegen/includes/from.ftl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/src/main/codegen/includes/from.ftl b/sql/src/main/codegen/includes/from.ftl index 9e9d3fcbe032..ae6d03b841cb 100644 --- a/sql/src/main/codegen/includes/from.ftl +++ b/sql/src/main/codegen/includes/from.ftl @@ -18,7 +18,7 @@ */ /* - * Druid note: this file is copied from core/src/main/codegen/templates/Parser.jj in Calcite 1.34.0, with changes to + * Druid note: this file is copied from core/src/main/codegen/templates/Parser.jj in Calcite 1.35.0, with changes to * to add two elements of Druid syntax to the FROM clause: * * id [ () ] From 0e3df2d2e9eb3388d43ac722a76de53a6a263ee8 Mon Sep 17 00:00:00 2001 From: AmatyaAvadhanula Date: Thu, 14 Sep 2023 14:58:02 +0530 Subject: [PATCH 05/10] Clean up stale locks if segment allocation fails (#14966) * Clean up stale locks if segment allocation fails due to an exception --- .../druid/indexing/overlord/TaskLockbox.java | 33 ++++- .../indexing/overlord/TaskLockboxTest.java | 137 ++++++++++++++++++ 2 files changed, 167 insertions(+), 3 deletions(-) diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/overlord/TaskLockbox.java b/indexing-service/src/main/java/org/apache/druid/indexing/overlord/TaskLockbox.java index ae8c3313ec70..a89816888887 100644 --- a/indexing-service/src/main/java/org/apache/druid/indexing/overlord/TaskLockbox.java +++ b/indexing-service/src/main/java/org/apache/druid/indexing/overlord/TaskLockbox.java @@ -494,9 +494,12 @@ public List allocateSegments( allocateSegmentIds(dataSource, interval, skipSegmentLineageCheck, holderList.getPending()); holderList.getPending().forEach(holder -> acquireTaskLock(holder, false)); } - holderList.getPending().forEach(holder -> addTaskAndPersistLocks(holder, isTimeChunkLock)); } + catch (Exception e) { + holderList.clearStaleLocks(this); + throw e; + } finally { giant.unlock(); } @@ -711,7 +714,8 @@ private TaskLockPosse createNewTaskLockPosse(LockRequest request) * for the given requests. Updates the holder with the allocated segment if * the allocation succeeds, otherwise marks it as failed. */ - private void allocateSegmentIds( + @VisibleForTesting + void allocateSegmentIds( String dataSource, Interval interval, boolean skipSegmentLineageCheck, @@ -1598,6 +1602,28 @@ Set getPending() return pending; } + /** + * When task locks are acquired in an attempt to allocate segments, * a new lock posse might be created. + * However, the posse is associated with the task only after all the segment allocations have succeeded. + * If there is an exception, unlock all such unassociated locks. + */ + void clearStaleLocks(TaskLockbox taskLockbox) + { + all + .stream() + .filter(holder -> holder.acquiredLock != null + && holder.taskLockPosse != null + && !holder.taskLockPosse.containsTask(holder.task)) + .forEach(holder -> { + holder.taskLockPosse.addTask(holder.task); + taskLockbox.unlock( + holder.task, + holder.acquiredLock.getInterval(), + holder.acquiredLock instanceof SegmentLock ? ((SegmentLock) holder.acquiredLock).getPartitionId() : null + ); + log.info("Cleared stale lock[%s] for task[%s]", holder.acquiredLock, holder.task.getId()); + }); + } List getResults() { @@ -1608,7 +1634,8 @@ List getResults() /** * Contains the task, request, lock and final result for a segment allocation. */ - private static class SegmentAllocationHolder + @VisibleForTesting + static class SegmentAllocationHolder { final AllocationHolderList list; diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TaskLockboxTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TaskLockboxTest.java index 6b9b4cd213d5..ceb1657f68eb 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TaskLockboxTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TaskLockboxTest.java @@ -25,6 +25,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.jsontype.NamedType; import com.fasterxml.jackson.databind.module.SimpleModule; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import org.apache.druid.indexer.TaskStatus; @@ -34,6 +35,8 @@ import org.apache.druid.indexing.common.TaskLockType; import org.apache.druid.indexing.common.TaskToolbox; import org.apache.druid.indexing.common.TimeChunkLock; +import org.apache.druid.indexing.common.actions.SegmentAllocateAction; +import org.apache.druid.indexing.common.actions.SegmentAllocateRequest; import org.apache.druid.indexing.common.actions.TaskActionClient; import org.apache.druid.indexing.common.config.TaskConfig; import org.apache.druid.indexing.common.config.TaskStorageConfig; @@ -46,6 +49,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.metadata.DerbyMetadataStorageActionHandlerFactory; @@ -1727,6 +1731,117 @@ public void testConflictsWithOverlappingSharedLocks() validator.expectActiveLocks(conflictingLock, floorLock); } + @Test + public void testDoNotCleanUsedLockAfterSegmentAllocationFailure() + { + final Task task = NoopTask.create(); + final Interval theInterval = Intervals.of("2023/2024"); + taskStorage.insert(task, TaskStatus.running(task.getId())); + + final TaskLockbox testLockbox = new SegmentAllocationFailingTaskLockbox(taskStorage, metadataStorageCoordinator); + testLockbox.add(task); + final LockResult lockResult = testLockbox.tryLock(task, new TimeChunkLockRequest( + TaskLockType.SHARED, + task, + theInterval, + null + )); + Assert.assertTrue(lockResult.isOk()); + + SegmentAllocateRequest request = new SegmentAllocateRequest( + task, + new SegmentAllocateAction( + task.getDataSource(), + DateTimes.of("2023-01-01"), + Granularities.NONE, + Granularities.YEAR, + task.getId(), + null, + false, + null, + null, + TaskLockType.SHARED + ), + 90 + ); + + try { + testLockbox.allocateSegments( + ImmutableList.of(request), + "DS", + theInterval, + false, + LockGranularity.TIME_CHUNK + ); + } + catch (Exception e) { + // do nothing + } + Assert.assertFalse(testLockbox.getAllLocks().isEmpty()); + Assert.assertEquals( + lockResult.getTaskLock(), + testLockbox.getOnlyTaskLockPosseContainingInterval(task, theInterval).get(0).getTaskLock() + ); + } + + @Test + public void testCleanUpLocksAfterSegmentAllocationFailure() + { + final Task task = NoopTask.create(); + taskStorage.insert(task, TaskStatus.running(task.getId())); + + final TaskLockbox testLockbox = new SegmentAllocationFailingTaskLockbox(taskStorage, metadataStorageCoordinator); + testLockbox.add(task); + + SegmentAllocateRequest request0 = new SegmentAllocateRequest( + task, + new SegmentAllocateAction( + task.getDataSource(), + DateTimes.of("2023-01-01"), + Granularities.NONE, + Granularities.YEAR, + task.getId(), + null, + false, + null, + null, + TaskLockType.SHARED + ), + 90 + ); + + SegmentAllocateRequest request1 = new SegmentAllocateRequest( + task, + new SegmentAllocateAction( + task.getDataSource(), + DateTimes.of("2023-01-01"), + Granularities.NONE, + Granularities.MONTH, + task.getId(), + null, + false, + null, + null, + TaskLockType.SHARED + ), + 90 + ); + + try { + testLockbox.allocateSegments( + ImmutableList.of(request0, request1), + "DS", + Intervals.of("2023/2024"), + false, + LockGranularity.TIME_CHUNK + ); + } + catch (Exception e) { + // do nothing + } + Assert.assertTrue(testLockbox.getAllLocks().isEmpty()); + } + private class TaskLockboxValidator { @@ -1953,4 +2068,26 @@ protected TaskLockPosse verifyAndCreateOrFindLockPosse(Task task, TaskLock taskL .contains("FailingLockAcquisition") ? null : super.verifyAndCreateOrFindLockPosse(task, taskLock); } } + + private static class SegmentAllocationFailingTaskLockbox extends TaskLockbox + { + public SegmentAllocationFailingTaskLockbox( + TaskStorage taskStorage, + IndexerMetadataStorageCoordinator metadataStorageCoordinator + ) + { + super(taskStorage, metadataStorageCoordinator); + } + + @Override + void allocateSegmentIds( + String dataSource, + Interval interval, + boolean skipSegmentLineageCheck, + Collection holders + ) + { + throw new RuntimeException("This lockbox cannot allocate segemnts."); + } + } } From 3ae5e978012e7a553e6ad3007eb210ab8da5b6fc Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Thu, 14 Sep 2023 09:19:09 -0700 Subject: [PATCH 06/10] Add IS [NOT] TRUE, IS [NOT] FALSE native functions. (#14977) They are not quite the same as "x == true", "x != true", etc. These functions never return null, even when "x" itself is null. --- .../org/apache/druid/math/expr/Function.java | 155 ++++++++++++++++++ .../org/apache/druid/math/expr/EvalTest.java | 120 ++++++++++++++ .../UnarySuffixOperatorConversion.java | 64 -------- .../calcite/planner/DruidOperatorTable.java | 21 +-- .../druid/sql/calcite/CalciteQueryTest.java | 2 +- 5 files changed, 280 insertions(+), 82 deletions(-) delete mode 100644 sql/src/main/java/org/apache/druid/sql/calcite/expression/UnarySuffixOperatorConversion.java diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index 632425d1c751..406ffac1ea7d 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -27,6 +27,7 @@ import org.apache.druid.java.util.common.UOE; import org.apache.druid.math.expr.vector.CastToTypeVectorProcessor; import org.apache.druid.math.expr.vector.ExprVectorProcessor; +import org.apache.druid.math.expr.vector.VectorComparisonProcessors; import org.apache.druid.math.expr.vector.VectorMathProcessors; import org.apache.druid.math.expr.vector.VectorProcessors; import org.apache.druid.math.expr.vector.VectorStringProcessors; @@ -2224,6 +2225,160 @@ public ExprVectorProcessor asVectorProcessor(Expr.VectorInputBindingInspe } } + /** + * SQL function "IS NOT FALSE". Different from "IS TRUE" in that it returns true for NULL as well. + */ + class IsNotFalseFunc implements Function + { + @Override + public String name() + { + return "notfalse"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval arg = args.get(0).eval(bindings); + return ExprEval.ofLongBoolean(arg.value() == null || arg.asBoolean()); + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 1); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.LONG; + } + } + + /** + * SQL function "IS NOT TRUE". Different from "IS FALSE" in that it returns true for NULL as well. + */ + class IsNotTrueFunc implements Function + { + @Override + public String name() + { + return "nottrue"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval arg = args.get(0).eval(bindings); + return ExprEval.ofLongBoolean(arg.value() == null || !arg.asBoolean()); + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 1); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.LONG; + } + } + + /** + * SQL function "IS FALSE". + */ + class IsFalseFunc implements Function + { + @Override + public String name() + { + return "isfalse"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval arg = args.get(0).eval(bindings); + return ExprEval.ofLongBoolean(arg.value() != null && !arg.asBoolean()); + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 1); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.LONG; + } + + @Override + public boolean canVectorize(Expr.InputBindingInspector inspector, List args) + { + final Expr expr = args.get(0); + return inspector.areNumeric(expr) && expr.canVectorize(inspector); + } + + @Override + public ExprVectorProcessor asVectorProcessor(Expr.VectorInputBindingInspector inspector, List args) + { + return VectorComparisonProcessors.lessThanOrEqual(inspector, args.get(0), ExprEval.of(0L).toExpr()); + } + } + + /** + * SQL function "IS TRUE". + */ + class IsTrueFunc implements Function + { + @Override + public String name() + { + return "istrue"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval arg = args.get(0).eval(bindings); + return ExprEval.ofLongBoolean(arg.asBoolean()); + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 1); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.LONG; + } + + @Override + public boolean canVectorize(Expr.InputBindingInspector inspector, List args) + { + final Expr expr = args.get(0); + return inspector.areNumeric(expr) && expr.canVectorize(inspector); + } + + @Override + public ExprVectorProcessor asVectorProcessor(Expr.VectorInputBindingInspector inspector, List args) + { + return VectorComparisonProcessors.greaterThan(inspector, args.get(0), ExprEval.of(0L).toExpr()); + } + } + class IsNullFunc implements Function { @Override diff --git a/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java b/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java index 97a91cb39698..c49959aa1ae3 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java @@ -99,6 +99,16 @@ public void testDoubleEval() Assert.assertTrue(evalDouble("2.0 == 2.0", bindings) > 0.0); Assert.assertTrue(evalDouble("2.0 != 1.0", bindings) > 0.0); + Assert.assertEquals(0L, evalLong("istrue(0.0)", bindings)); + Assert.assertEquals(1L, evalLong("isfalse(0.0)", bindings)); + Assert.assertEquals(1L, evalLong("nottrue(0.0)", bindings)); + Assert.assertEquals(0L, evalLong("notfalse(0.0)", bindings)); + + Assert.assertEquals(1L, evalLong("istrue(1.0)", bindings)); + Assert.assertEquals(0L, evalLong("isfalse(1.0)", bindings)); + Assert.assertEquals(0L, evalLong("nottrue(1.0)", bindings)); + Assert.assertEquals(1L, evalLong("notfalse(1.0)", bindings)); + Assert.assertTrue(evalDouble("!-1.0", bindings) > 0.0); Assert.assertTrue(evalDouble("!0.0", bindings) > 0.0); Assert.assertFalse(evalDouble("!2.0", bindings) > 0.0); @@ -121,6 +131,16 @@ public void testDoubleEval() Assert.assertEquals(1L, evalLong("2.0 == 2.0", bindings)); Assert.assertEquals(1L, evalLong("2.0 != 1.0", bindings)); + Assert.assertEquals(0L, evalLong("istrue(0.0)", bindings)); + Assert.assertEquals(1L, evalLong("isfalse(0.0)", bindings)); + Assert.assertEquals(1L, evalLong("nottrue(0.0)", bindings)); + Assert.assertEquals(0L, evalLong("notfalse(0.0)", bindings)); + + Assert.assertEquals(1L, evalLong("istrue(1.0)", bindings)); + Assert.assertEquals(0L, evalLong("isfalse(1.0)", bindings)); + Assert.assertEquals(0L, evalLong("nottrue(1.0)", bindings)); + Assert.assertEquals(1L, evalLong("notfalse(1.0)", bindings)); + Assert.assertEquals(1L, evalLong("!-1.0", bindings)); Assert.assertEquals(1L, evalLong("!0.0", bindings)); Assert.assertEquals(0L, evalLong("!2.0", bindings)); @@ -201,6 +221,106 @@ public void testLongEval() assertEquals("x", eval("nvl(if(x == 9223372036854775806, '', 'x'), 'NULL')", bindings).asString()); } + @Test + public void testIsFalse() + { + assertEquals( + 0L, + new Function.IsFalseFunc() + .apply(ImmutableList.of(new NullLongExpr()), InputBindings.nilBindings()) + .value() + ); + + assertEquals( + 1L, + new Function.IsFalseFunc() + .apply(ImmutableList.of(new LongExpr(0L)), InputBindings.nilBindings()) + .value() + ); + + assertEquals( + 0L, + new Function.IsFalseFunc() + .apply(ImmutableList.of(new LongExpr(1L)), InputBindings.nilBindings()) + .value() + ); + } + + @Test + public void testIsTrue() + { + assertEquals( + 0L, + new Function.IsTrueFunc() + .apply(ImmutableList.of(new NullLongExpr()), InputBindings.nilBindings()) + .value() + ); + + assertEquals( + 0L, + new Function.IsTrueFunc() + .apply(ImmutableList.of(new LongExpr(0L)), InputBindings.nilBindings()) + .value() + ); + + assertEquals( + 1L, + new Function.IsTrueFunc() + .apply(ImmutableList.of(new LongExpr(1L)), InputBindings.nilBindings()) + .value() + ); + } + + @Test + public void testIsNotFalse() + { + assertEquals( + 1L, + new Function.IsNotFalseFunc() + .apply(ImmutableList.of(new NullLongExpr()), InputBindings.nilBindings()) + .value() + ); + + assertEquals( + 0L, + new Function.IsNotFalseFunc() + .apply(ImmutableList.of(new LongExpr(0L)), InputBindings.nilBindings()) + .value() + ); + + assertEquals( + 1L, + new Function.IsNotFalseFunc() + .apply(ImmutableList.of(new LongExpr(1L)), InputBindings.nilBindings()) + .value() + ); + } + + @Test + public void testIsNotTrue() + { + assertEquals( + 1L, + new Function.IsNotTrueFunc() + .apply(ImmutableList.of(new NullLongExpr()), InputBindings.nilBindings()) + .value() + ); + + assertEquals( + 1L, + new Function.IsNotTrueFunc() + .apply(ImmutableList.of(new LongExpr(0L)), InputBindings.nilBindings()) + .value() + ); + + assertEquals( + 0L, + new Function.IsNotTrueFunc() + .apply(ImmutableList.of(new LongExpr(1L)), InputBindings.nilBindings()) + .value() + ); + } + @Test public void testArrayToScalar() { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/UnarySuffixOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/UnarySuffixOperatorConversion.java deleted file mode 100644 index 8a38fbe9c936..000000000000 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/UnarySuffixOperatorConversion.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.sql.calcite.expression; - -import com.google.common.collect.Iterables; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlOperator; -import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.segment.column.RowSignature; -import org.apache.druid.sql.calcite.planner.PlannerContext; - -public class UnarySuffixOperatorConversion implements SqlOperatorConversion -{ - private final SqlOperator operator; - private final String druidOperator; - - public UnarySuffixOperatorConversion(final SqlOperator operator, final String druidOperator) - { - this.operator = operator; - this.druidOperator = druidOperator; - } - - @Override - public SqlOperator calciteOperator() - { - return operator; - } - - @Override - public DruidExpression toDruidExpression( - final PlannerContext plannerContext, - final RowSignature rowSignature, - final RexNode rexNode - ) - { - return OperatorConversions.convertCallBuilder( - plannerContext, - rowSignature, - rexNode, - operands -> StringUtils.format( - "(%s %s)", - Iterables.getOnlyElement(operands).getExpression(), - druidOperator - ) - ); - } -} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index e8ab0cc71d84..16748b0b6ab5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -54,7 +54,6 @@ import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; import org.apache.druid.sql.calcite.expression.UnaryFunctionOperatorConversion; import org.apache.druid.sql.calcite.expression.UnaryPrefixOperatorConversion; -import org.apache.druid.sql.calcite.expression.UnarySuffixOperatorConversion; import org.apache.druid.sql.calcite.expression.WindowSqlAggregate; import org.apache.druid.sql.calcite.expression.builtin.ArrayAppendOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.ArrayConcatOperatorConversion; @@ -372,22 +371,10 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new UnaryPrefixOperatorConversion(SqlStdOperatorTable.UNARY_MINUS, "-")) .add(new UnaryFunctionOperatorConversion(SqlStdOperatorTable.IS_NULL, "isnull")) .add(new UnaryFunctionOperatorConversion(SqlStdOperatorTable.IS_NOT_NULL, "notnull")) - .add(new UnarySuffixOperatorConversion( - SqlStdOperatorTable.IS_FALSE, - "<= 0" - )) // Matches Evals.asBoolean - .add(new UnarySuffixOperatorConversion( - SqlStdOperatorTable.IS_NOT_TRUE, - "<= 0" - )) // Matches Evals.asBoolean - .add(new UnarySuffixOperatorConversion( - SqlStdOperatorTable.IS_TRUE, - "> 0" - )) // Matches Evals.asBoolean - .add(new UnarySuffixOperatorConversion( - SqlStdOperatorTable.IS_NOT_FALSE, - "> 0" - )) // Matches Evals.asBoolean + .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_FALSE, "isfalse")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_TRUE, "istrue")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_NOT_FALSE, "notfalse")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_NOT_TRUE, "nottrue")) .add(new BinaryOperatorConversion(SqlStdOperatorTable.MULTIPLY, "*")) .add(new BinaryOperatorConversion(SqlStdOperatorTable.MOD, "%")) .add(new BinaryOperatorConversion(SqlStdOperatorTable.DIVIDE, "/")) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 582408003719..499dd8faeb8e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -2461,7 +2461,7 @@ public void testExactCountDistinctWithFilter() "v0", NullHandling.replaceWithDefault() ? "(\"cnt\" == 1)" - : "((\"cnt\" == 1) > 0)", + : "istrue((\"cnt\" == 1))", ColumnType.LONG )) .setDimensions(dimensions( From 279b3818f0e02f6a2ff389532f1e6118598a7c24 Mon Sep 17 00:00:00 2001 From: Soumyava <93540295+somu-imply@users.noreply.github.com> Date: Thu, 14 Sep 2023 21:24:14 -0700 Subject: [PATCH 07/10] Make Unnest work with nullif operator (#14993) This is due to the recursive filter creation in unnest storage adapter not performing correctly in case of an empty children. This PR addresses the issue --- .../druid/segment/UnnestStorageAdapter.java | 7 ++- .../sql/calcite/CalciteArraysQueryTest.java | 48 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java index 83694d9618d5..02f8c0064aa2 100644 --- a/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java @@ -478,13 +478,16 @@ private List recursiveRewriteOnUnnestFilters( for (Filter filter : queryFilter.getFilters()) { if (filter.getRequiredColumns().contains(outputColumnName)) { if (filter instanceof AndFilter) { - preFilterList.add(new AndFilter(recursiveRewriteOnUnnestFilters( + List andChildFilters = recursiveRewriteOnUnnestFilters( (BooleanFilter) filter, inputColumn, inputColumnCapabilites, filterSplitter, isTopLevelAndFilter - ))); + ); + if (!andChildFilters.isEmpty()) { + preFilterList.add(new AndFilter(andChildFilters)); + } } else if (filter instanceof OrFilter) { // in case of Or Fiters, we set isTopLevelAndFilter to false that prevents pushing down any child filters to base List orChildFilters = recursiveRewriteOnUnnestFilters( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index 4842044892dc..df4e9b62cc9e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -4807,4 +4807,52 @@ public void testUnnestWithGroupByHavingWithWhereOnUnnestCol() ) ); } + + @Test + public void testUnnestVirtualWithColumnsAndNullIf() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "select c,m2 from druid.foo, unnest(ARRAY[\"m1\", \"m2\"]) as u(c) where NULLIF(c,m2) IS NULL", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + expressionVirtualColumn("j0.unnest", "array(\"m1\",\"m2\")", ColumnType.FLOAT_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .filters( + useDefault ? expressionFilter("(\"j0.unnest\" == \"m2\")") : + or( + expressionFilter("(\"j0.unnest\" == \"m2\")"), + and( + isNull("j0.unnest"), + not(expressionFilter("(\"j0.unnest\" == \"m2\")")) + ) + )) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest", "m2")) + .build() + ), + ImmutableList.of( + new Object[]{1.0f, 1.0D}, + new Object[]{1.0f, 1.0D}, + new Object[]{2.0f, 2.0D}, + new Object[]{2.0f, 2.0D}, + new Object[]{3.0f, 3.0D}, + new Object[]{3.0f, 3.0D}, + new Object[]{4.0f, 4.0D}, + new Object[]{4.0f, 4.0D}, + new Object[]{5.0f, 5.0D}, + new Object[]{5.0f, 5.0D}, + new Object[]{6.0f, 6.0D}, + new Object[]{6.0f, 6.0D} + ) + ); + } } From 0fc5d5405a750d346d4c7a5513d6987fb6666fff Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Fri, 15 Sep 2023 05:44:21 +0000 Subject: [PATCH 08/10] Tweak GHA runner label for MSQ (#14992) --- .github/labeler.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index 101fd25aa5a4..a9bfc45a86ec 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -46,9 +46,12 @@ - 'processing/src/main/java/org/apache/druid/java/util/emitter/**' - 'extensions-contrib/*-emitter/**' +'Area - MSQ': + - 'extensions-core/multi-stage-query/**' + 'Area - Querying': - 'sql/**' - - 'extensions-core/multi-stage-query/**' + - 'extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/**' 'Area - Segment Format and Ser/De': - 'processing/src/main/java/org/apache/druid/segment/**' @@ -62,6 +65,3 @@ 'Kubernetes': - 'extensions-contrib/kubernetes-overlord-extensions/**' - -'MSQ': - - 'extensions-core/multi-stage-query/**' From 39d95955f5508143b4fd2352baa3fe60a4920357 Mon Sep 17 00:00:00 2001 From: Rohan Garg <7731512+rohangarg@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:14:20 +0530 Subject: [PATCH 09/10] Do not eagerly close inner iterators in CloseableIterator#flatMap (#14986) --- .../data/input/s3/S3InputSourceTest.java | 17 ++++---- .../common/parsers/CloseableIterator.java | 42 ++++++++----------- .../impl/InputEntityIteratingReaderTest.java | 8 ++-- .../common/parsers/CloseableIteratorTest.java | 41 ++++++++++++++++++ 4 files changed, 73 insertions(+), 35 deletions(-) diff --git a/extensions-core/s3-extensions/src/test/java/org/apache/druid/data/input/s3/S3InputSourceTest.java b/extensions-core/s3-extensions/src/test/java/org/apache/druid/data/input/s3/S3InputSourceTest.java index fc538b682fc6..6b0bb537c7e1 100644 --- a/extensions-core/s3-extensions/src/test/java/org/apache/druid/data/input/s3/S3InputSourceTest.java +++ b/extensions-core/s3-extensions/src/test/java/org/apache/druid/data/input/s3/S3InputSourceTest.java @@ -1033,14 +1033,15 @@ public void testReaderRetriesOnSdkClientExceptionButNeverSucceedsThenThrows() th new CsvInputFormat(ImmutableList.of("time", "dim1", "dim2"), "|", false, null, 0), temporaryFolder.newFolder() ); - - final IllegalStateException e = Assert.assertThrows(IllegalStateException.class, reader::read); - MatcherAssert.assertThat(e.getCause(), CoreMatchers.instanceOf(IOException.class)); - MatcherAssert.assertThat(e.getCause().getCause(), CoreMatchers.instanceOf(SdkClientException.class)); - MatcherAssert.assertThat( - e.getCause().getCause().getMessage(), - CoreMatchers.startsWith("Data read has a different length than the expected") - ); + try (CloseableIterator readerIterator = reader.read()) { + final IllegalStateException e = Assert.assertThrows(IllegalStateException.class, readerIterator::hasNext); + MatcherAssert.assertThat(e.getCause(), CoreMatchers.instanceOf(IOException.class)); + MatcherAssert.assertThat(e.getCause().getCause(), CoreMatchers.instanceOf(SdkClientException.class)); + MatcherAssert.assertThat( + e.getCause().getCause().getMessage(), + CoreMatchers.startsWith("Data read has a different length than the expected") + ); + } EasyMock.verify(S3_CLIENT); } diff --git a/processing/src/main/java/org/apache/druid/java/util/common/parsers/CloseableIterator.java b/processing/src/main/java/org/apache/druid/java/util/common/parsers/CloseableIterator.java index af1baafe4198..7b81934367d1 100644 --- a/processing/src/main/java/org/apache/druid/java/util/common/parsers/CloseableIterator.java +++ b/processing/src/main/java/org/apache/druid/java/util/common/parsers/CloseableIterator.java @@ -19,7 +19,6 @@ package org.apache.druid.java.util.common.parsers; -import javax.annotation.Nullable; import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; @@ -62,37 +61,37 @@ public void close() throws IOException default CloseableIterator flatMap(Function> function) { - final CloseableIterator delegate = this; + final CloseableIterator outerIterator = this; return new CloseableIterator() { - CloseableIterator iterator = findNextIteratorIfNecessary(); + CloseableIterator currInnerIterator = null; - @Nullable - private CloseableIterator findNextIteratorIfNecessary() + private void findNextIteratorIfNecessary() { - while ((iterator == null || !iterator.hasNext()) && delegate.hasNext()) { - if (iterator != null) { + while ((currInnerIterator == null || !currInnerIterator.hasNext()) && outerIterator.hasNext()) { + if (currInnerIterator != null) { try { - iterator.close(); - iterator = null; + currInnerIterator.close(); + currInnerIterator = null; } catch (IOException e) { throw new UncheckedIOException(e); } } - iterator = function.apply(delegate.next()); - if (iterator.hasNext()) { - return iterator; + currInnerIterator = function.apply(outerIterator.next()); + if (currInnerIterator.hasNext()) { + return; } } - return null; } @Override public boolean hasNext() { - return iterator != null && iterator.hasNext(); + // closes the current iterator if it is finished, and opens a new non-empty iterator if possible + findNextIteratorIfNecessary(); + return currInnerIterator != null && currInnerIterator.hasNext(); } @Override @@ -101,21 +100,16 @@ public R next() if (!hasNext()) { throw new NoSuchElementException(); } - try { - return iterator.next(); - } - finally { - findNextIteratorIfNecessary(); - } + return currInnerIterator.next(); } @Override public void close() throws IOException { - delegate.close(); - if (iterator != null) { - iterator.close(); - iterator = null; + outerIterator.close(); + if (currInnerIterator != null) { + currInnerIterator.close(); + currInnerIterator = null; } } }; diff --git a/processing/src/test/java/org/apache/druid/data/input/impl/InputEntityIteratingReaderTest.java b/processing/src/test/java/org/apache/druid/data/input/impl/InputEntityIteratingReaderTest.java index 9e175b2a3dfe..a33899b25354 100644 --- a/processing/src/test/java/org/apache/druid/data/input/impl/InputEntityIteratingReaderTest.java +++ b/processing/src/test/java/org/apache/druid/data/input/impl/InputEntityIteratingReaderTest.java @@ -137,9 +137,11 @@ protected int getMaxRetries() ).iterator(), temporaryFolder.newFolder() ); - String expectedMessage = "Error occurred while trying to read uri: testscheme://some/path"; - Exception exception = Assert.assertThrows(RuntimeException.class, firehose::read); - Assert.assertTrue(exception.getMessage().contains(expectedMessage)); + try (CloseableIterator readIterator = firehose.read()) { + String expectedMessage = "Error occurred while trying to read uri: testscheme://some/path"; + Exception exception = Assert.assertThrows(RuntimeException.class, readIterator::hasNext); + Assert.assertTrue(exception.getMessage().contains(expectedMessage)); + } } } diff --git a/processing/src/test/java/org/apache/druid/java/util/common/parsers/CloseableIteratorTest.java b/processing/src/test/java/org/apache/druid/java/util/common/parsers/CloseableIteratorTest.java index be2d1d58bd5c..3f701ee92f6f 100644 --- a/processing/src/test/java/org/apache/druid/java/util/common/parsers/CloseableIteratorTest.java +++ b/processing/src/test/java/org/apache/druid/java/util/common/parsers/CloseableIteratorTest.java @@ -120,6 +120,47 @@ public void testFlatMapClosedEarly() throws IOException } } + @Test + public void testFlatMapInnerClose() throws IOException + { + List> innerIterators = new ArrayList<>(); + // the nested iterators is : [ [], [0], [0, 1] ] + try (final CloseTrackingCloseableIterator actual = new CloseTrackingCloseableIterator<>( + generateTestIterator(3) + .flatMap(list -> { + CloseTrackingCloseableIterator inner = + new CloseTrackingCloseableIterator<>(CloseableIterators.withEmptyBaggage(list.iterator())); + innerIterators.add(inner); + return inner; + }) + )) { + final Iterator expected = IntStream + .range(0, 3) + .flatMap(i -> IntStream.range(0, i)) + .iterator(); + + int iterCount = 0, innerIteratorIdx = 0; + while (actual.hasNext()) { + iterCount++; + if (iterCount == 1) { + Assert.assertEquals(2, innerIterators.size()); //empty iterator and single element iterator + innerIteratorIdx++; + } else if (iterCount == 2) { + Assert.assertEquals(3, innerIterators.size()); //empty iterator + single element iterator + double element iterator + innerIteratorIdx++; + } + Assert.assertEquals(expected.next(), actual.next()); // assert expected value to the iterator's value + for (int i = 0; i < innerIteratorIdx; i++) { + Assert.assertEquals(1, innerIterators.get(i).closeCount); // expect all previous iterators to be closed + } + // never expect the current iterator to be closed, even after doing the last next call on it + Assert.assertEquals(0, innerIterators.get(innerIteratorIdx).closeCount); + } + } + // check the last inner iterator is closed + Assert.assertEquals(1, innerIterators.get(2).closeCount); + } + private static CloseableIterator> generateTestIterator(int numIterates) { return new CloseableIterator>() From 973fbaf962fbf5d442c511d39ee077b4d2ff5053 Mon Sep 17 00:00:00 2001 From: Karan Kumar Date: Mon, 18 Sep 2023 01:41:58 +0530 Subject: [PATCH 10/10] Adding addition logging for taskIdReady in MSQ for debugging lock races. (#14998) --- .../java/org/apache/druid/msq/indexing/MSQControllerTask.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java index 5a7c0abbfda8..43967e7d748a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java @@ -43,6 +43,7 @@ import org.apache.druid.indexing.common.task.Tasks; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerContext; import org.apache.druid.msq.exec.ControllerImpl; @@ -69,6 +70,7 @@ public class MSQControllerTask extends AbstractTask implements ClientTaskQuery { public static final String TYPE = "query_controller"; public static final String DUMMY_DATASOURCE_FOR_SELECT = "__query_select"; + private static final Logger log = new Logger(MSQControllerTask.class); private final MSQSpec querySpec; @@ -204,7 +206,7 @@ public boolean isReady(TaskActionClient taskActionClient) throws Exception if (isIngestion(querySpec) && ((DataSourceMSQDestination) querySpec.getDestination()).isReplaceTimeChunks()) { final List intervals = ((DataSourceMSQDestination) querySpec.getDestination()).getReplaceTimeChunks(); - + log.debug("Task[%s] trying to acquire[%s] locks for intervals[%s] to become ready", getId(), TaskLockType.EXCLUSIVE, intervals); for (final Interval interval : intervals) { final TaskLock taskLock = taskActionClient.submit(new TimeChunkLockTryAcquireAction(TaskLockType.EXCLUSIVE, interval));