Skip to content

Commit

Permalink
add native 'array contains element' filter (apache#15366)
Browse files Browse the repository at this point in the history
* add native arrayContainsElement filter to use array column element indexes
  • Loading branch information
clintropolis authored Nov 29, 2023
1 parent 0a56c87 commit 64fcb32
Show file tree
Hide file tree
Showing 24 changed files with 2,338 additions and 950 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,24 @@ public class SqlBenchmark
"SELECT APPROX_COUNT_DISTINCT_BUILTIN(dimZipf) FROM foo",
"SELECT APPROX_COUNT_DISTINCT_DS_HLL(dimZipf) FROM foo",
"SELECT APPROX_COUNT_DISTINCT_DS_HLL_UTF8(dimZipf) FROM foo",
"SELECT APPROX_COUNT_DISTINCT_DS_THETA(dimZipf) FROM foo"
"SELECT APPROX_COUNT_DISTINCT_DS_THETA(dimZipf) FROM foo",
// 32: LATEST aggregator long
"SELECT LATEST(long1) FROM foo",
// 33: LATEST aggregator double
"SELECT LATEST(double4) FROM foo",
// 34: LATEST aggregator double
"SELECT LATEST(float3) FROM foo",
// 35: LATEST aggregator double
"SELECT LATEST(float3), LATEST(long1), LATEST(double4) FROM foo",
// 36,37: filter numeric nulls
"SELECT SUM(long5) FROM foo WHERE long5 IS NOT NULL",
"SELECT string2, SUM(long5) FROM foo WHERE long5 IS NOT NULL GROUP BY 1",
// 38: EARLIEST aggregator long
"SELECT EARLIEST(long1) FROM foo",
// 39: EARLIEST aggregator double
"SELECT EARLIEST(double4) FROM foo",
// 40: EARLIEST aggregator float
"SELECT EARLIEST(float3) FROM foo"
);

@Param({"5000000"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer;
Expand All @@ -31,10 +32,12 @@
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.segment.IndexSpec;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.generator.GeneratorBasicSchemas;
import org.apache.druid.segment.generator.GeneratorSchemaInfo;
import org.apache.druid.segment.generator.SegmentGenerator;
import org.apache.druid.segment.transform.TransformSpec;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.server.security.AuthConfig;
Expand Down Expand Up @@ -197,23 +200,8 @@ public String getFormatString()
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long4), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 37: time shift + expr agg (group by), uniform distribution high cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long5), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 38: LATEST aggregator long
"SELECT LATEST(long1) FROM foo",
// 39: LATEST aggregator double
"SELECT LATEST(double4) FROM foo",
// 40: LATEST aggregator double
"SELECT LATEST(float3) FROM foo",
// 41: LATEST aggregator double
"SELECT LATEST(float3), LATEST(long1), LATEST(double4) FROM foo",
// 42,43: filter numeric nulls
"SELECT SUM(long5) FROM foo WHERE long5 IS NOT NULL",
"SELECT string2, SUM(long5) FROM foo WHERE long5 IS NOT NULL GROUP BY 1",
// 44: EARLIEST aggregator long
"SELECT EARLIEST(long1) FROM foo",
// 45: EARLIEST aggregator double
"SELECT EARLIEST(double4) FROM foo",
// 46: EARLIEST aggregator float
"SELECT EARLIEST(float3) FROM foo"
// 38: array filtering
"SELECT string1, long1 FROM foo WHERE ARRAY_CONTAINS(\"multi-string3\", 100) GROUP BY 1,2"
);

@Param({"5000000"})
Expand All @@ -225,6 +213,12 @@ public String getFormatString()
})
private String vectorize;

@Param({
"explicit",
"auto"
})
private String schema;

@Param({
// non-expression reference
"0",
Expand Down Expand Up @@ -266,16 +260,7 @@ public String getFormatString()
"35",
"36",
"37",
"38",
"39",
"40",
"41",
"42",
"43",
"44",
"45",
"46",
"47"
"38"
})
private String query;

Expand All @@ -300,8 +285,21 @@ public void setup()
final PlannerConfig plannerConfig = new PlannerConfig();

