Skip to content

Commit

Permalink
Enable rollup on multi-value dimensions for compaction with MSQ engine (
Browse files Browse the repository at this point in the history
apache#16937)

Currently compaction with MSQ engine doesn't work for rollup on multi-value dimensions (MVDs), the reason being the default behaviour of grouping on MVD dimensions to unnest the dimension values; for instance grouping on `[s1,s2]` with aggregate `a` will result in two rows: `<s1,a>` and `<s2,a>`. 

This change enables rollup on MVDs (without unnest) by converting MVDs to Arrays before rollup using virtual columns, and then converting them back to MVDs using post aggregators. If segment schema is available to the compaction task (when it ends up downloading segments to get existing dimensions/metrics/granularity), it selectively does the MVD-Array conversion only for known multi-valued columns; else it conservatively performs this conversion for all `string` columns.
  • Loading branch information
gargvishesh authored Sep 4, 2024
1 parent 76b8c20 commit e28424e
Show file tree
Hide file tree
Showing 11 changed files with 461 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,21 @@
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.expression.TimestampFloorExprMacro;
import org.apache.druid.query.expression.TimestampParseExprMacro;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.GroupByQueryConfig;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.indexing.CombinedDataSchema;
import org.apache.druid.segment.indexing.DataSchema;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.server.coordinator.CompactionConfigValidationResult;
Expand All @@ -82,6 +84,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand All @@ -92,12 +95,14 @@ public class MSQCompactionRunner implements CompactionRunner
private static final Granularity DEFAULT_SEGMENT_GRANULARITY = Granularities.ALL;

private final ObjectMapper jsonMapper;
private final ExprMacroTable exprMacroTable;
private final Injector injector;
// Needed as output column name while grouping in the scenario of:
// a) no query granularity -- to specify an output name for the time dimension column since __time is a reserved name.
// b) custom query granularity -- to create a virtual column containing the rounded-off row timestamp.
// In both cases, the new column is converted back to __time later using columnMappings.
public static final String TIME_VIRTUAL_COLUMN = "__vTime";
public static final String ARRAY_VIRTUAL_COLUMN_PREFIX = "__vArray_";

@JsonIgnore
private final CurrentSubTaskHolder currentSubTaskHolder = new CurrentSubTaskHolder(
Expand All @@ -108,9 +113,14 @@ public class MSQCompactionRunner implements CompactionRunner


@JsonCreator
public MSQCompactionRunner(@JacksonInject ObjectMapper jsonMapper, @JacksonInject Injector injector)
public MSQCompactionRunner(
@JacksonInject final ObjectMapper jsonMapper,
@JacksonInject final ExprMacroTable exprMacroTable,
@JacksonInject final Injector injector
)
{
this.jsonMapper = jsonMapper;
this.exprMacroTable = exprMacroTable;
this.injector = injector;
}

Expand Down Expand Up @@ -192,11 +202,12 @@ public List<MSQControllerTask> createMsqControllerTasks(
Query<?> query;
Interval interval = intervalDataSchema.getKey();
DataSchema dataSchema = intervalDataSchema.getValue();
Map<String, VirtualColumn> inputColToVirtualCol = getVirtualColumns(dataSchema, interval);

if (isGroupBy(dataSchema)) {
query = buildGroupByQuery(compactionTask, interval, dataSchema);
query = buildGroupByQuery(compactionTask, interval, dataSchema, inputColToVirtualCol);
} else {
query = buildScanQuery(compactionTask, interval, dataSchema);
query = buildScanQuery(compactionTask, interval, dataSchema, inputColToVirtualCol);
}
QueryContext compactionTaskContext = new QueryContext(compactionTask.getContext());

Expand Down Expand Up @@ -308,7 +319,10 @@ private static RowSignature getRowSignature(DataSchema dataSchema)
return rowSignatureBuilder.build();
}

private static List<DimensionSpec> getAggregateDimensions(DataSchema dataSchema)
private static List<DimensionSpec> getAggregateDimensions(
DataSchema dataSchema,
Map<String, VirtualColumn> inputColToVirtualCol
)
{
List<DimensionSpec> dimensionSpecs = new ArrayList<>();

Expand All @@ -319,14 +333,22 @@ private static List<DimensionSpec> getAggregateDimensions(DataSchema dataSchema)
// The changed granularity would result in a new virtual column that needs to be aggregated upon.
dimensionSpecs.add(new DefaultDimensionSpec(TIME_VIRTUAL_COLUMN, TIME_VIRTUAL_COLUMN, ColumnType.LONG));
}

dimensionSpecs.addAll(dataSchema.getDimensionsSpec().getDimensions().stream()
.map(dim -> new DefaultDimensionSpec(
dim.getName(),
dim.getName(),
dim.getColumnType()
))
.collect(Collectors.toList()));
// If virtual columns are created from dimensions, replace dimension columns names with virtual column names.
dimensionSpecs.addAll(
dataSchema.getDimensionsSpec().getDimensions().stream()
.map(dim -> {
String dimension = dim.getName();
ColumnType colType = dim.getColumnType();
if (inputColToVirtualCol.containsKey(dim.getName())) {
VirtualColumn virtualColumn = inputColToVirtualCol.get(dimension);
dimension = virtualColumn.getOutputName();
if (virtualColumn instanceof ExpressionVirtualColumn) {
colType = ((ExpressionVirtualColumn) virtualColumn).getOutputType();
}
}
return new DefaultDimensionSpec(dimension, dimension, colType);
})
.collect(Collectors.toList()));
return dimensionSpecs;
}

Expand Down Expand Up @@ -365,13 +387,19 @@ private static List<OrderByColumnSpec> getOrderBySpec(PartitionsSpec partitionSp
return Collections.emptyList();
}

private static Query<?> buildScanQuery(CompactionTask compactionTask, Interval interval, DataSchema dataSchema)
private static Query<?> buildScanQuery(
CompactionTask compactionTask,
Interval interval,
DataSchema dataSchema,
Map<String, VirtualColumn> inputColToVirtualCol
)
{
RowSignature rowSignature = getRowSignature(dataSchema);
VirtualColumns virtualColumns = VirtualColumns.create(new ArrayList<>(inputColToVirtualCol.values()));
Druids.ScanQueryBuilder scanQueryBuilder = new Druids.ScanQueryBuilder()
.dataSource(dataSchema.getDataSource())
.columns(rowSignature.getColumnNames())
.virtualColumns(getVirtualColumns(dataSchema, interval))
.virtualColumns(virtualColumns)
.columnTypes(rowSignature.getColumnTypes())
.intervals(new MultipleIntervalSegmentSpec(Collections.singletonList(interval)))
.filters(dataSchema.getTransformSpec().getFilter())
Expand Down Expand Up @@ -416,51 +444,115 @@ private static boolean isQueryGranularityEmptyOrNone(DataSchema dataSchema)
}

/**
* Creates a virtual timestamp column to create a new __time field according to the provided queryGranularity, as
* queryGranularity field itself is mandated to be ALL in MSQControllerTask.
* Conditionally creates below virtual columns
* <ul>
* <li>timestamp column (for custom queryGranularity): converts __time field in line with the provided
* queryGranularity, since the queryGranularity field itself in MSQControllerTask is mandated to be ALL.</li>
* <li>mv_to_array columns (for group-by queries): temporary columns that convert MVD columns to array to enable
* grouping on them without unnesting.</li>
* </ul>
*/
private static VirtualColumns getVirtualColumns(DataSchema dataSchema, Interval interval)
private Map<String, VirtualColumn> getVirtualColumns(DataSchema dataSchema, Interval interval)
{
if (isQueryGranularityEmptyOrNone(dataSchema)) {
return VirtualColumns.EMPTY;
Map<String, VirtualColumn> inputColToVirtualCol = new HashMap<>();
if (!isQueryGranularityEmptyOrNone(dataSchema)) {
// Round-off time field according to provided queryGranularity
String timeVirtualColumnExpr;
if (dataSchema.getGranularitySpec()
.getQueryGranularity()
.equals(Granularities.ALL)) {
// For ALL query granularity, all records in a segment are assigned the interval start timestamp of the segment.
// It's the same behaviour in native compaction.
timeVirtualColumnExpr = StringUtils.format("timestamp_parse('%s')", interval.getStart());
} else {
PeriodGranularity periodQueryGranularity = (PeriodGranularity) dataSchema.getGranularitySpec()
.getQueryGranularity();
// Round off the __time column according to the required granularity.
timeVirtualColumnExpr =
StringUtils.format(
"timestamp_floor(\"%s\", '%s')",
ColumnHolder.TIME_COLUMN_NAME,
periodQueryGranularity.getPeriod().toString()
);
}
inputColToVirtualCol.put(ColumnHolder.TIME_COLUMN_NAME, new ExpressionVirtualColumn(
TIME_VIRTUAL_COLUMN,
timeVirtualColumnExpr,
ColumnType.LONG,
exprMacroTable
));
}
String virtualColumnExpr;
if (dataSchema.getGranularitySpec()
.getQueryGranularity()
.equals(Granularities.ALL)) {
// For ALL query granularity, all records in a segment are assigned the interval start timestamp of the segment.
// It's the same behaviour in native compaction.
virtualColumnExpr = StringUtils.format("timestamp_parse('%s')", interval.getStart());
} else {
PeriodGranularity periodQueryGranularity = (PeriodGranularity) dataSchema.getGranularitySpec()
.getQueryGranularity();
// Round of the __time column according to the required granularity.
virtualColumnExpr =
StringUtils.format(
"timestamp_floor(\"%s\", '%s')",
ColumnHolder.TIME_COLUMN_NAME,
periodQueryGranularity.getPeriod().toString()
);
if (isGroupBy(dataSchema)) {
// Convert MVDs to arrays for grouping to avoid unnest, assuming all string cols to be MVDs.
Set<String> multiValuedColumns = dataSchema.getDimensionsSpec()
.getDimensions()
.stream()
.filter(dim -> dim.getColumnType().equals(ColumnType.STRING))
.map(DimensionSchema::getName)
.collect(Collectors.toSet());
if (dataSchema instanceof CombinedDataSchema &&
((CombinedDataSchema) dataSchema).getMultiValuedDimensions() != null) {
// Filter actual MVDs from schema info.
Set<String> multiValuedColumnsFromSchema =
((CombinedDataSchema) dataSchema).getMultiValuedDimensions();
multiValuedColumns = multiValuedColumns.stream()
.filter(multiValuedColumnsFromSchema::contains)
.collect(Collectors.toSet());
}

for (String dim : multiValuedColumns) {
String virtualColumnExpr = StringUtils.format("mv_to_array(\"%s\")", dim);
inputColToVirtualCol.put(
dim,
new ExpressionVirtualColumn(
ARRAY_VIRTUAL_COLUMN_PREFIX + dim,
virtualColumnExpr,
ColumnType.STRING_ARRAY,
exprMacroTable
)
);
}
}
return VirtualColumns.create(new ExpressionVirtualColumn(
TIME_VIRTUAL_COLUMN,
virtualColumnExpr,
ColumnType.LONG,
new ExprMacroTable(ImmutableList.of(new TimestampFloorExprMacro(), new TimestampParseExprMacro()))
));
return inputColToVirtualCol;
}

private static Query<?> buildGroupByQuery(CompactionTask compactionTask, Interval interval, DataSchema dataSchema)
private Query<?> buildGroupByQuery(
CompactionTask compactionTask,
Interval interval,
DataSchema dataSchema,
Map<String, VirtualColumn> inputColToVirtualCol
)
{
DimFilter dimFilter = dataSchema.getTransformSpec().getFilter();

VirtualColumns virtualColumns = VirtualColumns.create(new ArrayList<>(inputColToVirtualCol.values()));

// Convert MVDs converted to arrays back to MVDs, with the same name as the input column.
// This is safe since input column names no longer exist at post-aggregation stage.
List<PostAggregator> postAggregators =
inputColToVirtualCol.entrySet()
.stream()
.filter(entry -> !entry.getKey().equals(ColumnHolder.TIME_COLUMN_NAME))
.map(
entry ->
new ExpressionPostAggregator(
entry.getKey(),
StringUtils.format("array_to_mv(\"%s\")", entry.getValue().getOutputName()),
null,
ColumnType.STRING,
exprMacroTable
)
)
.collect(Collectors.toList());

GroupByQuery.Builder builder = new GroupByQuery.Builder()
.setDataSource(new TableDataSource(compactionTask.getDataSource()))
.setVirtualColumns(getVirtualColumns(dataSchema, interval))
.setVirtualColumns(virtualColumns)
.setDimFilter(dimFilter)
.setGranularity(new AllGranularity())
.setDimensions(getAggregateDimensions(dataSchema))
.setDimensions(getAggregateDimensions(dataSchema, inputColToVirtualCol))
.setAggregatorSpecs(Arrays.asList(dataSchema.getAggregators()))
.setPostAggregatorSpecs(postAggregators)
.setContext(compactionTask.getContext())
.setInterval(interval);

Expand Down
Loading

0 comments on commit e28424e

Please sign in to comment.