Skip to content

Commit

Permalink
Updating plans when using joins with unnest on the left (#15075)
Browse files Browse the repository at this point in the history
* Updating plans when using joins with unnest on the left

* Correcting segment map function for hashJoin

* The changes done here are not reflected into MSQ yet so these tests might not run in MSQ

* native tests

* Self joins with unnest data source

* Making this pass

* Addressing comments by adding explanation and new test
  • Loading branch information
somu-imply authored Oct 7, 2023
1 parent f943997 commit 57ab8e1
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 15 deletions.
64 changes: 50 additions & 14 deletions processing/src/main/java/org/apache/druid/query/JoinDataSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,25 @@ private Function<SegmentReference, SegmentReference> createSegmentMapFunctionInt
.orElse(null)
)
);

final Function<SegmentReference, SegmentReference> baseMapFn;
// A join data source is not concrete
// And isConcrete() of an unnest datasource delegates to its base
// Hence, in the case of a Join -> Unnest -> Join
// if we just use isConcrete on the left
// the segment map function for the unnest would never get called
// This calls us to delegate to the segmentMapFunction of the left
// only when it is not a JoinDataSource
if (left instanceof JoinDataSource) {
baseMapFn = Function.identity();
} else {
baseMapFn = left.createSegmentMapFunction(
query,
cpuTimeAccumulator
);
}
return baseSegment ->
new HashJoinSegment(
baseSegment,
baseMapFn.apply(baseSegment),
baseFilterToUse,
GuavaUtils.firstNonNull(clausesToUse, ImmutableList.of()),
joinFilterPreAnalysis
Expand All @@ -501,18 +516,39 @@ private static Triple<DataSource, DimFilter, List<PreJoinableClause>> flattenJoi
DimFilter currentDimFilter = null;
final List<PreJoinableClause> preJoinableClauses = new ArrayList<>();

while (current instanceof JoinDataSource) {
final JoinDataSource joinDataSource = (JoinDataSource) current;
current = joinDataSource.getLeft();
currentDimFilter = validateLeftFilter(current, joinDataSource.getLeftFilter());
preJoinableClauses.add(
new PreJoinableClause(
joinDataSource.getRightPrefix(),
joinDataSource.getRight(),
joinDataSource.getJoinType(),
joinDataSource.getConditionAnalysis()
)
);
// There can be queries like
// Join of Unnest of Join of Unnest of Filter
// so these checks are needed to be ORed
// to get the base
// This method is called to get the analysis for the join data source
// Since the analysis of an UnnestDS or FilteredDS always delegates to its base
// To obtain the base data source underneath a Join
// we also iterate through the base of the FilterDS and UnnestDS in its path
// the base of which can be a concrete data source
// This also means that an addition of a new datasource
// Will need an instanceof check here
// A future work should look into if the flattenJoin
// can be refactored to omit these instanceof checks
while (current instanceof JoinDataSource || current instanceof UnnestDataSource || current instanceof FilteredDataSource) {
if (current instanceof JoinDataSource) {
final JoinDataSource joinDataSource = (JoinDataSource) current;
current = joinDataSource.getLeft();
currentDimFilter = validateLeftFilter(current, joinDataSource.getLeftFilter());
preJoinableClauses.add(
new PreJoinableClause(
joinDataSource.getRightPrefix(),
joinDataSource.getRight(),
joinDataSource.getJoinType(),
joinDataSource.getConditionAnalysis()
)
);
} else if (current instanceof UnnestDataSource) {
final UnnestDataSource unnestDataSource = (UnnestDataSource) current;
current = unnestDataSource.getBase();
} else {
final FilteredDataSource filteredDataSource = (FilteredDataSource) current;
current = filteredDataSource.getBase();
}
}

// Join clauses were added in the order we saw them while traversing down, but we need to apply them in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.TrueDimFilter;
import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.segment.join.NoopJoinableFactory;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.easymock.Mock;
import org.junit.Assert;
import org.junit.Rule;
Expand Down Expand Up @@ -433,6 +436,51 @@ public void test_computeJoinDataSourceCacheKey_keyChangesWithPrefix()
Assert.assertFalse(Arrays.equals(cacheKey1, cacheKey2));
}

@Test
public void testGetAnalysisWithUnnestDS()
{
JoinDataSource dataSource = JoinDataSource.create(
UnnestDataSource.create(
new TableDataSource("table1"),
new ExpressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING, ExprMacroTable.nil()),
null
),
new TableDataSource("table2"),
"j.",
"x == \"j.x\"",
JoinType.LEFT,
null,
ExprMacroTable.nil(),
null
);
DataSourceAnalysis analysis = dataSource.getAnalysis();
Assert.assertEquals("table1", analysis.getBaseDataSource().getTableNames().iterator().next());
}

@Test
public void testGetAnalysisWithFilteredDS()
{
JoinDataSource dataSource = JoinDataSource.create(
UnnestDataSource.create(
FilteredDataSource.create(
new TableDataSource("table1"),
TrueDimFilter.instance()
),
new ExpressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING, ExprMacroTable.nil()),
null
),
new TableDataSource("table2"),
"j.",
"x == \"j.x\"",
JoinType.LEFT,
null,
ExprMacroTable.nil(),
null
);
DataSourceAnalysis analysis = dataSource.getAnalysis();
Assert.assertEquals("table1", analysis.getBaseDataSource().getTableNames().iterator().next());
}

@Test
public void test_computeJoinDataSourceCacheKey_keyChangesWithBaseFilter()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static boolean isScanOrMapping(final DruidRel<?> druidRel, final boolean
*/
public static boolean isScanOrProject(final DruidRel<?> druidRel, final boolean canBeJoinOrUnion)
{
if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel
if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel || druidRel instanceof DruidCorrelateUnnestRel
|| druidRel instanceof DruidUnionDataSourceRel))) {
final PartialDruidQuery partialQuery = druidRel.getPartialDruidQuery();
final PartialDruidQuery.Stage stage = partialQuery.stage();
Expand Down
Loading

0 comments on commit 57ab8e1

Please sign in to comment.