final SegmentGenerator segmentGenerator = closer.register(new SegmentGenerator());
log.info("Starting benchmark setup using cacheDir[%s], rows[%,d].", segmentGenerator.getCacheDir(), rowsPerSegment);
final QueryableIndex index = segmentGenerator.generate(dataSegment, schemaInfo, Granularities.NONE, rowsPerSegment);
log.info("Starting benchmark setup using cacheDir[%s], rows[%,d], schema[%s].", segmentGenerator.getCacheDir(), rowsPerSegment, schema);
final QueryableIndex index;
if ("auto".equals(schema)) {
index = segmentGenerator.generate(
dataSegment,
schemaInfo,
DimensionsSpec.builder().useSchemaDiscovery(true).build(),
TransformSpec.NONE,
IndexSpec.DEFAULT,
Granularities.NONE,
rowsPerSegment
);
} else {
index = segmentGenerator.generate(dataSegment, schemaInfo, Granularities.NONE, rowsPerSegment);
}

final QueryRunnerFactoryConglomerate conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(
closer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.druid.data.input.InputStats;
import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.data.input.SplitHintSpec;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.SplittableInputSource;
import org.apache.druid.guice.IndexingServiceInputSourceModule;
import org.apache.druid.java.util.common.CloseableIterators;
Expand Down Expand Up @@ -179,7 +180,10 @@ public boolean hasNext()
public InputRow next()
{
rowCount++;
return generator.nextRow();
return MapInputRowParser.parse(
inputRowSchema,
generator.nextRaw(inputRowSchema.getTimestampSpec().getTimestampColumn())
);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.InputRowSchema;
import org.apache.druid.data.input.InputSourceReader;
import org.apache.druid.data.input.InputSplit;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.guice.IndexingServiceInputSourceModule;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.DateTimes;
Expand Down Expand Up @@ -128,11 +132,20 @@ public void testReader() throws IOException
timestampIncrement
);

InputSourceReader reader = inputSource.fixedFormatReader(null, null);
InputRowSchema rowSchema = new InputRowSchema(
new TimestampSpec(null, null, null),
DimensionsSpec.builder().useSchemaDiscovery(true).build(),
null
);

InputSourceReader reader = inputSource.fixedFormatReader(
rowSchema,
null
);
CloseableIterator<InputRow> iterator = reader.read();

InputRow first = iterator.next();
InputRow generatorFirst = generator.nextRow();
InputRow generatorFirst = MapInputRowParser.parse(rowSchema, generator.nextRaw(rowSchema.getTimestampSpec().getTimestampColumn()));
Assert.assertEquals(generatorFirst, first);
Assert.assertTrue(iterator.hasNext());
int i;
Expand All @@ -157,7 +170,7 @@ public void testSplits()
);

Assert.assertEquals(2, inputSource.estimateNumSplits(null, null));
Assert.assertEquals(false, inputSource.needsFormat());
Assert.assertFalse(inputSource.needsFormat());
Assert.assertEquals(2, inputSource.createSplits(null, null).count());
Assert.assertEquals(
new Long(2048L),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,13 @@ public static Number computeNumber(@Nullable String value)
@Nullable
public static ExprEval<?> castForEqualityComparison(ExprEval<?> valueToCompare, ExpressionType typeToCompareWith)
{
if (valueToCompare.isArray() && !typeToCompareWith.isArray()) {
final Object[] array = valueToCompare.asArray();
// cannot cast array to scalar if array length is greater than 1
if (array != null && array.length > 1) {
return null;
}
}
ExprEval<?> cast = valueToCompare.castTo(typeToCompareWith);
if (ExpressionType.LONG.equals(typeToCompareWith) && valueToCompare.asDouble() != cast.asDouble()) {
// make sure the DOUBLE value when cast to LONG is the same before and after the cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3327,11 +3327,11 @@ public void validateArguments(List<Expr> args)
@Override
public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
ExpressionType type = ExpressionType.LONG;
ExpressionType type = null;
for (Expr arg : args) {
type = ExpressionTypeConversion.function(type, arg.getOutputType(inspector));
type = ExpressionTypeConversion.leastRestrictiveType(type, arg.getOutputType(inspector));
}
return ExpressionType.asArrayType(type);
return type == null ? null : ExpressionTypeFactory.getInstance().ofArray(type);
}

/**
Expand Down
Loading

0 comments on commit 64fcb32

Please sign in to comment.