From 4123f2ca900761d4f5f6ed193d9fb378f0defbda Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Mon, 30 Sep 2024 02:20:30 -0700 Subject: [PATCH] add multi-value string object vector matcher and expression vector object selectors (#17162) (#17165) --- ...tiValueStringObjectVectorValueMatcher.java | 128 ++++++++++++++++++ ...torValueMatcherColumnProcessorFactory.java | 3 + .../segment/filter/ExpressionFilter.java | 6 +- ...ExpressionMultiValueDimensionSelector.java | 10 ++ ...nVectorMultiValueStringObjectSelector.java | 81 +++++++++++ .../virtual/ExpressionVectorSelectors.java | 18 ++- .../virtual/ExpressionVirtualColumn.java | 2 +- .../druid/segment/filter/BaseFilterTest.java | 6 +- .../segment/filter/EqualityFilterTests.java | 94 +++++++++++++ 9 files changed, 341 insertions(+), 7 deletions(-) create mode 100644 processing/src/main/java/org/apache/druid/query/filter/vector/MultiValueStringObjectVectorValueMatcher.java create mode 100644 processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorMultiValueStringObjectSelector.java diff --git a/processing/src/main/java/org/apache/druid/query/filter/vector/MultiValueStringObjectVectorValueMatcher.java b/processing/src/main/java/org/apache/druid/query/filter/vector/MultiValueStringObjectVectorValueMatcher.java new file mode 100644 index 000000000000..3e6c32efee9c --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/filter/vector/MultiValueStringObjectVectorValueMatcher.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.filter.vector; + +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExpressionType; +import org.apache.druid.query.filter.DruidObjectPredicate; +import org.apache.druid.query.filter.DruidPredicateFactory; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.vector.VectorObjectSelector; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +public class MultiValueStringObjectVectorValueMatcher implements VectorValueMatcherFactory +{ + protected final VectorObjectSelector selector; + + public MultiValueStringObjectVectorValueMatcher(final VectorObjectSelector selector) + { + this.selector = selector; + } + + @Override + public VectorValueMatcher makeMatcher(@Nullable String value) + { + return new BaseVectorValueMatcher(selector) + { + final VectorMatch match = VectorMatch.wrap(new int[selector.getMaxVectorSize()]); + + @Override + public ReadableVectorMatch match(final ReadableVectorMatch mask, boolean includeUnknown) + { + final Object[] vector = selector.getObjectVector(); + final int[] selection = match.getSelection(); + + int numRows = 0; + + for (int i = 0; i < mask.getSelectionSize(); i++) { + final int rowNum = mask.getSelection()[i]; + final Object val = vector[rowNum]; + if (val instanceof List) { + for (Object o : (List) val) { + if ((o == null && includeUnknown) || Objects.equals(value, o)) { + selection[numRows++] = rowNum; + break; + } + } + } else { + if ((val == null && includeUnknown) || Objects.equals(value, val)) { + selection[numRows++] = rowNum; + } + } + } + + match.setSelectionSize(numRows); + return match; + } + }; + } + + @Override + public VectorValueMatcher makeMatcher(Object matchValue, ColumnType matchValueType) + { + final ExprEval eval = ExprEval.ofType(ExpressionType.fromColumnType(matchValueType), matchValue); + final ExprEval castForComparison = ExprEval.castForEqualityComparison(eval, ExpressionType.STRING); + if (castForComparison == null || castForComparison.asString() == null) { + return VectorValueMatcher.allFalseObjectMatcher(selector); + } + return makeMatcher(castForComparison.asString()); + } + + @Override + public VectorValueMatcher makeMatcher(DruidPredicateFactory predicateFactory) + { + final DruidObjectPredicate predicate = predicateFactory.makeStringPredicate(); + + return new BaseVectorValueMatcher(selector) + { + final VectorMatch match = VectorMatch.wrap(new int[selector.getMaxVectorSize()]); + + @Override + public ReadableVectorMatch match(final ReadableVectorMatch mask, boolean includeUnknown) + { + final Object[] vector = selector.getObjectVector(); + final int[] selection = match.getSelection(); + + int numRows = 0; + + for (int i = 0; i < mask.getSelectionSize(); i++) { + final int rowNum = mask.getSelection()[i]; + Object val = vector[rowNum]; + if (val instanceof List) { + for (Object o : (List) val) { + if (predicate.apply((String) o).matches(includeUnknown)) { + selection[numRows++] = rowNum; + break; + } + } + } else if (predicate.apply((String) val).matches(includeUnknown)) { + selection[numRows++] = rowNum; + } + } + + match.setSelectionSize(numRows); + return match; + } + }; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/filter/vector/VectorValueMatcherColumnProcessorFactory.java b/processing/src/main/java/org/apache/druid/query/filter/vector/VectorValueMatcherColumnProcessorFactory.java index 0d16ee24230b..d8cdd509e71a 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/vector/VectorValueMatcherColumnProcessorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/filter/vector/VectorValueMatcherColumnProcessorFactory.java @@ -99,6 +99,9 @@ public VectorValueMatcherFactory makeObjectProcessor( ) { if (capabilities.is(ValueType.STRING)) { + if (capabilities.hasMultipleValues().isTrue()) { + return new MultiValueStringObjectVectorValueMatcher(selector); + } return new StringObjectVectorValueMatcher(selector); } return new ObjectVectorValueMatcher(selector); diff --git a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java index ccd53e96bc29..dcd818489ca4 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java @@ -121,18 +121,18 @@ public VectorValueMatcher makeVectorMatcher(VectorColumnSelectorFactory factory) case STRING: return VectorValueMatcherColumnProcessorFactory.instance().makeObjectProcessor( ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities(), - ExpressionVectorSelectors.makeVectorObjectSelector(factory, theExpr) + ExpressionVectorSelectors.makeVectorObjectSelector(factory, theExpr, null) ).makeMatcher(predicateFactory); case ARRAY: return VectorValueMatcherColumnProcessorFactory.instance().makeObjectProcessor( ColumnCapabilitiesImpl.createDefault().setType(ExpressionType.toColumnType(outputType)).setHasNulls(true), - ExpressionVectorSelectors.makeVectorObjectSelector(factory, theExpr) + ExpressionVectorSelectors.makeVectorObjectSelector(factory, theExpr, null) ).makeMatcher(predicateFactory); default: if (ExpressionType.NESTED_DATA.equals(outputType)) { return VectorValueMatcherColumnProcessorFactory.instance().makeObjectProcessor( ColumnCapabilitiesImpl.createDefault().setType(ExpressionType.toColumnType(outputType)).setHasNulls(true), - ExpressionVectorSelectors.makeVectorObjectSelector(factory, theExpr) + ExpressionVectorSelectors.makeVectorObjectSelector(factory, theExpr, null) ).makeMatcher(predicateFactory); } throw new UOE("Vectorized expression matchers not implemented for type: [%s]", outputType); diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionMultiValueDimensionSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionMultiValueDimensionSelector.java index 031bc46cc1de..dd70b3566e16 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionMultiValueDimensionSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionMultiValueDimensionSelector.java @@ -76,10 +76,14 @@ String getValue(ExprEval evaluated) return evaluated.asString(); } + @Nullable List getArrayAsList(ExprEval evaluated) { assert evaluated.isArray(); //noinspection ConstantConditions + if (evaluated.asArray() == null) { + return null; + } return Arrays.stream(evaluated.asArray()) .map(Evals::asString) .collect(Collectors.toList()); @@ -133,6 +137,9 @@ public boolean matches(boolean includeUnknown) ExprEval evaluated = getEvaluated(); if (evaluated.isArray()) { List array = getArrayAsList(evaluated); + if (array == null) { + return includeUnknown || value == null; + } return array.stream().anyMatch(x -> (includeUnknown && x == null) || Objects.equals(x, value)); } final String rowValue = getValue(evaluated); @@ -159,6 +166,9 @@ public boolean matches(boolean includeUnknown) final DruidObjectPredicate predicate = predicateFactory.makeStringPredicate(); if (evaluated.isArray()) { List array = getArrayAsList(evaluated); + if (array == null) { + return predicate.apply(null).matches(includeUnknown); + } return array.stream().anyMatch(x -> predicate.apply(x).matches(includeUnknown)); } final String rowValue = getValue(evaluated); diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorMultiValueStringObjectSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorMultiValueStringObjectSelector.java new file mode 100644 index 000000000000..8c2ce93dba70 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorMultiValueStringObjectSelector.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.virtual; + +import com.google.common.base.Preconditions; +import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.vector.ExprVectorProcessor; +import org.apache.druid.segment.vector.ReadableVectorInspector; +import org.apache.druid.segment.vector.VectorObjectSelector; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; + +import java.util.Arrays; + +public class ExpressionVectorMultiValueStringObjectSelector implements VectorObjectSelector +{ + private final Expr.VectorInputBinding bindings; + private final ExprVectorProcessor processor; + + @MonotonicNonNull + private Object[] cached; + private int currentId = ReadableVectorInspector.NULL_ID; + + public ExpressionVectorMultiValueStringObjectSelector( + ExprVectorProcessor processor, + Expr.VectorInputBinding bindings + ) + { + this.processor = Preconditions.checkNotNull(processor, "processor"); + this.bindings = Preconditions.checkNotNull(bindings, "bindings"); + this.cached = new Object[bindings.getMaxVectorSize()]; + } + + @Override + public Object[] getObjectVector() + { + if (bindings.getCurrentVectorId() != currentId) { + currentId = bindings.getCurrentVectorId(); + final Object[] tmp = processor.evalVector(bindings).getObjectVector(); + for (int i = 0; i < bindings.getCurrentVectorSize(); i++) { + Object[] tmpi = (Object[]) tmp[i]; + if (tmpi == null) { + cached[i] = null; + } else if (tmpi.length == 1) { + cached[i] = tmpi[0]; + } else { + cached[i] = Arrays.asList(tmpi); + } + } + } + return cached; + } + + @Override + public int getMaxVectorSize() + { + return bindings.getMaxVectorSize(); + } + + @Override + public int getCurrentVectorSize() + { + return bindings.getCurrentVectorSize(); + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorSelectors.java index 8578be228d59..c0776a2356c5 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorSelectors.java @@ -21,6 +21,7 @@ import com.google.common.base.Preconditions; import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprType; import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.math.expr.InputBindings; @@ -33,6 +34,8 @@ import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.column.Types; +import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.vector.ConstantVectorSelectors; import org.apache.druid.segment.vector.ReadableVectorInspector; import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; @@ -94,21 +97,32 @@ public static VectorValueSelector makeVectorValueSelector( public static VectorObjectSelector makeVectorObjectSelector( VectorColumnSelectorFactory factory, - Expr expression + Expr expression, + @Nullable ColumnType outputTypeHint ) { final ExpressionPlan plan = ExpressionPlanner.plan(factory, expression); Preconditions.checkArgument(plan.is(ExpressionPlan.Trait.VECTORIZABLE)); if (plan.isConstant()) { + final ExprEval eval = plan.getExpression().eval(InputBindings.nilBindings()); + if (Types.is(outputTypeHint, ValueType.STRING) && eval.type().isArray()) { + return ConstantVectorSelectors.vectorObjectSelector( + factory.getReadableVectorInspector(), + ExpressionSelectors.coerceEvalToObjectOrList(eval) + ); + } return ConstantVectorSelectors.vectorObjectSelector( factory.getReadableVectorInspector(), - plan.getExpression().eval(InputBindings.nilBindings()).valueOrDefault() + eval.valueOrDefault() ); } final Expr.VectorInputBinding bindings = createVectorBindings(plan.getAnalysis(), factory); final ExprVectorProcessor processor = plan.getExpression().asVectorProcessor(bindings); + if (Types.is(outputTypeHint, ValueType.STRING) && processor.getOutputType().isArray()) { + return new ExpressionVectorMultiValueStringObjectSelector(processor, bindings); + } return new ExpressionVectorObjectSelector(processor, bindings); } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java index e6f4d57e1d57..8bb62128f9da 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java @@ -238,7 +238,7 @@ public VectorObjectSelector makeVectorObjectSelector(String columnName, VectorCo return factory.makeObjectSelector(parsedExpression.get().getBindingIfIdentifier()); } - return ExpressionVectorSelectors.makeVectorObjectSelector(factory, parsedExpression.get()); + return ExpressionVectorSelectors.makeVectorObjectSelector(factory, parsedExpression.get(), expression.outputType); } @Nullable diff --git a/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java b/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java index 2e3ae0b633ff..ed4ff921a9c5 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java @@ -163,7 +163,11 @@ public abstract class BaseFilterTest extends InitializedNullHandlingTest new NestedFieldVirtualColumn("nested", "$.l0", "nested.l0", ColumnType.LONG), new NestedFieldVirtualColumn("nested", "$.arrayLong", "nested.arrayLong", ColumnType.LONG_ARRAY), new NestedFieldVirtualColumn("nested", "$.arrayDouble", "nested.arrayDouble", ColumnType.DOUBLE_ARRAY), - new NestedFieldVirtualColumn("nested", "$.arrayString", "nested.arrayString", ColumnType.STRING_ARRAY) + new NestedFieldVirtualColumn("nested", "$.arrayString", "nested.arrayString", ColumnType.STRING_ARRAY), + new ExpressionVirtualColumn("arrayLongAsMvd", "array_to_mv(arrayLong)", ColumnType.STRING, TestExprMacroTable.INSTANCE), + new ExpressionVirtualColumn("arrayDoubleAsMvd", "array_to_mv(arrayDouble)", ColumnType.STRING, TestExprMacroTable.INSTANCE), + new ExpressionVirtualColumn("arrayStringAsMvd", "array_to_mv(arrayString)", ColumnType.STRING, TestExprMacroTable.INSTANCE), + new ExpressionVirtualColumn("arrayConstantAsMvd", "array_to_mv(array(1,2,3))", ColumnType.STRING, TestExprMacroTable.INSTANCE) ) ); diff --git a/processing/src/test/java/org/apache/druid/segment/filter/EqualityFilterTests.java b/processing/src/test/java/org/apache/druid/segment/filter/EqualityFilterTests.java index 0259a2c34327..9e29ff266b1e 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/EqualityFilterTests.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/EqualityFilterTests.java @@ -1629,6 +1629,100 @@ public void testNestedColumnEquality() : ImmutableList.of("0", "1", "2", "3", "4", "5") ); } + + @Test + public void testArraysAsMvds() + { + Assume.assumeTrue(canTestArrayColumns()); + /* + dim0 .. arrayString arrayLong arrayDouble + "0", .. ["a", "b", "c"], [1L, 2L, 3L], [1.1, 2.2, 3.3] + "1", .. [], [], [1.1, 2.2, 3.3] + "2", .. null, [1L, 2L, 3L], [null] + "3", .. ["a", "b", "c"], null, [] + "4", .. ["c", "d"], [null], [-1.1, -333.3] + "5", .. [null], [123L, 345L], null + */ + + assertFilterMatches( + new EqualityFilter( + "arrayStringAsMvd", + ColumnType.STRING, + "b", + null + ), + ImmutableList.of("0", "3") + ); + assertFilterMatches( + NotDimFilter.of( + new EqualityFilter( + "arrayStringAsMvd", + ColumnType.STRING, + "b", + null + ) + ), + NullHandling.sqlCompatible() + ? ImmutableList.of("1", "4") + : ImmutableList.of("1", "2", "4", "5") + ); + + assertFilterMatches( + new EqualityFilter( + "arrayLongAsMvd", + ColumnType.STRING, + "2", + null + ), + ImmutableList.of("0", "2") + ); + assertFilterMatches( + NotDimFilter.of( + new EqualityFilter( + "arrayLongAsMvd", + ColumnType.STRING, + "2", + null + ) + ), + NullHandling.sqlCompatible() + ? ImmutableList.of("1", "5") + : ImmutableList.of("1", "3", "4", "5") + ); + + assertFilterMatches( + new EqualityFilter( + "arrayDoubleAsMvd", + ColumnType.STRING, + "3.3", + null + ), + ImmutableList.of("0", "1") + ); + assertFilterMatches( + NotDimFilter.of( + new EqualityFilter( + "arrayDoubleAsMvd", + ColumnType.STRING, + "3.3", + null + ) + ), + NullHandling.sqlCompatible() + ? ImmutableList.of("3", "4") + : ImmutableList.of("2", "3", "4", "5") + ); + + assertFilterMatches( + new EqualityFilter( + "arrayConstantAsMvd", + ColumnType.STRING, + "3", + null + ), + ImmutableList.of("0", "1", "2", "3", "4", "5") + ); + } } public static class EqualityFilterNonParameterizedTests extends InitializedNullHandlingTest