Skip to content

Commit

Permalink
Join context hints (apache#17541)
Browse files Browse the repository at this point in the history
* join hints draft

* join algo

* propagate join hints

* review comments

* Use direct hints instead

* Add tests

* Pass preferred algo through pre join clause

* Refactors

* Fix tests

* Revert test changes

* Fix serialization

* Fix tests

* Fix test

* Fix test

* Fix test for sql compat mode

* Increase coverage

* Refactored hint class

---------

Co-authored-by: sreemanamala <[email protected]>
  • Loading branch information
adarshsanjeev and sreemanamala authored Dec 17, 2024
1 parent de9da37 commit bb4416a
Show file tree
Hide file tree
Showing 38 changed files with 2,133 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource;
Expand Down Expand Up @@ -390,7 +391,8 @@ public void testSubqueryWithNestedGroupBy()
JoinType.INNER,
null,
TestExprMacroTable.INSTANCE,
null
null,
JoinAlgorithm.BROADCAST
)
)
.intervals(querySegmentSpec(Intervals.ETERNITY))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;

import javax.annotation.Nullable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
import org.apache.druid.msq.input.ReadableInput;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.query.Query;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.druid.query.DataSource;
import org.apache.druid.query.FilteredDataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.query.QueryContext;
Expand All @@ -65,8 +66,6 @@
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.sql.calcite.external.ExternalDataSource;
import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.joda.time.Interval;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -212,18 +211,19 @@ public static DataSourcePlan forDataSource(
broadcast
);
} else if (dataSource instanceof JoinDataSource) {
final JoinAlgorithm preferredJoinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext);
JoinDataSource joinDataSource = (JoinDataSource) dataSource;
final JoinAlgorithm preferredJoinAlgorithm = joinDataSource.getJoinAlgorithm();
final JoinAlgorithm deducedJoinAlgorithm = deduceJoinAlgorithm(
preferredJoinAlgorithm,
((JoinDataSource) dataSource)
joinDataSource
);

switch (deducedJoinAlgorithm) {
case BROADCAST:
return forBroadcastHashJoin(
queryKitSpec,
queryContext,
(JoinDataSource) dataSource,
joinDataSource,
querySegmentSpec,
filter,
filterFields,
Expand All @@ -234,7 +234,7 @@ public static DataSourcePlan forDataSource(
case SORT_MERGE:
return forSortMergeJoin(
queryKitSpec,
(JoinDataSource) dataSource,
joinDataSource,
querySegmentSpec,
minStageNumber,
broadcast
Expand Down Expand Up @@ -615,7 +615,8 @@ private static DataSourcePlan forBroadcastHashJoin(
clause.getJoinType(),
// First JoinDataSource (i == 0) involves the base table, so we need to propagate the base table filter.
i == 0 ? analysis.getJoinBaseTableFilter().orElse(null) : null,
dataSource.getJoinableFactoryWrapper()
dataSource.getJoinableFactoryWrapper(),
clause.getJoinAlgorithm()
);
inputSpecs.addAll(clausePlan.getInputSpecs());
clausePlan.getBroadcastInputs().intStream().forEach(n -> broadcastInputs.add(n + shift));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.query.OrderBy;
import org.apache.druid.query.Query;
Expand Down Expand Up @@ -81,7 +82,6 @@
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.ColumnMapping;
import org.apache.druid.sql.calcite.planner.ColumnMappings;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.hamcrest.CoreMatchers;
Expand Down Expand Up @@ -1043,7 +1043,9 @@ private void testJoin(String contextName, Map<String, Object> context, final Joi
DruidExpression.ofColumn(ColumnType.FLOAT, "m1"),
DruidExpression.ofColumn(ColumnType.FLOAT, "j0.m1")
),
JoinType.INNER
JoinType.INNER,
null,
joinAlgorithm
)
)
.setInterval(querySegmentSpec(Filtration.eternity()))
Expand Down Expand Up @@ -2523,7 +2525,9 @@ public void testJoinUsesDifferentAlgorithm(String contextName, Map<String, Objec
),
"j0.",
"(CAST(floor(100), 'DOUBLE') == \"j0.d0\")",
JoinType.LEFT
JoinType.LEFT,
null,
JoinAlgorithm.SORT_MERGE
)
)
.setAggregatorSpecs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
Expand All @@ -50,7 +51,6 @@
import org.apache.druid.segment.TestIndex;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.easymock.EasyMock;
Expand Down Expand Up @@ -190,6 +190,7 @@ public void testBuildTableAndInlineData() throws IOException
JoinConditionAnalysis.forExpression("x == \"j.x\"", "j.", ExprMacroTable.nil()),
JoinType.INNER,
null,
null,
null
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

import com.google.common.collect.ImmutableMap;
import org.apache.druid.msq.sql.MSQTaskSqlEngine;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.CalciteJoinQueryTest;
import org.apache.druid.sql.calcite.QueryTestBuilder;
import org.apache.druid.sql.calcite.SqlTestFrameworkConfig;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import java.util.Map;

Expand Down
Loading

0 comments on commit bb4416a

Please sign in to comment.