Skip to content

Commit

Permalink
Don't use ComplexMetricExtractor to fetch the class of the object in …
Browse files Browse the repository at this point in the history
…field readers (#16825)

This patch fixes queries like `SELECT COUNT(DISTINCT json_col) FROM foo`
  • Loading branch information
LakshSingla authored Aug 5, 2024
1 parent 0411c4e commit c84e689
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,34 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.impl.JsonInputFormat;
import org.apache.druid.data.input.impl.LocalInputSource;
import org.apache.druid.data.input.impl.systemfield.SystemFields;
import org.apache.druid.guice.BuiltInTypesModule;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.NestedDataTestUtils;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.NotDimFilter;
import org.apache.druid.query.filter.NullFilter;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.GroupByQueryConfig;
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
Expand All @@ -44,8 +58,10 @@
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.ColumnMapping;
import org.apache.druid.sql.calcite.planner.ColumnMappings;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.timeline.SegmentId;
import org.apache.druid.utils.CompressionUtils;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
Expand All @@ -70,6 +86,7 @@ public class MSQComplexGroupByTest extends MSQTestBase
private String dataFileNameJsonString;
private String dataFileSignatureJsonString;
private DataSource dataFileExternalDataSource;
private File dataFile;

public static Collection<Object[]> data()
{
Expand All @@ -85,9 +102,9 @@ public static Collection<Object[]> data()
@BeforeEach
public void setup() throws IOException
{
File dataFile = newTempFile("dataFile");
dataFile = newTempFile("dataFile");
final InputStream resourceStream = this.getClass().getClassLoader()
.getResourceAsStream(NestedDataTestUtils.ALL_TYPES_TEST_DATA_FILE);
.getResourceAsStream(NestedDataTestUtils.ALL_TYPES_TEST_DATA_FILE);
final InputStream decompressing = CompressionUtils.decompress(
resourceStream,
"nested-all-types-test-data.json"
Expand Down Expand Up @@ -416,4 +433,185 @@ public void testSortingOnNestedData(String contextName, Map<String, Object> cont
))
.verifyResults();
}

@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testExactCountDistinctOnNestedData(String contextName, Map<String, Object> context)
{
Assumptions.assumeTrue(NullHandling.sqlCompatible());
RowSignature rowSignature = RowSignature.builder()
.add("distinct_obj", ColumnType.LONG)
.build();

Map<String, Object> modifiedContext = ImmutableMap.<String, Object>builder()
.putAll(context)
.put(PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false)
.build();

DimFilter innerFilter = NullHandling.replaceWithDefault()
? new SelectorDimFilter("d0", null, null)
: new NullFilter("d0", null);

testSelectQuery().setSql("SELECT\n"
+ " COUNT(DISTINCT obj) AS distinct_obj\n"
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{ \"files\": [" + dataFileNameJsonString + "],\"type\":\"local\"}',\n"
+ " '{\"type\": \"json\"}',\n"
+ " '[{\"name\": \"timestamp\", \"type\": \"STRING\"}, {\"name\": \"obj\", \"type\": \"COMPLEX<json>\"}]'\n"
+ " )\n"
+ " )\n"
+ " ORDER BY 1")
.setQueryContext(ImmutableMap.of(PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false))
.setExpectedMSQSpec(
MSQSpec
.builder()
.query(
GroupByQuery
.builder()
.setDataSource(
new QueryDataSource(
GroupByQuery
.builder()
.setDataSource(dataFileExternalDataSource)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(
new DefaultDimensionSpec("obj", "d0", ColumnType.NESTED_DATA)
)
.setGranularity(Granularities.ALL)
.setContext(modifiedContext)
.build()
)
)
.setAggregatorSpecs(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
new NotDimFilter(innerFilter),
"a0"
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setGranularity(Granularities.ALL)
.setLimitSpec(new DefaultLimitSpec(
ImmutableList.of(
new OrderByColumnSpec(
"a0",
OrderByColumnSpec.Direction.ASCENDING,
StringComparators.NUMERIC
)
),
Integer.MAX_VALUE
))
.setContext(modifiedContext)
.build()
)
.columnMappings(new ColumnMappings(ImmutableList.of(
new ColumnMapping("a0", "distinct_obj")
)))
.tuningConfig(MSQTuningConfig.defaultConfig())
.destination(TaskReportMSQDestination.INSTANCE)
.build()
)
.setExpectedRowSignature(rowSignature)
.setQueryContext(modifiedContext)
.setExpectedResultRows(ImmutableList.of(
new Object[]{7L}
))
.verifyResults();
}

@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testExactCountDistinctOnNestedData2(String contextName, Map<String, Object> context)
{
Assumptions.assumeTrue(NullHandling.sqlCompatible());
RowSignature dataFileSignature = RowSignature.builder()
.add("timestamp", ColumnType.STRING)
.add("cObj", ColumnType.NESTED_DATA)
.build();
DataSource dataFileExternalDataSource2 = new ExternalDataSource(
new LocalInputSource(null, null, ImmutableList.of(dataFile), SystemFields.none()),
new JsonInputFormat(null, null, null, null, null),
dataFileSignature
);
RowSignature rowSignature = RowSignature.builder()
.add("distinct_obj", ColumnType.LONG)
.build();

Map<String, Object> modifiedContext = ImmutableMap.<String, Object>builder()
.putAll(context)
.put(PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false)
.build();

DimFilter innerFilter = NullHandling.replaceWithDefault()
? new SelectorDimFilter("d0", null, null)
: new NullFilter("d0", null);

testSelectQuery().setSql("SELECT\n"
+ " COUNT(DISTINCT cObj) AS distinct_obj\n"
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{ \"files\": [" + dataFileNameJsonString + "],\"type\":\"local\"}',\n"
+ " '{\"type\": \"json\"}',\n"
+ " '[{\"name\": \"timestamp\", \"type\": \"STRING\"}, {\"name\": \"cObj\", \"type\": \"COMPLEX<json>\"}]'\n"
+ " )\n"
+ " )\n"
+ " ORDER BY 1")
.setQueryContext(ImmutableMap.of(PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false))
.setExpectedMSQSpec(
MSQSpec
.builder()
.query(
GroupByQuery
.builder()
.setDataSource(
new QueryDataSource(
GroupByQuery
.builder()
.setDataSource(dataFileExternalDataSource2)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(
new DefaultDimensionSpec("cObj", "d0", ColumnType.NESTED_DATA)
)
.setGranularity(Granularities.ALL)
.setContext(modifiedContext)
.build()
)
)
.setAggregatorSpecs(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
new NotDimFilter(innerFilter),
"a0"
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setGranularity(Granularities.ALL)
.setLimitSpec(new DefaultLimitSpec(
ImmutableList.of(
new OrderByColumnSpec(
"a0",
OrderByColumnSpec.Direction.ASCENDING,
StringComparators.NUMERIC
)
),
Integer.MAX_VALUE
))
.setContext(modifiedContext)
.build()
)
.columnMappings(new ColumnMappings(ImmutableList.of(
new ColumnMapping("a0", "distinct_obj")
)))
.tuningConfig(MSQTuningConfig.defaultConfig())
.destination(TaskReportMSQDestination.INSTANCE)
.build()
)
.setExpectedRowSignature(rowSignature)
.setQueryContext(modifiedContext)
.setExpectedResultRows(ImmutableList.of(
new Object[]{1L}
))
.verifyResults();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,16 @@ private static class Selector<T> extends ObjectColumnSelector<T>
private final Memory memory;
private final ReadableFieldPointer fieldPointer;
private final ComplexMetricSerde serde;
@SuppressWarnings("rawtypes")
private final Class clazz;

private Selector(Memory memory, ReadableFieldPointer fieldPointer, ComplexMetricSerde serde)
{
this.memory = memory;
this.fieldPointer = fieldPointer;
this.serde = serde;
//noinspection deprecation
this.clazz = serde.getObjectStrategy().getClazz();
}

@Nullable
Expand All @@ -169,7 +173,8 @@ public T getObject()
@Override
public Class<T> classOfObject()
{
return serde.getExtractor().extractedClass();
//noinspection unchecked
return clazz;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ private static class ComplexFrameColumn extends ObjectColumnAccessorBase impleme
{
private final Frame frame;
private final ComplexMetricSerde serde;
private final Class<?> clazz;
private final Memory memory;
private final long startOfOffsetSection;
private final long startOfDataSection;
Expand All @@ -138,6 +139,8 @@ private ComplexFrameColumn(
{
this.frame = frame;
this.serde = serde;
//noinspection deprecation
this.clazz = serde.getObjectStrategy().getClazz();
this.memory = memory;
this.startOfOffsetSection = startOfOffsetSection;
this.startOfDataSection = startOfDataSection;
Expand All @@ -158,7 +161,7 @@ public Object getObject()
@Override
public Class<?> classOfObject()
{
return serde.getExtractor().extractedClass();
return clazz;
}

@Override
Expand Down

0 comments on commit c84e689

Please sign in to comment.