Skip to content

Commit

Permalink
Add wait for target allocation id to appear
Browse files Browse the repository at this point in the history
Signed-off-by: Gaurav Bafna <[email protected]>
  • Loading branch information
gbbafna committed Sep 1, 2024
1 parent c308b98 commit 74ce992
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
import org.opensearch.action.StepListener;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.action.support.replication.ReplicationResponse;
import org.opensearch.cluster.routing.IndexShardRoutingTable;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.common.SetOnce;
import org.opensearch.common.concurrent.GatedCloseable;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.engine.RecoveryEngineException;
import org.opensearch.index.seqno.ReplicationTracker;
import org.opensearch.index.seqno.RetentionLease;
import org.opensearch.index.seqno.RetentionLeaseNotFoundException;
import org.opensearch.index.seqno.RetentionLeases;
Expand Down Expand Up @@ -58,21 +55,7 @@ public LocalStorePeerRecoverySourceHandler(
@Override
protected void innerRecoveryToTarget(ActionListener<RecoveryResponse> listener, Consumer<Exception> onFailure) throws IOException {
final SetOnce<RetentionLease> retentionLeaseRef = new SetOnce<>();

RunUnderPrimaryPermit.run(() -> {
final IndexShardRoutingTable routingTable = shard.getReplicationGroup().getRoutingTable();
ShardRouting targetShardRouting = routingTable.getByAllocationId(request.targetAllocationId());
if (targetShardRouting == null) {
logger.debug(
"delaying recovery of {} as it is not listed as assigned to target node {}",
request.shardId(),
request.targetNode()
);
throw new DelayRecoveryException("source node does not have the shard listed in its state as allocated on the node");
}
assert targetShardRouting.initializing() : "expected recovery target to be initializing but was " + targetShardRouting;
retentionLeaseRef.set(shard.getRetentionLeases().get(ReplicationTracker.getPeerRecoveryRetentionLeaseId(targetShardRouting)));
}, shardId + " validating recovery target [" + request.targetAllocationId() + "] registered ", shard, cancellableThreads, logger);
waitForAssignment(retentionLeaseRef);
final Closeable retentionLock = shard.acquireHistoryRetentionLock();
resources.add(retentionLock);
final long startingSeqNo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.action.support.replication.ReplicationResponse;
import org.opensearch.cluster.routing.IndexShardRoutingTable;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.common.CheckedRunnable;
import org.opensearch.common.SetOnce;
import org.opensearch.common.StopWatch;
import org.opensearch.common.concurrent.GatedCloseable;
import org.opensearch.common.lease.Releasable;
Expand All @@ -59,6 +62,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.index.engine.RecoveryEngineException;
import org.opensearch.index.seqno.ReplicationTracker;
import org.opensearch.index.seqno.RetentionLease;
import org.opensearch.index.seqno.RetentionLeaseNotFoundException;
import org.opensearch.index.seqno.RetentionLeases;
Expand Down Expand Up @@ -191,6 +195,25 @@ public void recoverToTarget(ActionListener<RecoveryResponse> listener) {
protected abstract void innerRecoveryToTarget(ActionListener<RecoveryResponse> listener, Consumer<Exception> onFailure)
throws IOException;

protected void waitForAssignment(SetOnce<RetentionLease> retentionLeaseRef) {
RunUnderPrimaryPermit.run(() -> {
final IndexShardRoutingTable routingTable = shard.getReplicationGroup().getRoutingTable();
ShardRouting targetShardRouting = routingTable.getByAllocationId(request.targetAllocationId());
if (targetShardRouting == null) {
logger.debug(
"delaying recovery of {} as it is not listed as assigned to target node {}",
request.shardId(),
request.targetNode()
);
throw new DelayRecoveryException("source node does not have the shard listed in its state as allocated on the node");
}
assert targetShardRouting.initializing() : "expected recovery target to be initializing but was " + targetShardRouting;
retentionLeaseRef.set(shard.getRetentionLeases().get(ReplicationTracker.getPeerRecoveryRetentionLeaseId(targetShardRouting)));
}, shardId + " validating recovery target [" + request.targetAllocationId() + "] registered ", shard, cancellableThreads, logger);
}



protected void finalizeStepAndCompleteFuture(
long startingSeqNo,
StepListener<SendSnapshotResult> sendSnapshotStep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,26 @@

import org.apache.lucene.index.IndexCommit;
import org.opensearch.action.StepListener;
import org.opensearch.action.bulk.BackoffPolicy;
import org.opensearch.cluster.routing.IndexShardRoutingTable;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.common.SetOnce;
import org.opensearch.common.concurrent.GatedCloseable;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.engine.RecoveryEngineException;
import org.opensearch.index.seqno.ReplicationTracker;
import org.opensearch.index.seqno.RetentionLease;
import org.opensearch.index.seqno.SequenceNumbers;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.indices.RunUnderPrimaryPermit;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Transports;

import java.io.IOException;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

/**
Expand All @@ -48,7 +56,8 @@ protected void innerRecoveryToTarget(ActionListener<RecoveryResponse> listener,
// A replica of an index with remote translog does not require the translogs locally and keeps receiving the
// updated segments file on refresh, flushes, and merges. In recovery, here, only file-based recovery is performed
// and there is no translog replay done.

final SetOnce<RetentionLease> retentionLeaseRef = new SetOnce<>();
waitForAssignment(retentionLeaseRef);
final StepListener<SendFileResult> sendFileStep = new StepListener<>();
final StepListener<TimeValue> prepareEngineStep = new StepListener<>();
final StepListener<SendSnapshotResult> sendSnapshotStep = new StepListener<>();
Expand Down Expand Up @@ -102,4 +111,41 @@ protected void innerRecoveryToTarget(ActionListener<RecoveryResponse> listener,

finalizeStepAndCompleteFuture(startingSeqNo, sendSnapshotStep, sendFileStep, prepareEngineStep, onFailure);
}

protected void waitForAssignment(SetOnce<RetentionLease> retentionLeaseRef) {
BackoffPolicy EXPONENTIAL_BACKOFF_POLICY = BackoffPolicy.exponentialBackoff(
TimeValue.timeValueMillis(100),
3
);
AtomicReference<ShardRouting> targetShardRouting = new AtomicReference<>();
Iterator<TimeValue> backoffDelayIterator = EXPONENTIAL_BACKOFF_POLICY.iterator();
while (backoffDelayIterator.hasNext() ) {
RunUnderPrimaryPermit.run(() -> {
final IndexShardRoutingTable routingTable = shard.getReplicationGroup().getRoutingTable();
targetShardRouting.set(routingTable.getByAllocationId(request.targetAllocationId()));
if (targetShardRouting.get() == null) {
logger.info(
"delaying recovery of {} as it is not listed as assigned to target node {}",
request.shardId(),
request.targetNode()
);
Thread.sleep(backoffDelayIterator.next().millis());
}
if (targetShardRouting.get() != null) {
assert targetShardRouting.get().initializing() : "expected recovery target to be initializing but was " + targetShardRouting;
retentionLeaseRef.set(shard.getRetentionLeases().get(ReplicationTracker.getPeerRecoveryRetentionLeaseId(targetShardRouting.get())));
}

}, shardId + " validating recovery target [" + request.targetAllocationId() + "] registered ", shard, cancellableThreads, logger);

if (targetShardRouting.get() != null) {
return;
}
}
if (targetShardRouting.get() != null) {
return;
}
throw new DelayRecoveryException("source node does not have the shard listed in its state as allocated on the node");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@
import org.apache.lucene.store.IOContext;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
import org.junit.After;
import org.junit.Before;
import org.opensearch.ExceptionsHelper;
import org.opensearch.Version;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.StepListener;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.routing.IndexShardRoutingTable;
import org.opensearch.common.Numbers;
import org.opensearch.common.Randomness;
import org.opensearch.common.SetOnce;
Expand Down Expand Up @@ -89,6 +92,7 @@
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardRelocatedException;
import org.opensearch.index.shard.IndexShardState;
import org.opensearch.index.shard.ReplicationGroup;
import org.opensearch.index.store.Store;
import org.opensearch.index.store.StoreFileMetadata;
import org.opensearch.index.translog.Translog;
Expand All @@ -102,8 +106,6 @@
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.io.OutputStream;
Expand Down Expand Up @@ -141,6 +143,8 @@
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
Expand Down Expand Up @@ -746,6 +750,104 @@ void phase2(
assertFalse(phase2Called.get());
}

public void testThrowExceptionOnNoTargetInRouting() throws IOException {
final RecoverySettings recoverySettings = new RecoverySettings(Settings.EMPTY, service);
final StartRecoveryRequest request = getStartRecoveryRequest();
final IndexShard shard = mock(IndexShard.class);
when(shard.seqNoStats()).thenReturn(mock(SeqNoStats.class));
when(shard.segmentStats(anyBoolean(), anyBoolean())).thenReturn(mock(SegmentsStats.class));
when(shard.isRelocatedPrimary()).thenReturn(false);
final ReplicationGroup replicationGroup = mock(ReplicationGroup.class);
final IndexShardRoutingTable routingTable = mock(IndexShardRoutingTable.class);
when(routingTable.getByAllocationId(anyString())).thenReturn(null);
when(shard.getReplicationGroup()).thenReturn(replicationGroup);
when(replicationGroup.getRoutingTable()).thenReturn(routingTable);
when(shard.acquireSafeIndexCommit()).thenReturn(mock(GatedCloseable.class));
doAnswer(invocation -> {
((ActionListener<Releasable>) invocation.getArguments()[0]).onResponse(() -> {});
return null;
}).when(shard).acquirePrimaryOperationPermit(any(), anyString(), any());

final IndexMetadata.Builder indexMetadata = IndexMetadata.builder("test")
.settings(
Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, between(0, 5))
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 5))
.put(IndexMetadata.SETTING_VERSION_CREATED, VersionUtils.randomVersion(random()))
.put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID(random()))
);
if (randomBoolean()) {
indexMetadata.state(IndexMetadata.State.CLOSE);
}
when(shard.indexSettings()).thenReturn(new IndexSettings(indexMetadata.build(), Settings.EMPTY));

final AtomicBoolean phase1Called = new AtomicBoolean();
final AtomicBoolean prepareTargetForTranslogCalled = new AtomicBoolean();
final AtomicBoolean phase2Called = new AtomicBoolean();
final RecoverySourceHandler handler = new LocalStorePeerRecoverySourceHandler(
shard,
mock(RecoveryTargetHandler.class),
threadPool,
request,
Math.toIntExact(recoverySettings.getChunkSize().getBytes()),
between(1, 8),
between(1, 8)
) {

@Override
void phase1(
IndexCommit snapshot,
long startingSeqNo,
IntSupplier translogOps,
ActionListener<SendFileResult> listener,
boolean skipCreateRetentionLeaseStep
) {
phase1Called.set(true);
super.phase1(snapshot, startingSeqNo, translogOps, listener, skipCreateRetentionLeaseStep);
}

@Override
void prepareTargetForTranslog(int totalTranslogOps, ActionListener<TimeValue> listener) {
prepareTargetForTranslogCalled.set(true);
super.prepareTargetForTranslog(totalTranslogOps, listener);
}

@Override
void phase2(
long startingSeqNo,
long endingSeqNo,
Translog.Snapshot snapshot,
long maxSeenAutoIdTimestamp,
long maxSeqNoOfUpdatesOrDeletes,
RetentionLeases retentionLeases,
long mappingVersion,
ActionListener<SendSnapshotResult> listener
) throws IOException {
phase2Called.set(true);
super.phase2(
startingSeqNo,
endingSeqNo,
snapshot,
maxSeenAutoIdTimestamp,
maxSeqNoOfUpdatesOrDeletes,
retentionLeases,
mappingVersion,
listener
);
}

};
PlainActionFuture<RecoveryResponse> future = new PlainActionFuture<>();
expectThrows(DelayRecoveryException.class, () -> {
handler.recoverToTarget(future);
future.actionGet();
});
verify(routingTable, times(1)).getByAllocationId(null);
assertFalse(phase1Called.get());
assertFalse(prepareTargetForTranslogCalled.get());
assertFalse(phase2Called.get());
}

public void testCancellationsDoesNotLeakPrimaryPermits() throws Exception {
final CancellableThreads cancellableThreads = new CancellableThreads();
final IndexShard shard = mock(IndexShard.class);
Expand Down
Loading

0 comments on commit 74ce992

Please sign in to comment.