Skip to content

Commit

Permalink
SortMerge join support for IS NOT DISTINCT FROM. (apache#16003)
Browse files Browse the repository at this point in the history
* SortMerge join support for IS NOT DISTINCT FROM.

The patch adds a "requiredNonNullKeyParts" field to the sortMerge
processor, which has the list of key parts that must be nonnull for
an equijoin condition to match. Conditions with SQL "=" are present in
the list; conditions with SQL "IS NOT DISTINCT FROM" are absent from
the list.

* Fix test.

* Update javadoc.
  • Loading branch information
gianm authored Mar 19, 2024
1 parent fa8e511 commit c96b215
Show file tree
Hide file tree
Showing 10 changed files with 576 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,15 @@ private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgo
/**
* Checks if the sortMerge algorithm can execute a particular join condition.
*
* Two checks:
* (1) join condition on two tables "table1" and "table2" is of the form
* One check: join condition on two tables "table1" and "table2" is of the form
* table1.columnA = table2.columnA && table1.columnB = table2.columnB && ....
*
* (2) join condition uses equals, not IS NOT DISTINCT FROM [sortMerge processor does not currently implement
* IS NOT DISTINCT FROM]
*/
private static boolean canUseSortMergeJoin(JoinConditionAnalysis joinConditionAnalysis)
{
return joinConditionAnalysis
.getEquiConditions()
.stream()
.allMatch(equality -> equality.getLeftExpr().isIdentifier() && !equality.isIncludeNull());
.allMatch(equality -> equality.getLeftExpr().isIdentifier());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor<Object>
FrameWriterFactory frameWriterFactory,
String rightPrefix,
List<List<KeyColumn>> keyColumns,
int[] requiredNonNullKeyParts,
JoinType joinType,
long maxBufferedBytes
)
Expand All @@ -148,8 +149,8 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor<Object>
this.rightPrefix = rightPrefix;
this.joinType = joinType;
this.trackers = ImmutableList.of(
new Tracker(left, keyColumns.get(LEFT), maxBufferedBytes),
new Tracker(right, keyColumns.get(RIGHT), maxBufferedBytes)
new Tracker(left, keyColumns.get(LEFT), requiredNonNullKeyParts, maxBufferedBytes),
new Tracker(right, keyColumns.get(RIGHT), requiredNonNullKeyParts, maxBufferedBytes)
);
this.maxBufferedBytes = maxBufferedBytes;
}
Expand Down Expand Up @@ -195,7 +196,7 @@ public ReturnOrAwait<Object> runIncrementally(IntSet readableInputs) throws IOEx

// Two rows match if the keys compare equal _and_ neither key has a null component. (x JOIN y ON x.a = y.a does
// not match rows where "x.a" is null.)
final boolean marksMatch = markCmp == 0 && trackers.get(LEFT).hasCompletelyNonNullMark();
final boolean marksMatch = markCmp == 0 && trackers.get(LEFT).markHasRequiredNonNullKeyParts();

