From 63e50876bcdf163afccb2a36ad7f4cff4de8a849 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 16 Oct 2023 10:44:53 +0530 Subject: [PATCH] Fix issue with checking segment load status (#15147) (#15156) This PR addresses a bug with waiting for segments to be loaded. In the case of append, segments would be created with the same version. This caused the number of segments returned to be incorrect. This PR changes this to keep track of the range of partition numbers as well for each version, which lets the task wait for the correct set of segments. The partition numbers are expected to be continuous since the task obtains the lock for the segment while running. Co-authored-by: Adarsh Sanjeev --- .../apache/druid/msq/exec/ControllerImpl.java | 8 +-- .../msq/exec/SegmentLoadStatusFetcher.java | 61 ++++++++++++++----- .../exec/SegmentLoadStatusFetcherTest.java | 30 ++++++--- 3 files changed, 71 insertions(+), 28 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 3dc2e099c5e8..c108c7d679e9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -1429,15 +1429,13 @@ private void publishAllSegments(final Set segments) throws IOExcept .submit(new MarkSegmentsAsUnusedAction(task.getDataSource(), interval)); } } else { - Set versionsToAwait = segmentsWithTombstones.stream().map(DataSegment::getVersion).collect(Collectors.toSet()); if (MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context())) { segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), task.getId(), task.getDataSource(), - versionsToAwait, - segmentsWithTombstones.size(), + segmentsWithTombstones, true ); } @@ -1447,15 +1445,13 @@ private void publishAllSegments(final Set segments) throws IOExcept ); } } else if (!segments.isEmpty()) { - Set versionsToAwait = segments.stream().map(DataSegment::getVersion).collect(Collectors.toSet()); if (MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context())) { segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), task.getId(), task.getDataSource(), - versionsToAwait, - segments.size(), + segments, true ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java index 17f46bad23a2..1546766f856f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java @@ -29,22 +29,27 @@ import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.discovery.BrokerClient; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.http.client.Request; import org.apache.druid.sql.http.ResultFormat; import org.apache.druid.sql.http.SqlQuery; +import org.apache.druid.timeline.DataSegment; import org.jboss.netty.handler.codec.http.HttpMethod; import org.joda.time.DateTime; import org.joda.time.Interval; import javax.annotation.Nullable; import javax.ws.rs.core.MediaType; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; /** * Class that periodically checks with the broker if all the segments generated are loaded by querying the sys table @@ -73,7 +78,7 @@ public class SegmentLoadStatusFetcher implements AutoCloseable * - If replication_factor is -1, the replication factor is not known currently and will become known after a load rule * evaluation. *
- * See https://github.com/apache/druid/pull/14403 for more details about replication_factor + * See this for more details about replication_factor */ private static final String LOAD_QUERY = "SELECT COUNT(*) AS usedSegments,\n" + "COUNT(*) FILTER (WHERE is_published = 1 AND replication_factor > 0) AS precachedSegments,\n" @@ -81,14 +86,14 @@ public class SegmentLoadStatusFetcher implements AutoCloseable + "COUNT(*) FILTER (WHERE is_available = 0 AND is_published = 1 AND replication_factor != 0) AS pendingSegments,\n" + "COUNT(*) FILTER (WHERE replication_factor = -1) AS unknownSegments\n" + "FROM sys.segments\n" - + "WHERE datasource = '%s' AND is_overshadowed = 0 AND version in (%s)"; + + "WHERE datasource = '%s' AND is_overshadowed = 0 AND (%s)"; private final BrokerClient brokerClient; private final ObjectMapper objectMapper; // Map of version vs latest load status. private final AtomicReference versionLoadStatusReference; private final String datasource; - private final String versionsInClauseString; + private final String versionsConditionString; private final int totalSegmentsGenerated; private final boolean doWait; // since live reports fetch the value in another thread, we need to use AtomicReference @@ -101,20 +106,16 @@ public SegmentLoadStatusFetcher( ObjectMapper objectMapper, String taskId, String datasource, - Set versionsToAwait, - int totalSegmentsGenerated, + Set dataSegments, boolean doWait ) { this.brokerClient = brokerClient; this.objectMapper = objectMapper; this.datasource = datasource; - this.versionsInClauseString = String.join( - ",", - versionsToAwait.stream().map(s -> StringUtils.format("'%s'", s)).collect(Collectors.toSet()) - ); + this.versionsConditionString = createVersionCondition(dataSegments); + this.totalSegmentsGenerated = dataSegments.size(); this.versionLoadStatusReference = new AtomicReference<>(new VersionLoadStatus(0, 0, 0, 0, totalSegmentsGenerated)); - this.totalSegmentsGenerated = totalSegmentsGenerated; this.status = new AtomicReference<>(new SegmentLoadWaiterStatus( State.INIT, null, @@ -163,9 +164,8 @@ public void waitForSegmentsToLoad() if (runningMillis - lastLogMillis >= TimeUnit.MINUTES.toMillis(1)) { lastLogMillis = runningMillis; log.info( - "Fetching segment load status for datasource[%s] from broker for segment versions[%s]", - datasource, - versionsInClauseString + "Fetching segment load status for datasource[%s] from broker", + datasource ); } @@ -237,7 +237,7 @@ private void updateStatus(State state, DateTime startTime) private VersionLoadStatus fetchLoadStatusFromBroker() throws Exception { Request request = brokerClient.makeRequest(HttpMethod.POST, "/druid/v2/sql/"); - SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, versionsInClauseString), + SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, versionsConditionString), ResultFormat.OBJECTLINES, false, false, false, null, null ); @@ -255,6 +255,37 @@ private VersionLoadStatus fetchLoadStatusFromBroker() throws Exception } } + /** + * Takes a list of segments and creates the condition for the broker query. Directly creates a string to avoid + * computing it repeatedly. + */ + private static String createVersionCondition(Set dataSegments) + { + // Creates a map of version to earliest and latest partition numbers created. These would be contiguous since the task + // holds the lock. + Map> versionsVsPartitionNumberRangeMap = new HashMap<>(); + + dataSegments.forEach(segment -> { + final String version = segment.getVersion(); + final int partitionNum = segment.getId().getPartitionNum(); + versionsVsPartitionNumberRangeMap.computeIfPresent(version, (k, v) -> Pair.of( + partitionNum < v.lhs ? partitionNum : v.lhs, + partitionNum > v.rhs ? partitionNum : v.rhs + )); + versionsVsPartitionNumberRangeMap.computeIfAbsent(version, k -> Pair.of(partitionNum, partitionNum)); + }); + + // Create a condition for each version / partition + List versionConditionList = new ArrayList<>(); + for (Map.Entry> stringPairEntry : versionsVsPartitionNumberRangeMap.entrySet()) { + Pair pair = stringPairEntry.getValue(); + versionConditionList.add( + StringUtils.format("(version = '%s' AND partition_num BETWEEN %s AND %s)", stringPairEntry.getKey(), pair.lhs, pair.rhs) + ); + } + return String.join(" OR ", versionConditionList); + } + /** * Returns the current status of the load. */ diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java index f2ffa0c9ec72..548a7ac473e9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java @@ -20,14 +20,19 @@ package org.apache.druid.msq.exec; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableSet; import org.apache.druid.discovery.BrokerClient; +import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.http.client.Request; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.partition.NumberedShardSpec; import org.junit.Assert; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; @@ -76,8 +81,7 @@ public String answer(InvocationOnMock invocation) throws Throwable new ObjectMapper(), "id", TEST_DATASOURCE, - ImmutableSet.of("version1"), - 5, + IntStream.range(0, 5).boxed().map(partitionNum -> createTestDataSegment("version1", partitionNum)).collect(Collectors.toSet()), false ); segmentLoadWaiter.waitForSegmentsToLoad(); @@ -114,8 +118,7 @@ public String answer(InvocationOnMock invocation) throws Throwable new ObjectMapper(), "id", TEST_DATASOURCE, - ImmutableSet.of("version1"), - 5, + IntStream.range(0, 5).boxed().map(partitionNum -> createTestDataSegment("version1", partitionNum)).collect(Collectors.toSet()), false ); segmentLoadWaiter.waitForSegmentsToLoad(); @@ -153,8 +156,7 @@ public String answer(InvocationOnMock invocation) throws Throwable new ObjectMapper(), "id", TEST_DATASOURCE, - ImmutableSet.of("version1"), - 5, + IntStream.range(0, 5).boxed().map(partitionNum -> createTestDataSegment("version1", partitionNum)).collect(Collectors.toSet()), true ); @@ -169,4 +171,18 @@ public String answer(InvocationOnMock invocation) throws Throwable Assert.assertTrue(segmentLoadWaiter.status().getState() == SegmentLoadStatusFetcher.State.FAILED); } + private static DataSegment createTestDataSegment(String version, int partitionNumber) + { + return new DataSegment( + TEST_DATASOURCE, + Intervals.ETERNITY, + version, + null, + null, + null, + new NumberedShardSpec(partitionNumber, 1), + 0, + 0 + ); + } }