Skip to content

Commit

Permalink
[FLINK-36613] Fix unstable RescaleCheckpointManuallyITCase in paralle…
Browse files Browse the repository at this point in the history
…l run
  • Loading branch information
Zakelly committed Nov 13, 2024
1 parent 584dc46 commit 73c120d
Showing 1 changed file with 61 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.api.connector.sink2.SinkWriter;
import org.apache.flink.api.connector.sink2.WriterInitContext;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.CheckpointingOptions;
Expand All @@ -38,7 +41,6 @@
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.legacy.SinkFunction;
import org.apache.flink.streaming.api.functions.source.legacy.RichParallelSourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.testutils.junit.SharedObjects;
Expand Down Expand Up @@ -152,6 +154,7 @@ private String runJobAndGetCheckpoint(
int maxParallelism,
MiniCluster miniCluster)
throws Exception {
JobID jobID = null;
try {
JobGraph jobGraph =
createJobGraphWithKeyedState(
Expand All @@ -163,15 +166,18 @@ private String runJobAndGetCheckpoint(
true,
100,
miniCluster);
jobID = jobGraph.getJobID();
miniCluster.submitJob(jobGraph).get();
miniCluster.requestJobResult(jobGraph.getJobID()).get();
return getLatestCompletedCheckpointPath(jobGraph.getJobID(), miniCluster)
miniCluster.requestJobResult(jobID).get();
return getLatestCompletedCheckpointPath(jobID, miniCluster)
.orElseThrow(
() ->
new IllegalStateException(
"Cannot get completed checkpoint, job failed before completing checkpoint"));
} finally {
CollectionSink.clearElementsSet();
if (jobID != null) {
CollectionSink.clearElementsSet(jobID);
}
}
}

Expand All @@ -184,6 +190,7 @@ private void restoreAndAssert(
MiniCluster miniCluster,
String restorePath)
throws Exception {
JobID jobID = null;
try {
JobGraph scaledJobGraph =
createJobGraphWithKeyedState(
Expand All @@ -195,13 +202,14 @@ private void restoreAndAssert(
false,
100,
miniCluster);
jobID = scaledJobGraph.getJobID();

scaledJobGraph.setSavepointRestoreSettings(forPath(restorePath));

miniCluster.submitJob(scaledJobGraph).get();
miniCluster.requestJobResult(scaledJobGraph.getJobID()).get();
miniCluster.requestJobResult(jobID).get();

Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet(jobID);

Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();

Expand All @@ -215,7 +223,9 @@ private void restoreAndAssert(
}
assertEquals(expectedResult, actualResult);
} finally {
CollectionSink.clearElementsSet();
if (jobID != null) {
CollectionSink.clearElementsSet(jobID);
}
}
}

Expand Down Expand Up @@ -282,7 +292,7 @@ public Integer getKey(Integer value) {
DataStream<Tuple2<Integer, Integer>> result =
input.flatMap(new SubtaskIndexFlatMapper(numberElementsExpect));

result.addSink(new CollectionSink<>());
result.sinkTo(new CollectionSink<>());

return env.getStreamGraph().getJobGraph(env.getClass().getClassLoader(), jobID.get());
}
Expand Down Expand Up @@ -389,25 +399,59 @@ public void initializeState(FunctionInitializationContext context) throws Except
}
}

private static class CollectionSink<IN> implements SinkFunction<IN> {
private static class CollectionSink<IN> implements Sink<IN> {

private static final Set<Object> elements =
Collections.newSetFromMap(new ConcurrentHashMap<>());
private static final ConcurrentHashMap<JobID, CollectionSinkWriter<?>> writers =
new ConcurrentHashMap<>();

private static final long serialVersionUID = 1L;

@SuppressWarnings("unchecked")
public static <IN> Set<IN> getElementsSet() {
return (Set<IN>) elements;
public static <IN> Set<IN> getElementsSet(JobID jobID) {
CollectionSinkWriter<IN> writer = (CollectionSinkWriter<IN>) writers.get(jobID);
if (writer == null) {
return Collections.emptySet();
} else {
return writer.getElementsSet();
}
}

public static void clearElementsSet() {
elements.clear();
public static void clearElementsSet(JobID jobID) {
writers.remove(jobID);
}

@Override
public void invoke(IN value) throws Exception {
elements.add(value);
@SuppressWarnings("unchecked")
public SinkWriter<IN> createWriter(WriterInitContext context) throws IOException {
final CollectionSinkWriter<IN> writer =
(CollectionSinkWriter<IN>)
writers.computeIfAbsent(
context.getJobInfo().getJobId(),
(k) -> new CollectionSinkWriter<IN>());
return writer;
}

private static class CollectionSinkWriter<IN> implements SinkWriter<IN> {

private final Set<Object> elements =
Collections.newSetFromMap(new ConcurrentHashMap<>());

@Override
public void write(IN element, Context context)
throws IOException, InterruptedException {
elements.add(element);
}

@Override
public void flush(boolean endOfInput) throws IOException, InterruptedException {}

@Override
public void close() throws Exception {}

@SuppressWarnings("unchecked")
public <IN> Set<IN> getElementsSet() {
return (Set<IN>) elements;
}
}
}
}

0 comments on commit 73c120d

Please sign in to comment.