From 990ef0f87708c8e3e338b8f0148b0d6d7b6f18c9 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 28 Feb 2024 08:51:00 -0600 Subject: [PATCH] JNI bindings for distinct_hash_join (#15019) Adds Java bindings to the distinct hash join functionality added in #14990. Authors: - Jason Lowe (https://github.com/jlowe) Approvers: - Jim Brennan (https://github.com/jbrennan333) - Nghia Truong (https://github.com/ttnghia) --- java/src/main/java/ai/rapids/cudf/Table.java | 105 +++++++++++++++-- java/src/main/native/src/TableJni.cpp | 28 ++++- .../test/java/ai/rapids/cudf/TableTest.java | 111 +++++++++++++++++- 3 files changed, 231 insertions(+), 13 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 1356c93c64d..c562e08b4c8 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -626,6 +626,9 @@ private static native long[] leftHashJoinGatherMapsWithCount(long leftTable, lon private static native long[] innerJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long[] innerDistinctJoinGatherMaps(long leftKeys, long rightKeys, + boolean compareNullsEqual) throws CudfException; + private static native long innerJoinRowCount(long table, long hashJoin) throws CudfException; private static native long[] innerHashJoinGatherMaps(long table, long hashJoin) throws CudfException; @@ -2920,7 +2923,9 @@ private static GatherMap[] buildJoinGatherMaps(long[] gatherMapData) { * the table argument represents the key columns from the right table. Two {@link GatherMap} * instances will be returned that can be used to gather the left and right tables, * respectively, to produce the result of the left join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightKeys join key columns from the right table * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps @@ -2956,7 +2961,9 @@ public long leftJoinRowCount(HashJoin rightHash) { * the {@link HashJoin} argument has been constructed from the key columns from the right table. * Two {@link GatherMap} instances will be returned that can be used to gather the left and right * tables, respectively, to produce the result of the left join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightHash hash table built from join key columns from the right table * @return left and right table gather maps */ @@ -2975,11 +2982,15 @@ public GatherMap[] leftJoinGatherMaps(HashJoin rightHash) { * the {@link HashJoin} argument has been constructed from the key columns from the right table. * Two {@link GatherMap} instances will be returned that can be used to gather the left and right * tables, respectively, to produce the result of the left join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing an output row count that was previously computed from * {@link #leftJoinRowCount(HashJoin)}. + * * WARNING: Passing a row count that is smaller than the actual row count will result * in undefined behavior. + * * @param rightHash hash table built from join key columns from the right table * @param outputRowCount number of output rows in the join result * @return left and right table gather maps @@ -3013,7 +3024,9 @@ public long conditionalLeftJoinRowCount(Table rightTable, CompiledExpression con * the columns from the left table, and the table argument represents the columns from the * right table. Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the left join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join * @return left and right table gather maps @@ -3032,11 +3045,15 @@ public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable, * the columns from the left table, and the table argument represents the columns from the * right table. Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the left join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing an output row count that was previously computed from * {@link #conditionalLeftJoinRowCount(Table, CompiledExpression)}. + * * WARNING: Passing a row count that is smaller than the actual row count will result * in undefined behavior. + * * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join * @param outputRowCount number of output rows in the join result @@ -3085,7 +3102,9 @@ public static MixedJoinSize mixedLeftJoinSize(Table leftKeys, Table rightKeys, * assumed to be a logical AND of the equality condition and inequality condition. * Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the left join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition @@ -3112,10 +3131,13 @@ public static GatherMap[] mixedLeftJoinGatherMaps(Table leftKeys, Table rightKey * assumed to be a logical AND of the equality condition and inequality condition. * Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the left join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing the size result from * {@link #mixedLeftJoinSize(Table, Table, Table, Table, CompiledExpression, NullEquality)} * when the output size was computed previously. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition @@ -3145,14 +3167,16 @@ public static GatherMap[] mixedLeftJoinGatherMaps(Table leftKeys, Table rightKey * the table argument represents the key columns from the right table. Two {@link GatherMap} * instances will be returned that can be used to gather the left and right tables, * respectively, to produce the result of the inner join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightKeys join key columns from the right table * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps */ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqual) { if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightKeys.getNumberOfColumns()); } long[] gatherMapData = @@ -3160,6 +3184,30 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of an inner equi-join between + * two tables where the right table is guaranteed to not contain any duplicated join keys. It is + * assumed this table instance holds the key columns from the left table, and the table argument + * represents the key columns from the right table. Two {@link GatherMap} instances will be + * returned that can be used to gather the left and right tables, respectively, to produce the + * result of the inner join. + * + * It is the responsibility of the caller to close the resulting gather map instances. + * + * @param rightKeys join key columns from the right table + * @param compareNullsEqual true if null key values should match otherwise false + * @return left and right table gather maps + */ + public GatherMap[] innerDistinctJoinGatherMaps(Table rightKeys, boolean compareNullsEqual) { + if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) { + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightKeys.getNumberOfColumns()); + } + long[] gatherMapData = + innerDistinctJoinGatherMaps(getNativeView(), rightKeys.getNativeView(), compareNullsEqual); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the number of rows resulting from an inner equi-join between two tables. * @param otherHash hash table built from join key columns from the other table @@ -3167,7 +3215,7 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua */ public long innerJoinRowCount(HashJoin otherHash) { if (getNumberOfColumns() != otherHash.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "otherKeys: " + otherHash.getNumberOfColumns()); } return innerJoinRowCount(getNativeView(), otherHash.getNativeView()); @@ -3179,13 +3227,15 @@ public long innerJoinRowCount(HashJoin otherHash) { * the {@link HashJoin} argument has been constructed from the key columns from the right table. * Two {@link GatherMap} instances will be returned that can be used to gather the left and right * tables, respectively, to produce the result of the inner join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightHash hash table built from join key columns from the right table * @return left and right table gather maps */ public GatherMap[] innerJoinGatherMaps(HashJoin rightHash) { if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } long[] gatherMapData = innerHashJoinGatherMaps(getNativeView(), rightHash.getNativeView()); @@ -3198,18 +3248,22 @@ public GatherMap[] innerJoinGatherMaps(HashJoin rightHash) { * the {@link HashJoin} argument has been constructed from the key columns from the right table. * Two {@link GatherMap} instances will be returned that can be used to gather the left and right * tables, respectively, to produce the result of the inner join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing an output row count that was previously computed from * {@link #innerJoinRowCount(HashJoin)}. + * * WARNING: Passing a row count that is smaller than the actual row count will result * in undefined behavior. + * * @param rightHash hash table built from join key columns from the right table * @param outputRowCount number of output rows in the join result * @return left and right table gather maps */ public GatherMap[] innerJoinGatherMaps(HashJoin rightHash, long outputRowCount) { if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } long[] gatherMapData = innerHashJoinGatherMapsWithCount(getNativeView(), @@ -3237,7 +3291,9 @@ public long conditionalInnerJoinRowCount(Table rightTable, * the columns from the left table, and the table argument represents the columns from the * right table. Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the inner join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightTable the right side table of the join * @param condition conditional expression to evaluate during the join * @return left and right table gather maps @@ -3256,11 +3312,15 @@ public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable, * the columns from the left table, and the table argument represents the columns from the * right table. Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the inner join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing an output row count that was previously computed from * {@link #conditionalInnerJoinRowCount(Table, CompiledExpression)}. + * * WARNING: Passing a row count that is smaller than the actual row count will result * in undefined behavior. + * * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join * @param outputRowCount number of output rows in the join result @@ -3309,7 +3369,9 @@ public static MixedJoinSize mixedInnerJoinSize(Table leftKeys, Table rightKeys, * assumed to be a logical AND of the equality condition and inequality condition. * Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the inner join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition @@ -3336,10 +3398,13 @@ public static GatherMap[] mixedInnerJoinGatherMaps(Table leftKeys, Table rightKe * assumed to be a logical AND of the equality condition and inequality condition. * Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the inner join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing the size result from * {@link #mixedInnerJoinSize(Table, Table, Table, Table, CompiledExpression, NullEquality)} * when the output size was computed previously. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition @@ -3369,14 +3434,16 @@ public static GatherMap[] mixedInnerJoinGatherMaps(Table leftKeys, Table rightKe * the table argument represents the key columns from the right table. Two {@link GatherMap} * instances will be returned that can be used to gather the left and right tables, * respectively, to produce the result of the full join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightKeys join key columns from the right table * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps */ public GatherMap[] fullJoinGatherMaps(Table rightKeys, boolean compareNullsEqual) { if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightKeys.getNumberOfColumns()); } long[] gatherMapData = @@ -3396,7 +3463,7 @@ public GatherMap[] fullJoinGatherMaps(Table rightKeys, boolean compareNullsEqual */ public long fullJoinRowCount(HashJoin rightHash) { if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } return fullJoinRowCount(getNativeView(), rightHash.getNativeView()); @@ -3408,13 +3475,15 @@ public long fullJoinRowCount(HashJoin rightHash) { * the {@link HashJoin} argument has been constructed from the key columns from the right table. * Two {@link GatherMap} instances will be returned that can be used to gather the left and right * tables, respectively, to produce the result of the full join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightHash hash table built from join key columns from the right table * @return left and right table gather maps */ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash) { if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } long[] gatherMapData = fullHashJoinGatherMaps(getNativeView(), rightHash.getNativeView()); @@ -3427,7 +3496,9 @@ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash) { * the {@link HashJoin} argument has been constructed from the key columns from the right table. * Two {@link GatherMap} instances will be returned that can be used to gather the left and right * tables, respectively, to produce the result of the full join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing an output row count that was previously computed from * {@link #fullJoinRowCount(HashJoin)}. * WARNING: Passing a row count that is smaller than the actual row count will result @@ -3438,7 +3509,7 @@ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash) { */ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash, long outputRowCount) { if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } long[] gatherMapData = fullHashJoinGatherMapsWithCount(getNativeView(), @@ -3452,7 +3523,9 @@ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash, long outputRowCount) { * the columns from the left table, and the table argument represents the columns from the * right table. Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the full join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param rightTable the right side table of the join * @param condition conditional expression to evaluate during the join * @return left and right table gather maps @@ -3471,7 +3544,9 @@ public GatherMap[] conditionalFullJoinGatherMaps(Table rightTable, * assumed to be a logical AND of the equality condition and inequality condition. * Two {@link GatherMap} instances will be returned that can be used to gather * the left and right tables, respectively, to produce the result of the full join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition @@ -3512,7 +3587,7 @@ private static GatherMap buildSemiJoinGatherMap(long[] gatherMapData) { */ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqual) { if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightKeys.getNumberOfColumns()); } long[] gatherMapData = @@ -3612,7 +3687,9 @@ public static MixedJoinSize mixedLeftSemiJoinSize(Table leftKeys, Table rightKey * assumed to be a logical AND of the equality condition and inequality condition. * A {@link GatherMap} instance will be returned that can be used to gather * the left table to produce the result of the left semi join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition @@ -3639,10 +3716,13 @@ public static GatherMap mixedLeftSemiJoinGatherMap(Table leftKeys, Table rightKe * assumed to be a logical AND of the equality condition and inequality condition. * A {@link GatherMap} instance will be returned that can be used to gather * the left table to produce the result of the left semi join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing the size result from * {@link #mixedLeftSemiJoinSize(Table, Table, Table, Table, CompiledExpression, NullEquality)} * when the output size was computed previously. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition @@ -3679,7 +3759,7 @@ public static GatherMap mixedLeftSemiJoinGatherMap(Table leftKeys, Table rightKe */ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqual) { if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) { - throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightKeys.getNumberOfColumns()); } long[] gatherMapData = @@ -3779,7 +3859,9 @@ public static MixedJoinSize mixedLeftAntiJoinSize(Table leftKeys, Table rightKey * assumed to be a logical AND of the equality condition and inequality condition. * A {@link GatherMap} instance will be returned that can be used to gather * the left table to produce the result of the left anti join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition @@ -3806,10 +3888,13 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe * assumed to be a logical AND of the equality condition and inequality condition. * A {@link GatherMap} instance will be returned that can be used to gather * the left table to produce the result of the left anti join. + * * It is the responsibility of the caller to close the resulting gather map instances. + * * This interface allows passing the size result from * {@link #mixedLeftAntiJoinSize(Table, Table, Table, Table, CompiledExpression, NullEquality)} * when the output size was computed previously. + * * @param leftKeys the left table's key columns for the equality condition * @param rightKeys the right table's key columns for the equality condition * @param leftConditional the left table's columns needed to evaluate the inequality condition diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 8585761788e..84f1174fd3f 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -702,9 +702,9 @@ jlongArray gather_maps_to_java(JNIEnv *env, jlongArray gather_map_to_java(JNIEnv *env, std::unique_ptr> map) { // release the underlying device buffer to Java - auto gather_map_buffer = std::make_unique(map->release()); cudf::jni::native_jlongArray result(env, 3); - result[0] = static_cast(gather_map_buffer->size()); + result[0] = static_cast(map->size() * sizeof(cudf::size_type)); + auto gather_map_buffer = std::make_unique(map->release()); result[1] = ptr_as_jlong(gather_map_buffer->data()); result[2] = release_as_jlong(gather_map_buffer); return result.get_jArray(); @@ -2557,6 +2557,30 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerDistinctJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { + return cudf::jni::join_gather_maps( + env, j_left_keys, j_right_keys, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, cudf::null_equality nulleq) { + auto has_nulls = cudf::has_nested_nulls(left) || cudf::has_nested_nulls(right) ? + cudf::nullable_join::YES : + cudf::nullable_join::NO; + std::pair>, + std::unique_ptr>> + maps; + if (cudf::detail::has_nested_columns(right)) { + cudf::distinct_hash_join hash(right, left, has_nulls, nulleq); + maps = hash.inner_join(); + } else { + cudf::distinct_hash_join hash(right, left, has_nulls, nulleq); + maps = hash.inner_join(); + } + // Unique join returns {right map, left map} but all the other joins + // return {left map, right map}. Swap here to make it consistent. + return std::make_pair(std::move(maps.second), std::move(maps.first)); + }); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_innerJoinRowCount(JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join) { diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index efdb6f4bb1b..6f0b2b51f4c 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -33,7 +33,6 @@ import com.google.common.base.Charsets; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import org.apache.avro.SchemaBuilder; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.parquet.hadoop.ParquetFileReader; @@ -2104,6 +2103,116 @@ void testInnerJoinGatherMapsNulls() { } } + private void checkInnerDistinctJoin(Table leftKeys, Table rightKeys, Table expected, + boolean compareNullsEqual) { + GatherMap[] maps = leftKeys.innerDistinctJoinGatherMaps(rightKeys, compareNullsEqual); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + + @Test + void testInnerDistinctJoinGatherMaps() { + try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8, 6).build(); + Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column(2, 7, 8, 9, 10) // left + .column(2, 0, 1, 3, 0) // right + .build()) { + checkInnerDistinctJoin(leftKeys, rightKeys, expected, false); + } + } + + @Test + void testInnerDistinctJoinGatherMapsWithNested() { + StructType structType = new StructType(false, + new BasicType(false, DType.STRING), + new BasicType(false, DType.INT32)); + StructData[] leftData = new StructData[]{ + new StructData("abc", 1), + new StructData("xyz", 1), + new StructData("abc", 2), + new StructData("xyz", 2), + new StructData("abc", 1), + new StructData("abc", 3), + new StructData("xyz", 3) + }; + StructData[] rightData = new StructData[]{ + new StructData("abc", 1), + new StructData("xyz", 4), + new StructData("xyz", 2), + new StructData("abc", -1), + }; + try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build(); + Table rightKeys = new Table.TestBuilder().column(structType, rightData).build(); + Table expected = new Table.TestBuilder() + .column(0, 3, 4) + .column(0, 2, 0) + .build()) { + checkInnerDistinctJoin(leftKeys, rightKeys, expected, false); + } + } + + @Test + void testInnerDistinctJoinGatherMapsNullsEqual() { + try (Table leftKeys = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table rightKeys = new Table.TestBuilder() + .column(null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(2, 7, 8, 9) // left + .column(1, 0, 0, 2) // right + .build()) { + checkInnerDistinctJoin(leftKeys, rightKeys, expected, true); + } + } + + @Test + void testInnerDistinctJoinGatherMapsWithNestedNullsEqual() { + StructType structType = new StructType(true, + new BasicType(true, DType.STRING), + new BasicType(true, DType.INT32)); + StructData[] leftData = new StructData[]{ + new StructData("abc", 1), + null, + new StructData("xyz", 1), + new StructData("abc", 2), + new StructData("xyz", null), + null, + new StructData("abc", 1), + new StructData("abc", 3), + new StructData("xyz", 3), + new StructData(null, null), + new StructData(null, 1) + }; + StructData[] rightData = new StructData[]{ + null, + new StructData("abc", 1), + new StructData("xyz", 4), + new StructData("xyz", 2), + new StructData(null, null), + new StructData(null, 2), + new StructData(null, 1), + new StructData("xyz", null), + new StructData("abc", null), + new StructData("abc", -1) + }; + try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build(); + Table rightKeys = new Table.TestBuilder().column(structType, rightData).build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 4, 5, 6, 9, 10) + .column(1, 0, 7, 0, 1, 4, 6) + .build()) { + checkInnerDistinctJoin(leftKeys, rightKeys, expected, true); + } + } + @Test void testInnerHashJoinGatherMaps() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build();