Skip to content

Commit

Permalink
Introduce natural comparator for types that don't have a StringCompar…
Browse files Browse the repository at this point in the history
…ator (apache#15145)

Fixes a bug when executing queries with the ordering of arrays
  • Loading branch information
LakshSingla authored Oct 16, 2023
1 parent 4b0d1b3 commit dc8d219
Show file tree
Hide file tree
Showing 13 changed files with 380 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,17 @@ private static void validateQuery(final GroupByQuery query)
}
}

/**
* Only allow ordering the queries from the MSQ engine, ignoring the comparator that is set in the query. This
* function checks if it is safe to do so, which is the case if the natural comparator is used for the dimension.
* Since MSQ executes the queries planned by the SQL layer, this is a sanity check as we always add the natural
* comparator for the dimensions there
*/
private static boolean isNaturalComparator(final ValueType type, final StringComparator comparator)
{
if (StringComparators.NATURAL.equals(comparator)) {
return true;
}
return ((type == ValueType.STRING && StringComparators.LEXICOGRAPHIC.equals(comparator))
|| (type.isNumeric() && StringComparators.NUMERIC.equals(comparator)))
&& !type.isArray();
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

import javax.annotation.Nonnull;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -2534,7 +2533,6 @@ public void testUnionAllUsingUnionDataSource()
.verifyResults();
}

@Nonnull
private List<Object[]> expectedMultiValueFooRowsGroup()
{
ArrayList<Object[]> expected = new ArrayList<>();
Expand All @@ -2553,7 +2551,6 @@ private List<Object[]> expectedMultiValueFooRowsGroup()
return expected;
}

@Nonnull
private List<Object[]> expectedMultiValueFooRowsGroupByList()
{
ArrayList<Object[]> expected = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ public StringComparatorModule()
new NamedType(StringComparators.AlphanumericComparator.class, StringComparators.ALPHANUMERIC_NAME),
new NamedType(StringComparators.StrlenComparator.class, StringComparators.STRLEN_NAME),
new NamedType(StringComparators.NumericComparator.class, StringComparators.NUMERIC_NAME),
new NamedType(StringComparators.VersionComparator.class, StringComparators.VERSION_NAME)
new NamedType(StringComparators.VersionComparator.class, StringComparators.VERSION_NAME),
new NamedType(StringComparators.NaturalComparator.class, StringComparators.NATURAL_NAME)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ public BoundDimFilter(
boolean orderingIsAlphanumeric = this.ordering.equals(StringComparators.ALPHANUMERIC);
Preconditions.checkState(
alphaNumeric == orderingIsAlphanumeric,
"mismatch between alphanumeric and ordering property");
"mismatch between alphanumeric and ordering property"
);
}
}
this.extractionFn = extractionFn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,11 @@ public static Grouper.BufferComparator makeNullHandlingBufferComparatorForNumeri

