From 274a10b19e5ed9760d8805480fc884c00dd44e85 Mon Sep 17 00:00:00 2001 From: Vishesh Garg Date: Mon, 30 Oct 2023 14:39:15 +0530 Subject: [PATCH] Segregate advance and advanceUninterruptibly flow in postJoinCursor to allow for interrupts in advance (#15222) Currently advance function in postJoinCursor calls advanceUninterruptibly which in turn keeps calling baseCursor.advanceUninterruptibly until the post join condition matches, without checking for interrupts. This causes the CPU to hit 100% without getting a chance for query to be cancelled. With this change, the call flow of advance and advanceUninterruptibly is separated out so that they call baseCursor.advance and baseCursor.advanceUninterruptibly in them, respectively, giving a chance for interrupts in the former case between successive calls to baseCursor.advance. --- .../druid/segment/join/PostJoinCursor.java | 33 ++- .../segment/join/PostJoinCursorTest.java | 264 ++++++++++++++++++ 2 files changed, 292 insertions(+), 5 deletions(-) create mode 100644 processing/src/test/java/org/apache/druid/segment/join/PostJoinCursorTest.java diff --git a/processing/src/main/java/org/apache/druid/segment/join/PostJoinCursor.java b/processing/src/main/java/org/apache/druid/segment/join/PostJoinCursor.java index 57da128c73d74..26c119bd34f34 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/PostJoinCursor.java +++ b/processing/src/main/java/org/apache/druid/segment/join/PostJoinCursor.java @@ -19,7 +19,7 @@ package org.apache.druid.segment.join; -import org.apache.druid.query.BaseQuery; +import com.google.common.annotations.VisibleForTesting; import org.apache.druid.query.filter.Filter; import org.apache.druid.query.filter.ValueMatcher; import org.apache.druid.segment.ColumnSelectorFactory; @@ -39,7 +39,7 @@ public class PostJoinCursor implements Cursor private final ColumnSelectorFactory columnSelectorFactory; @Nullable - private final ValueMatcher valueMatcher; + private ValueMatcher valueMatcher; @Nullable private final Filter postJoinFilter; @@ -69,7 +69,28 @@ public static PostJoinCursor wrap( return postJoinCursor; } + @VisibleForTesting + public void setValueMatcher(@Nullable ValueMatcher valueMatcher) + { + this.valueMatcher = valueMatcher; + } + private void advanceToMatch() + { + if (valueMatcher != null) { + while (!isDone() && !valueMatcher.matches(false)) { + baseCursor.advance(); + } + } + } + + /** + * Matches tuples coming out of a join to a post-join condition uninterruptibly, and hence can be a long-running call. + * For this reason, {@link PostJoinCursor#advance()} instead calls {@link PostJoinCursor#advanceToMatch()} (unlike + * other cursors) that allows interruptions, thereby resolving issues where the + * CPU thread running PostJoinCursor cannot be terminated + */ + private void advanceToMatchUninterruptibly() { if (valueMatcher != null) { while (!isDone() && !valueMatcher.matches(false)) { @@ -99,15 +120,17 @@ public Filter getPostJoinFilter() @Override public void advance() { - advanceUninterruptibly(); - BaseQuery.checkInterrupted(); + baseCursor.advance(); + // Relies on baseCursor.advance() call inside this for BaseQuery.checkInterrupted() checks -- unlike other cursors + // which call advanceInterruptibly() and hence have to explicitly provision for interrupts. + advanceToMatch(); } @Override public void advanceUninterruptibly() { baseCursor.advanceUninterruptibly(); - advanceToMatch(); + advanceToMatchUninterruptibly(); } @Override diff --git a/processing/src/test/java/org/apache/druid/segment/join/PostJoinCursorTest.java b/processing/src/test/java/org/apache/druid/segment/join/PostJoinCursorTest.java new file mode 100644 index 0000000000000..6813c04bb10f9 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/segment/join/PostJoinCursorTest.java @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.join; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.java.util.common.granularity.Granularity; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.query.QueryInterruptedException; +import org.apache.druid.query.QueryMetrics; +import org.apache.druid.query.filter.Filter; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.QueryableIndex; +import org.apache.druid.segment.QueryableIndexSegment; +import org.apache.druid.segment.QueryableIndexStorageAdapter; +import org.apache.druid.segment.StorageAdapter; +import org.apache.druid.segment.VirtualColumns; +import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis; +import org.apache.druid.timeline.SegmentId; +import org.joda.time.DateTime; +import org.joda.time.Interval; +import org.junit.Test; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static java.lang.Thread.sleep; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class PostJoinCursorTest extends BaseHashJoinSegmentStorageAdapterTest +{ + + public QueryableIndexSegment infiniteFactSegment; + + /** + * Simulates infinite segment by using a base cursor with advance() and advanceInterruptibly() + * reduced to a no-op. + */ + private static class TestInfiniteQueryableIndexSegment extends QueryableIndexSegment + { + + private static class InfiniteQueryableIndexStorageAdapter extends QueryableIndexStorageAdapter + { + CountDownLatch countDownLatch; + + public InfiniteQueryableIndexStorageAdapter(QueryableIndex index, CountDownLatch countDownLatch) + { + super(index); + this.countDownLatch = countDownLatch; + } + + @Override + public Sequence makeCursors( + @Nullable Filter filter, + Interval interval, + VirtualColumns virtualColumns, + Granularity gran, + boolean descending, + @Nullable QueryMetrics queryMetrics + ) + { + return super.makeCursors(filter, interval, virtualColumns, gran, descending, queryMetrics) + .map(cursor -> new CursorNoAdvance(cursor, countDownLatch)); + } + + private static class CursorNoAdvance implements Cursor + { + Cursor cursor; + CountDownLatch countDownLatch; + + public CursorNoAdvance(Cursor cursor, CountDownLatch countDownLatch) + { + this.cursor = cursor; + this.countDownLatch = countDownLatch; + } + + @Override + public ColumnSelectorFactory getColumnSelectorFactory() + { + return cursor.getColumnSelectorFactory(); + } + + @Override + public DateTime getTime() + { + return cursor.getTime(); + } + + @Override + public void advance() + { + // Do nothing to simulate infinite rows + countDownLatch.countDown(); + } + + @Override + public void advanceUninterruptibly() + { + // Do nothing to simulate infinite rows + countDownLatch.countDown(); + + } + + @Override + public boolean isDone() + { + return false; + } + + @Override + public boolean isDoneOrInterrupted() + { + return cursor.isDoneOrInterrupted(); + } + + @Override + public void reset() + { + + } + } + } + + private final StorageAdapter testStorageAdaptor; + + public TestInfiniteQueryableIndexSegment(QueryableIndex index, SegmentId segmentId, CountDownLatch countDownLatch) + { + super(index, segmentId); + testStorageAdaptor = new InfiniteQueryableIndexStorageAdapter(index, countDownLatch); + } + + @Override + public StorageAdapter asStorageAdapter() + { + return testStorageAdaptor; + } + } + + + private static class ExceptionHandler implements Thread.UncaughtExceptionHandler + { + + Throwable exception; + + @Override + public void uncaughtException(Thread t, Throwable e) + { + exception = e; + } + + public Throwable getException() + { + return exception; + } + } + + @Test + public void testAdvanceWithInterruption() throws IOException, InterruptedException + { + + final int rowsBeforeInterrupt = 1000; + + CountDownLatch countDownLatch = new CountDownLatch(rowsBeforeInterrupt); + + infiniteFactSegment = new TestInfiniteQueryableIndexSegment( + JoinTestHelper.createFactIndexBuilder(temporaryFolder.newFolder()).buildMMappedIndex(), + SegmentId.dummy("facts"), + countDownLatch + ); + + countriesTable = JoinTestHelper.createCountriesIndexedTable(); + + Thread joinCursorThread = new Thread(() -> makeCursorAndAdvance()); + ExceptionHandler exceptionHandler = new ExceptionHandler(); + joinCursorThread.setUncaughtExceptionHandler(exceptionHandler); + joinCursorThread.start(); + + countDownLatch.await(1, TimeUnit.SECONDS); + joinCursorThread.interrupt(); + + // Wait for a max of 1 sec for the exception to be set. + for (int i = 0; i < 1000; i++) { + if (exceptionHandler.getException() == null) { + sleep(1); + } else { + assertTrue(exceptionHandler.getException() instanceof QueryInterruptedException); + return; + } + } + fail(); + } + + public void makeCursorAndAdvance() + { + + List joinableClauses = ImmutableList.of( + factToCountryOnIsoCode(JoinType.LEFT) + ); + + JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( + null, + joinableClauses, + VirtualColumns.EMPTY + ); + + HashJoinSegmentStorageAdapter hashJoinSegmentStorageAdapter = new HashJoinSegmentStorageAdapter( + infiniteFactSegment.asStorageAdapter(), + joinableClauses, + joinFilterPreAnalysis + ); + + Cursor cursor = Iterables.getOnlyElement(hashJoinSegmentStorageAdapter.makeCursors( + null, + Intervals.ETERNITY, + VirtualColumns.EMPTY, + Granularities.ALL, + false, + null + ).toList()); + + ((PostJoinCursor) cursor).setValueMatcher(new ValueMatcher() + { + @Override + public boolean matches(boolean includeUnknown) + { + return false; + } + + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + + } + }); + + cursor.advance(); + } +}