diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java index e85cfd8bdb6fe..4294ff683826d 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java @@ -22,6 +22,9 @@ import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.connector.sink2.Sink; +import org.apache.flink.api.connector.sink2.SinkWriter; +import org.apache.flink.api.connector.sink2.WriterInitContext; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.CheckpointingOptions; @@ -38,7 +41,6 @@ import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.functions.sink.legacy.SinkFunction; import org.apache.flink.streaming.api.functions.source.legacy.RichParallelSourceFunction; import org.apache.flink.test.util.MiniClusterWithClientResource; import org.apache.flink.testutils.junit.SharedObjects; @@ -152,6 +154,7 @@ private String runJobAndGetCheckpoint( int maxParallelism, MiniCluster miniCluster) throws Exception { + JobID jobID = null; try { JobGraph jobGraph = createJobGraphWithKeyedState( @@ -163,15 +166,18 @@ private String runJobAndGetCheckpoint( true, 100, miniCluster); + jobID = jobGraph.getJobID(); miniCluster.submitJob(jobGraph).get(); - miniCluster.requestJobResult(jobGraph.getJobID()).get(); - return getLatestCompletedCheckpointPath(jobGraph.getJobID(), miniCluster) + miniCluster.requestJobResult(jobID).get(); + return getLatestCompletedCheckpointPath(jobID, miniCluster) .orElseThrow( () -> new IllegalStateException( "Cannot get completed checkpoint, job failed before completing checkpoint")); } finally { - CollectionSink.clearElementsSet(); + if (jobID != null) { + CollectionSink.clearElementsSet(jobID); + } } } @@ -184,6 +190,7 @@ private void restoreAndAssert( MiniCluster miniCluster, String restorePath) throws Exception { + JobID jobID = null; try { JobGraph scaledJobGraph = createJobGraphWithKeyedState( @@ -195,13 +202,14 @@ private void restoreAndAssert( false, 100, miniCluster); + jobID = scaledJobGraph.getJobID(); scaledJobGraph.setSavepointRestoreSettings(forPath(restorePath)); miniCluster.submitJob(scaledJobGraph).get(); - miniCluster.requestJobResult(scaledJobGraph.getJobID()).get(); + miniCluster.requestJobResult(jobID).get(); - Set> actualResult = CollectionSink.getElementsSet(); + Set> actualResult = CollectionSink.getElementsSet(jobID); Set> expectedResult = new HashSet<>(); @@ -215,7 +223,9 @@ private void restoreAndAssert( } assertEquals(expectedResult, actualResult); } finally { - CollectionSink.clearElementsSet(); + if (jobID != null) { + CollectionSink.clearElementsSet(jobID); + } } } @@ -282,7 +292,7 @@ public Integer getKey(Integer value) { DataStream> result = input.flatMap(new SubtaskIndexFlatMapper(numberElementsExpect)); - result.addSink(new CollectionSink<>()); + result.sinkTo(new CollectionSink<>()); return env.getStreamGraph().getJobGraph(env.getClass().getClassLoader(), jobID.get()); } @@ -389,25 +399,59 @@ public void initializeState(FunctionInitializationContext context) throws Except } } - private static class CollectionSink implements SinkFunction { + private static class CollectionSink implements Sink { - private static final Set elements = - Collections.newSetFromMap(new ConcurrentHashMap<>()); + private static final ConcurrentHashMap> writers = + new ConcurrentHashMap<>(); private static final long serialVersionUID = 1L; @SuppressWarnings("unchecked") - public static Set getElementsSet() { - return (Set) elements; + public static Set getElementsSet(JobID jobID) { + CollectionSinkWriter writer = (CollectionSinkWriter) writers.get(jobID); + if (writer == null) { + return Collections.emptySet(); + } else { + return writer.getElementsSet(); + } } - public static void clearElementsSet() { - elements.clear(); + public static void clearElementsSet(JobID jobID) { + writers.remove(jobID); } @Override - public void invoke(IN value) throws Exception { - elements.add(value); + @SuppressWarnings("unchecked") + public SinkWriter createWriter(WriterInitContext context) throws IOException { + final CollectionSinkWriter writer = + (CollectionSinkWriter) + writers.computeIfAbsent( + context.getJobInfo().getJobId(), + (k) -> new CollectionSinkWriter()); + return writer; + } + + private static class CollectionSinkWriter implements SinkWriter { + + private final Set elements = + Collections.newSetFromMap(new ConcurrentHashMap<>()); + + @Override + public void write(IN element, Context context) + throws IOException, InterruptedException { + elements.add(element); + } + + @Override + public void flush(boolean endOfInput) throws IOException, InterruptedException {} + + @Override + public void close() throws Exception {} + + @SuppressWarnings("unchecked") + public Set getElementsSet() { + return (Set) elements; + } } } }