private static boolean isPrimitiveComparable(boolean pushLimitDown, @Nullable StringComparator stringComparator)
{
return !pushLimitDown || stringComparator == null || stringComparator.equals(StringComparators.NUMERIC);
return !pushLimitDown
|| stringComparator == null
|| stringComparator.equals(StringComparators.NUMERIC)
// NATURAL isn't set for numeric types, however if it is, then that would mean that we are ordering the
// numeric type with its natural comparator (which is NUMERIC)
|| stringComparator.equals(StringComparators.NATURAL);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,8 @@ private static int compareDimsInRowsWithAggs(
final StringComparator comparator = comparators.get(i);

final ColumnType fieldType = fieldTypes.get(i);
if (fieldType.isNumeric() && comparator.equals(StringComparators.NUMERIC)) {
if (fieldType.isNumeric()
&& (comparator.equals(StringComparators.NUMERIC) || comparator.equals(StringComparators.NATURAL))) {
// use natural comparison
if (fieldType.is(ValueType.DOUBLE)) {
// sometimes doubles can become floats making the round trip from serde, make sure to coerce them both
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public static StringComparator fromString(String type)
return StringComparators.NUMERIC;
case StringComparators.VERSION_NAME:
return StringComparators.VERSION;
case StringComparators.NATURAL_NAME:
return StringComparators.NATURAL;
default:
throw new IAE("Unknown string comparator[%s]", type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.google.common.collect.Ordering;
import com.google.common.primitives.Ints;
import org.apache.druid.common.guava.GuavaUtils;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.maven.artifact.versioning.DefaultArtifactVersion;

import java.math.BigDecimal;
Expand All @@ -34,25 +36,28 @@ public class StringComparators
public static final String NUMERIC_NAME = "numeric";
public static final String STRLEN_NAME = "strlen";
public static final String VERSION_NAME = "version";
public static final String NATURAL_NAME = "natural";

public static final StringComparator LEXICOGRAPHIC = new LexicographicComparator();
public static final StringComparator ALPHANUMERIC = new AlphanumericComparator();
public static final StringComparator NUMERIC = new NumericComparator();
public static final StringComparator STRLEN = new StrlenComparator();
public static final StringComparator VERSION = new VersionComparator();
public static final StringComparator NATURAL = new NaturalComparator();

public static final int LEXICOGRAPHIC_CACHE_ID = 0x01;
public static final int ALPHANUMERIC_CACHE_ID = 0x02;
public static final int NUMERIC_CACHE_ID = 0x03;
public static final int STRLEN_CACHE_ID = 0x04;
public static final int VERSION_CACHE_ID = 0x05;
public static final int NATURAL_CACHE_ID = 0x06;

/**
* Comparison using the natural comparator of {@link String}.
*
* Note that this is not equivalent to comparing UTF-8 byte arrays; see javadocs for
* {@link org.apache.druid.java.util.common.StringUtils#compareUnicode(String, String)} and
* {@link org.apache.druid.java.util.common.StringUtils#compareUtf8UsingJavaStringOrdering(byte[], byte[])}.
* {@link StringUtils#compareUnicode(String, String)} and
* {@link StringUtils#compareUtf8UsingJavaStringOrdering(byte[], byte[])}.
*/
public static class LexicographicComparator extends StringComparator
{
Expand Down Expand Up @@ -492,4 +497,51 @@ public byte[] getCacheKey()
return new byte[]{(byte) VERSION_CACHE_ID};
}
}

/**
* NaturalComparator refers to the natural ordering of the type that it refers.
*
* For example, if the type is Long, the natural ordering would be numeric
* if the type is an array, the natural ordering would be lexicographic comparison of the natural ordering of the
* elements in the arrays.
*
* It is a sigil value for the dimension that we can handle in the execution layer, and don't need the comparator for.
* It is also a placeholder for dimensions that we don't have a comparator for (like arrays), but is a required for
* planning
*/
public static class NaturalComparator extends StringComparator
{
@Override
public int compare(String o1, String o2)
{
throw DruidException.defensive("compare() should not be called for the NaturalComparator");
}

@Override
public String toString()
{
return StringComparators.NATURAL_NAME;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
return o != null && getClass() == o.getClass();
}

@Override
public int hashCode()
{
return 0;
}

@Override
public byte[] getCacheKey()
{
return new byte[]{(byte) NATURAL_CACHE_ID};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.apache.druid.error.DruidException;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -33,6 +34,8 @@

public class StringComparatorsTest
{
private static final ObjectMapper JSON_MAPPER = new DefaultObjectMapper();

private void commonTest(StringComparator comparator)
{
// equality test
Expand Down Expand Up @@ -156,65 +159,83 @@ public void testVersionComparator()
Assert.assertTrue(StringComparators.VERSION.compare("1.0-SNAPSHOT", "1.0-Final") < 0);
}

@Test
public void testNaturalComparator()
{
Assert.assertThrows(DruidException.class, () -> StringComparators.NATURAL.compare("str1", "str2"));
}

@Test
public void testLexicographicComparatorSerdeTest() throws IOException
{
ObjectMapper jsonMapper = new DefaultObjectMapper();
String expectJsonSpec = "{\"type\":\"lexicographic\"}";

String jsonSpec = jsonMapper.writeValueAsString(StringComparators.LEXICOGRAPHIC);
String jsonSpec = JSON_MAPPER.writeValueAsString(StringComparators.LEXICOGRAPHIC);
Assert.assertEquals(expectJsonSpec, jsonSpec);
Assert.assertEquals(StringComparators.LEXICOGRAPHIC, jsonMapper.readValue(expectJsonSpec, StringComparator.class));
Assert.assertEquals(StringComparators.LEXICOGRAPHIC, JSON_MAPPER.readValue(expectJsonSpec, StringComparator.class));

String makeFromJsonSpec = "\"lexicographic\"";
Assert.assertEquals(
StringComparators.LEXICOGRAPHIC,
jsonMapper.readValue(makeFromJsonSpec, StringComparator.class)
JSON_MAPPER.readValue(makeFromJsonSpec, StringComparator.class)
);
}

@Test
public void testAlphanumericComparatorSerdeTest() throws IOException
{
ObjectMapper jsonMapper = new DefaultObjectMapper();
String expectJsonSpec = "{\"type\":\"alphanumeric\"}";

String jsonSpec = jsonMapper.writeValueAsString(StringComparators.ALPHANUMERIC);
String jsonSpec = JSON_MAPPER.writeValueAsString(StringComparators.ALPHANUMERIC);
Assert.assertEquals(expectJsonSpec, jsonSpec);
Assert.assertEquals(StringComparators.ALPHANUMERIC, jsonMapper.readValue(expectJsonSpec, StringComparator.class));
Assert.assertEquals(StringComparators.ALPHANUMERIC, JSON_MAPPER.readValue(expectJsonSpec, StringComparator.class));

String makeFromJsonSpec = "\"alphanumeric\"";
Assert.assertEquals(StringComparators.ALPHANUMERIC, jsonMapper.readValue(makeFromJsonSpec, StringComparator.class));
Assert.assertEquals(StringComparators.ALPHANUMERIC, JSON_MAPPER.readValue(makeFromJsonSpec, StringComparator.class));
}

@Test
public void testStrlenComparatorSerdeTest() throws IOException
{
ObjectMapper jsonMapper = new DefaultObjectMapper();
String expectJsonSpec = "{\"type\":\"strlen\"}";

String jsonSpec = jsonMapper.writeValueAsString(StringComparators.STRLEN);
String jsonSpec = JSON_MAPPER.writeValueAsString(StringComparators.STRLEN);
Assert.assertEquals(expectJsonSpec, jsonSpec);
Assert.assertEquals(StringComparators.STRLEN, jsonMapper.readValue(expectJsonSpec, StringComparator.class));
Assert.assertEquals(StringComparators.STRLEN, JSON_MAPPER.readValue(expectJsonSpec, StringComparator.class));

String makeFromJsonSpec = "\"strlen\"";
Assert.assertEquals(StringComparators.STRLEN, jsonMapper.readValue(makeFromJsonSpec, StringComparator.class));
Assert.assertEquals(StringComparators.STRLEN, JSON_MAPPER.readValue(makeFromJsonSpec, StringComparator.class));
}

@Test
public void testNumericComparatorSerdeTest() throws IOException
{
ObjectMapper jsonMapper = new DefaultObjectMapper();
String expectJsonSpec = "{\"type\":\"numeric\"}";

String jsonSpec = jsonMapper.writeValueAsString(StringComparators.NUMERIC);
String jsonSpec = JSON_MAPPER.writeValueAsString(StringComparators.NUMERIC);
Assert.assertEquals(expectJsonSpec, jsonSpec);
Assert.assertEquals(StringComparators.NUMERIC, jsonMapper.readValue(expectJsonSpec, StringComparator.class));
Assert.assertEquals(StringComparators.NUMERIC, JSON_MAPPER.readValue(expectJsonSpec, StringComparator.class));

String makeFromJsonSpec = "\"numeric\"";
Assert.assertEquals(StringComparators.NUMERIC, jsonMapper.readValue(makeFromJsonSpec, StringComparator.class));
Assert.assertEquals(StringComparators.NUMERIC, JSON_MAPPER.readValue(makeFromJsonSpec, StringComparator.class));

makeFromJsonSpec = "\"NuMeRiC\"";
Assert.assertEquals(StringComparators.NUMERIC, jsonMapper.readValue(makeFromJsonSpec, StringComparator.class));
Assert.assertEquals(StringComparators.NUMERIC, JSON_MAPPER.readValue(makeFromJsonSpec, StringComparator.class));
}

@Test
public void testNaturalComparatorSerdeTest() throws IOException
{
String expectJsonSpec = "{\"type\":\"natural\"}";

String jsonSpec = JSON_MAPPER.writeValueAsString(StringComparators.NATURAL);
Assert.assertEquals(expectJsonSpec, jsonSpec);
Assert.assertEquals(StringComparators.NATURAL, JSON_MAPPER.readValue(expectJsonSpec, StringComparator.class));

String makeFromJsonSpec = "\"natural\"";
Assert.assertEquals(StringComparators.NATURAL, JSON_MAPPER.readValue(makeFromJsonSpec, StringComparator.class));

makeFromJsonSpec = "\"NaTuRaL\"";
Assert.assertEquals(StringComparators.NATURAL, JSON_MAPPER.readValue(makeFromJsonSpec, StringComparator.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public static BoundDimFilter lessThanOrEqualTo(final BoundRefKey boundRefKey, fi
public static BoundDimFilter interval(final BoundRefKey boundRefKey, final Interval interval)
{
if (!boundRefKey.getComparator().equals(StringComparators.NUMERIC)) {
// Interval comparison only works with NUMERIC comparator.
// Interval comparison only works with NUMERIC comparator
throw new ISE("Comparator must be NUMERIC but was[%s]", boundRefKey.getComparator());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.calcite.util.TimestampString;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.math.expr.ExpressionProcessingConfig;
Expand Down Expand Up @@ -208,20 +207,26 @@ public static boolean isLongType(SqlTypeName sqlTypeName)
SqlTypeName.INT_TYPES.contains(sqlTypeName);
}

/**
* Returns the natural StringComparator associated with the RelDataType
*/
public static StringComparator getStringComparatorForRelDataType(RelDataType dataType)
{
final ColumnType valueType = getColumnTypeForRelDataType(dataType);
return getStringComparatorForValueType(valueType);
}

/**
* Returns the natural StringComparator associated with the given ColumnType
*/
public static StringComparator getStringComparatorForValueType(ColumnType valueType)
{
if (valueType.isNumeric()) {
return StringComparators.NUMERIC;
} else if (valueType.is(ValueType.STRING)) {
return StringComparators.LEXICOGRAPHIC;
} else {
throw new ISE("Unrecognized valueType[%s]", valueType);
return StringComparators.NATURAL;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package org.apache.druid.sql.calcite.planner;

import com.google.common.collect.ImmutableSortedSet;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.util.CalciteTestBase;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -52,4 +54,18 @@ public void testFindUnusedPrefix()
Assert.assertEquals("x", Calcites.findUnusedPrefixForDigits("x", ImmutableSortedSet.of("foo", "xa", "_x")));
Assert.assertEquals("__x", Calcites.findUnusedPrefixForDigits("x", ImmutableSortedSet.of("foo", "x1a", "_x90")));
}

@Test
public void testGetStringComparatorForColumnType()
{
Assert.assertEquals(StringComparators.LEXICOGRAPHIC, Calcites.getStringComparatorForValueType(ColumnType.STRING));
Assert.assertEquals(StringComparators.NUMERIC, Calcites.getStringComparatorForValueType(ColumnType.LONG));
Assert.assertEquals(StringComparators.NUMERIC, Calcites.getStringComparatorForValueType(ColumnType.FLOAT));
Assert.assertEquals(StringComparators.NUMERIC, Calcites.getStringComparatorForValueType(ColumnType.DOUBLE));
Assert.assertEquals(StringComparators.NATURAL, Calcites.getStringComparatorForValueType(ColumnType.STRING_ARRAY));
Assert.assertEquals(StringComparators.NATURAL, Calcites.getStringComparatorForValueType(ColumnType.LONG_ARRAY));
Assert.assertEquals(StringComparators.NATURAL, Calcites.getStringComparatorForValueType(ColumnType.DOUBLE_ARRAY));
Assert.assertEquals(StringComparators.NATURAL, Calcites.getStringComparatorForValueType(ColumnType.NESTED_DATA));
Assert.assertEquals(StringComparators.NATURAL, Calcites.getStringComparatorForValueType(ColumnType.UNKNOWN_COMPLEX));
}
}

0 comments on commit dc8d219

Please sign in to comment.