// If marked keys are equal on both sides ("marksMatch"), at least one side needs to have a complete set of rows
// for the marked key. Check if this is true, otherwise call nextAwait to read more data.
Expand Down Expand Up @@ -446,7 +447,7 @@ private boolean allTrackersAreAtEnd()
/**
* Compares the marked rows of the two {@link #trackers}. This method returns 0 if both sides are null, even
* though this is not considered a match by join semantics. Therefore, it is important to also check
* {@link Tracker#hasCompletelyNonNullMark()}.
* {@link Tracker#markHasRequiredNonNullKeyParts()}.
*
* @return negative if {@link #LEFT} key is earlier, positive if {@link #RIGHT} key is earlier, zero if the keys
* are the same. Returns zero even if a key component is null, even though this is not considered a match by
Expand Down Expand Up @@ -549,6 +550,7 @@ private static class Tracker
private final List<FrameHolder> holders = new ArrayList<>();
private final ReadableInput input;
private final List<KeyColumn> keyColumns;
private final int[] requiredNonNullKeyParts;
private final long maxBytesBuffered;

// markFrame and markRow are the first frame and row with the current key.
Expand All @@ -561,10 +563,16 @@ private static class Tracker
// done indicates that no more data is available in the channel.
private boolean done;

public Tracker(ReadableInput input, List<KeyColumn> keyColumns, long maxBytesBuffered)
public Tracker(
final ReadableInput input,
final List<KeyColumn> keyColumns,
final int[] requiredNonNullKeyParts,
final long maxBytesBuffered
)
{
this.input = input;
this.keyColumns = keyColumns;
this.requiredNonNullKeyParts = requiredNonNullKeyParts;
this.maxBytesBuffered = maxBytesBuffered;
}

Expand Down Expand Up @@ -686,9 +694,9 @@ public boolean hasMark()
/**
* Whether this tracker has a marked row that is completely nonnull.
*/
public boolean hasCompletelyNonNullMark()
public boolean markHasRequiredNonNullKeyParts()
{
return hasMark() && holders.get(markFrame).comparisonWidget.isCompletelyNonNullKey(markRow);
return hasMark() && holders.get(markFrame).comparisonWidget.hasNonNullKeyParts(markRow, requiredNonNullKeyParts);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.processor.FrameProcessor;
Expand Down Expand Up @@ -142,6 +144,7 @@ public ProcessorsAndChannels<Object, Long> makeProcessors(

// Compute key columns.
final List<List<KeyColumn>> keyColumns = toKeyColumns(condition);
final int[] requiredNonNullKeyParts = toRequiredNonNullKeyParts(condition);

// Stitch up the inputs and validate each input channel signature.
// If validateInputFrameSignatures fails, it's a precondition violation: this class somehow got bad inputs.
Expand Down Expand Up @@ -180,6 +183,7 @@ public ProcessorsAndChannels<Object, Long> makeProcessors(
stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()),
rightPrefix,
keyColumns,
requiredNonNullKeyParts,
joinType,
frameContext.memoryParameters().getSortMergeJoinMemory()
);
Expand Down Expand Up @@ -217,6 +221,27 @@ public static List<List<KeyColumn>> toKeyColumns(final JoinConditionAnalysis con
return retVal;
}

/**
* Extracts a list of key parts that must be nonnull from a {@link JoinConditionAnalysis}. These are equality
* conditions for which {@link Equality#isIncludeNull()} is false.
*
* The condition must have been validated by {@link #validateCondition(JoinConditionAnalysis)}.
*/
public static int[] toRequiredNonNullKeyParts(final JoinConditionAnalysis condition)
{
final IntList retVal = new IntArrayList(condition.getEquiConditions().size());

final List<Equality> equiConditions = condition.getEquiConditions();
for (int i = 0; i < equiConditions.size(); i++) {
Equality equiCondition = equiConditions.get(i);
if (!equiCondition.isIncludeNull()) {
retVal.add(i);
}
}

return retVal.toArray(new int[0]);
}

/**
* Validates that a join condition can be handled by this processor. Returns the condition if it can be handled.
* Throws {@link IllegalArgumentException} if the condition cannot be handled.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.msq.querykit.common;

import com.google.common.collect.ImmutableList;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.junit.Assert;
import org.junit.Test;

public class SortMergeJoinFrameProcessorFactoryTest
{
@Test
public void test_validateCondition()
{
Assert.assertNotNull(
SortMergeJoinFrameProcessorFactory.validateCondition(
JoinConditionAnalysis.forExpression("1", "j.", ExprMacroTable.nil())
)
);

Assert.assertNotNull(
SortMergeJoinFrameProcessorFactory.validateCondition(
JoinConditionAnalysis.forExpression("x == \"j.y\"", "j.", ExprMacroTable.nil())
)
);

Assert.assertNotNull(
SortMergeJoinFrameProcessorFactory.validateCondition(
JoinConditionAnalysis.forExpression("1", "j.", ExprMacroTable.nil())
)
);

Assert.assertNotNull(
SortMergeJoinFrameProcessorFactory.validateCondition(
JoinConditionAnalysis.forExpression("x == \"j.y\" && a == \"j.b\"", "j.", ExprMacroTable.nil())
)
);

Assert.assertNotNull(
SortMergeJoinFrameProcessorFactory.validateCondition(
JoinConditionAnalysis.forExpression(
"notdistinctfrom(x, \"j.y\") && a == \"j.b\"",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertThrows(
IllegalArgumentException.class,
() -> SortMergeJoinFrameProcessorFactory.validateCondition(
JoinConditionAnalysis.forExpression("x == y", "j.", ExprMacroTable.nil())
)
);

Assert.assertThrows(
IllegalArgumentException.class,
() -> SortMergeJoinFrameProcessorFactory.validateCondition(
JoinConditionAnalysis.forExpression("x + 1 == \"j.y\"", "j.", ExprMacroTable.nil())
)
);
}

@Test
public void test_toKeyColumns()
{
Assert.assertEquals(
ImmutableList.of(
ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)),
ImmutableList.of(new KeyColumn("y", KeyOrder.ASCENDING))
),
SortMergeJoinFrameProcessorFactory.toKeyColumns(
JoinConditionAnalysis.forExpression(
"x == \"j.y\"",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertEquals(
ImmutableList.of(
ImmutableList.of(),
ImmutableList.of()
),
SortMergeJoinFrameProcessorFactory.toKeyColumns(
JoinConditionAnalysis.forExpression(
"1",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertEquals(
ImmutableList.of(
ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING), new KeyColumn("a", KeyOrder.ASCENDING)),
ImmutableList.of(new KeyColumn("y", KeyOrder.ASCENDING), new KeyColumn("b", KeyOrder.ASCENDING))
),
SortMergeJoinFrameProcessorFactory.toKeyColumns(
JoinConditionAnalysis.forExpression(
"x == \"j.y\" && a == \"j.b\"",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertEquals(
ImmutableList.of(
ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING), new KeyColumn("a", KeyOrder.ASCENDING)),
ImmutableList.of(new KeyColumn("y", KeyOrder.ASCENDING), new KeyColumn("b", KeyOrder.ASCENDING))
),
SortMergeJoinFrameProcessorFactory.toKeyColumns(
JoinConditionAnalysis.forExpression(
"x == \"j.y\" && notdistinctfrom(a, \"j.b\")",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertEquals(
ImmutableList.of(
ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING), new KeyColumn("a", KeyOrder.ASCENDING)),
ImmutableList.of(new KeyColumn("y", KeyOrder.ASCENDING), new KeyColumn("b", KeyOrder.ASCENDING))
),
SortMergeJoinFrameProcessorFactory.toKeyColumns(
JoinConditionAnalysis.forExpression(
"notdistinctfrom(x, \"j.y\") && a == \"j.b\"",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertEquals(
ImmutableList.of(
ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING), new KeyColumn("a", KeyOrder.ASCENDING)),
ImmutableList.of(new KeyColumn("y", KeyOrder.ASCENDING), new KeyColumn("b", KeyOrder.ASCENDING))
),
SortMergeJoinFrameProcessorFactory.toKeyColumns(
JoinConditionAnalysis.forExpression(
"notdistinctfrom(x, \"j.y\") && notdistinctfrom(a, \"j.b\")",
"j.",
ExprMacroTable.nil()
)
)
);
}

@Test
public void test_toRequiredNonNullKeyParts()
{
Assert.assertArrayEquals(
new int[0],
SortMergeJoinFrameProcessorFactory.toRequiredNonNullKeyParts(
JoinConditionAnalysis.forExpression(
"1",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertArrayEquals(
new int[]{0},
SortMergeJoinFrameProcessorFactory.toRequiredNonNullKeyParts(
JoinConditionAnalysis.forExpression(
"x == \"j.y\"",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertArrayEquals(
new int[]{0, 1},
SortMergeJoinFrameProcessorFactory.toRequiredNonNullKeyParts(
JoinConditionAnalysis.forExpression(
"x == \"j.y\" && a == \"j.b\"",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertArrayEquals(
new int[]{0},
SortMergeJoinFrameProcessorFactory.toRequiredNonNullKeyParts(
JoinConditionAnalysis.forExpression(
"x == \"j.y\" && notdistinctfrom(a, \"j.b\")",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertArrayEquals(
new int[]{1},
SortMergeJoinFrameProcessorFactory.toRequiredNonNullKeyParts(
JoinConditionAnalysis.forExpression(
"notdistinctfrom(x, \"j.y\") && a == \"j.b\"",
"j.",
ExprMacroTable.nil()
)
)
);

Assert.assertArrayEquals(
new int[0],
SortMergeJoinFrameProcessorFactory.toRequiredNonNullKeyParts(
JoinConditionAnalysis.forExpression(
"notdistinctfrom(x, \"j.y\") && notdistinctfrom(a, \"j.b\")",
"j.",
ExprMacroTable.nil()
)
)
);
}
}
Loading

0 comments on commit c96b215

Please sign in to comment.