diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 3dc2e099c5e8..c108c7d679e9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -1429,15 +1429,13 @@ private void publishAllSegments(final Set segments) throws IOExcept .submit(new MarkSegmentsAsUnusedAction(task.getDataSource(), interval)); } } else { - Set versionsToAwait = segmentsWithTombstones.stream().map(DataSegment::getVersion).collect(Collectors.toSet()); if (MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context())) { segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), task.getId(), task.getDataSource(), - versionsToAwait, - segmentsWithTombstones.size(), + segmentsWithTombstones, true ); } @@ -1447,15 +1445,13 @@ private void publishAllSegments(final Set segments) throws IOExcept ); } } else if (!segments.isEmpty()) { - Set versionsToAwait = segments.stream().map(DataSegment::getVersion).collect(Collectors.toSet()); if (MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context())) { segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), task.getId(), task.getDataSource(), - versionsToAwait, - segments.size(), + segments, true ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java index 17f46bad23a2..1546766f856f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java @@ -29,22 +29,27 @@ import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.discovery.BrokerClient; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.http.client.Request; import org.apache.druid.sql.http.ResultFormat; import org.apache.druid.sql.http.SqlQuery; +import org.apache.druid.timeline.DataSegment; import org.jboss.netty.handler.codec.http.HttpMethod; import org.joda.time.DateTime; import org.joda.time.Interval; import javax.annotation.Nullable; import javax.ws.rs.core.MediaType; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; /** * Class that periodically checks with the broker if all the segments generated are loaded by querying the sys table @@ -73,7 +78,7 @@ public class SegmentLoadStatusFetcher implements AutoCloseable * - If replication_factor is -1, the replication factor is not known currently and will become known after a load rule * evaluation. *
- * See https://github.com/apache/druid/pull/14403 for more details about replication_factor + * See this for more details about replication_factor */ private static final String LOAD_QUERY = "SELECT COUNT(*) AS usedSegments,\n" + "COUNT(*) FILTER (WHERE is_published = 1 AND replication_factor > 0) AS precachedSegments,\n" @@ -81,14 +86,14 @@ public class SegmentLoadStatusFetcher implements AutoCloseable + "COUNT(*) FILTER (WHERE is_available = 0 AND is_published = 1 AND replication_factor != 0) AS pendingSegments,\n" + "COUNT(*) FILTER (WHERE replication_factor = -1) AS unknownSegments\n" + "FROM sys.segments\n" - + "WHERE datasource = '%s' AND is_overshadowed = 0 AND version in (%s)"; + + "WHERE datasource = '%s' AND is_overshadowed = 0 AND (%s)"; private final BrokerClient brokerClient; private final ObjectMapper objectMapper; // Map of version vs latest load status. private final AtomicReference versionLoadStatusReference; private final String datasource; - private final String versionsInClauseString; + private final String versionsConditionString; private final int totalSegmentsGenerated; private final boolean doWait; // since live reports fetch the value in another thread, we need to use AtomicReference @@ -101,20 +106,16 @@ public SegmentLoadStatusFetcher( ObjectMapper objectMapper, String taskId, String datasource, - Set versionsToAwait, - int totalSegmentsGenerated, + Set dataSegments, boolean doWait ) { this.brokerClient = brokerClient; this.objectMapper = objectMapper; this.datasource = datasource; - this.versionsInClauseString = String.join( - ",", - versionsToAwait.stream().map(s -> StringUtils.format("'%s'", s)).collect(Collectors.toSet()) - ); + this.versionsConditionString = createVersionCondition(dataSegments); + this.totalSegmentsGenerated = dataSegments.size(); this.versionLoadStatusReference = new AtomicReference<>(new VersionLoadStatus(0, 0, 0, 0, totalSegmentsGenerated)); - this.totalSegmentsGenerated = totalSegmentsGenerated; this.status = new AtomicReference<>(new SegmentLoadWaiterStatus( State.INIT, null, @@ -163,9 +164,8 @@ public void waitForSegmentsToLoad() if (runningMillis - lastLogMillis >= TimeUnit.MINUTES.toMillis(1)) { lastLogMillis = runningMillis; log.info( - "Fetching segment load status for datasource[%s] from broker for segment versions[%s]", - datasource, - versionsInClauseString + "Fetching segment load status for datasource[%s] from broker", + datasource ); } @@ -237,7 +237,7 @@ private void updateStatus(State state, DateTime startTime) private VersionLoadStatus fetchLoadStatusFromBroker() throws Exception { Request request = brokerClient.makeRequest(HttpMethod.POST, "/druid/v2/sql/"); - SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, versionsInClauseString), + SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, versionsConditionString), ResultFormat.OBJECTLINES, false, false, false, null, null ); @@ -255,6 +255,37 @@ private VersionLoadStatus fetchLoadStatusFromBroker() throws Exception } } + /** + * Takes a list of segments and creates the condition for the broker query. Directly creates a string to avoid + * computing it repeatedly. + */ + private static String createVersionCondition(Set dataSegments) + { + // Creates a map of version to earliest and latest partition numbers created. These would be contiguous since the task + // holds the lock. + Map> versionsVsPartitionNumberRangeMap = new HashMap<>(); + + dataSegments.forEach(segment -> { + final String version = segment.getVersion(); + final int partitionNum = segment.getId().getPartitionNum(); + versionsVsPartitionNumberRangeMap.computeIfPresent(version, (k, v) -> Pair.of( + partitionNum < v.lhs ? partitionNum : v.lhs, + partitionNum > v.rhs ? partitionNum : v.rhs + )); + versionsVsPartitionNumberRangeMap.computeIfAbsent(version, k -> Pair.of(partitionNum, partitionNum)); + }); + + // Create a condition for each version / partition + List versionConditionList = new ArrayList<>(); + for (Map.Entry> stringPairEntry : versionsVsPartitionNumberRangeMap.entrySet()) { + Pair pair = stringPairEntry.getValue(); + versionConditionList.add( + StringUtils.format("(version = '%s' AND partition_num BETWEEN %s AND %s)", stringPairEntry.getKey(), pair.lhs, pair.rhs) + ); + } + return String.join(" OR ", versionConditionList); + } + /** * Returns the current status of the load. */ diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java index f2ffa0c9ec72..548a7ac473e9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java @@ -20,14 +20,19 @@ package org.apache.druid.msq.exec; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableSet; import org.apache.druid.discovery.BrokerClient; +import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.http.client.Request; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.partition.NumberedShardSpec; import org.junit.Assert; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; @@ -76,8 +81,7 @@ public String answer(InvocationOnMock invocation) throws Throwable new ObjectMapper(), "id", TEST_DATASOURCE, - ImmutableSet.of("version1"), - 5, + IntStream.range(0, 5).boxed().map(partitionNum -> createTestDataSegment("version1", partitionNum)).collect(Collectors.toSet()), false ); segmentLoadWaiter.waitForSegmentsToLoad(); @@ -114,8 +118,7 @@ public String answer(InvocationOnMock invocation) throws Throwable new ObjectMapper(), "id", TEST_DATASOURCE, - ImmutableSet.of("version1"), - 5, + IntStream.range(0, 5).boxed().map(partitionNum -> createTestDataSegment("version1", partitionNum)).collect(Collectors.toSet()), false ); segmentLoadWaiter.waitForSegmentsToLoad(); @@ -153,8 +156,7 @@ public String answer(InvocationOnMock invocation) throws Throwable new ObjectMapper(), "id", TEST_DATASOURCE, - ImmutableSet.of("version1"), - 5, + IntStream.range(0, 5).boxed().map(partitionNum -> createTestDataSegment("version1", partitionNum)).collect(Collectors.toSet()), true ); @@ -169,4 +171,18 @@ public String answer(InvocationOnMock invocation) throws Throwable Assert.assertTrue(segmentLoadWaiter.status().getState() == SegmentLoadStatusFetcher.State.FAILED); } + private static DataSegment createTestDataSegment(String version, int partitionNumber) + { + return new DataSegment( + TEST_DATASOURCE, + Intervals.ETERNITY, + version, + null, + null, + null, + new NumberedShardSpec(partitionNumber, 1), + 0, + 0 + ); + } }