diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/RecordContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/RecordContext.java index 39e9d5ac1b6e7..1341c8a6f337a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/RecordContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/RecordContext.java @@ -64,7 +64,7 @@ public class RecordContext extends ReferenceCounted, Object> namespaces = null; /** User-defined variables. */ - private final AtomicReferenceArray declaredVariables; + private final AtomicReferenceArray contextVariables; /** * The extra context info which is used to hold customized data defined by state backend. The @@ -100,7 +100,7 @@ public RecordContext( this.disposer = disposer; this.keyGroup = keyGroup; this.epoch = epoch; - this.declaredVariables = variables; + this.contextVariables = variables; } public Object getRecord() { @@ -152,16 +152,16 @@ public void setNamespace(InternalPartitionedState state, N namespace) { @SuppressWarnings("unchecked") public T getVariable(int i) { checkVariableIndex(i); - return (T) declaredVariables.get(i); + return (T) contextVariables.get(i); } public void setVariable(int i, T value) { checkVariableIndex(i); - declaredVariables.set(i, value); + contextVariables.set(i, value); } private void checkVariableIndex(int i) { - if (i >= declaredVariables.length()) { + if (i >= contextVariables.length()) { throw new UnsupportedOperationException( "Variable index out of bounds. Maybe you are accessing " + "a variable that have not been declared."); @@ -169,7 +169,7 @@ private void checkVariableIndex(int i) { } AtomicReferenceArray getVariablesReference() { - return declaredVariables; + return contextVariables; } public void setExtra(Object extra) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/ContextVariable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/ContextVariable.java new file mode 100644 index 0000000000000..c3976dd83131b --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/ContextVariable.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.asyncprocessing.declare; + +import javax.annotation.Nullable; + +import java.util.function.Supplier; + +/** A value that will have different values across different contexts. */ +public class ContextVariable { + + final DeclarationManager manager; + + final int ordinal; + + @Nullable final Supplier initializer; + + boolean initialized = false; + + ContextVariable(DeclarationManager manager, int ordinal, Supplier initializer) { + this.manager = manager; + this.ordinal = ordinal; + this.initializer = initializer; + } + + public T get() { + if (!initialized && initializer != null) { + manager.setVariableValue(ordinal, initializer.get()); + initialized = true; + } + return manager.getVariableValue(ordinal); + } + + public void set(T newValue) { + manager.setVariableValue(ordinal, newValue); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclarationContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclarationContext.java index 43608bd472799..33cbfd72ab24c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclarationContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclarationContext.java @@ -91,7 +91,19 @@ public NamedBiFunction declare( public DeclaredVariable declareVariable( TypeSerializer serializer, String name, @Nullable Supplier initialValue) throws DeclarationException { - return manager.register(serializer, name, initialValue); + return manager.registerVariable(serializer, name, initialValue); + } + + /** + * Declare a variable that will keep value across callback with same context. This value cannot + * be serialized into checkpoint. + * + * @param initializer the initializer of variable. Can be null if no need to initialize. + * @param The type of value. + */ + public ContextVariable declareVariable(@Nullable Supplier initializer) + throws DeclarationException { + return manager.registerVariable(initializer); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclarationManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclarationManager.java index 26c4b1a9a4f64..1f9cae174f1f0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclarationManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclarationManager.java @@ -38,6 +38,8 @@ public class DeclarationManager { private int nextValidNameSequence = 0; + private int contextVariableCount = 0; + public DeclarationManager() { this.knownCallbacks = new HashMap<>(); this.knownVariables = new HashMap<>(); @@ -50,14 +52,19 @@ T register(T knownCallback) throws DeclarationExceptio return knownCallback; } - DeclaredVariable register( + ContextVariable registerVariable(@Nullable Supplier initializer) + throws DeclarationException { + return new ContextVariable<>(this, contextVariableCount++, initializer); + } + + DeclaredVariable registerVariable( TypeSerializer serializer, String name, @Nullable Supplier initializer) throws DeclarationException { if (knownVariables.containsKey(name)) { throw new DeclarationException("Duplicated variable key " + name); } DeclaredVariable variable = - new DeclaredVariable<>(this, knownVariables.size(), serializer, name, initializer); + new DeclaredVariable<>(this, contextVariableCount++, serializer, name, initializer); knownVariables.put(name, variable); return variable; } @@ -81,7 +88,7 @@ public void setVariableValue(int ordinal, T value) { } public int variableCount() { - return knownVariables.size(); + return contextVariableCount; } String nextAssignedName(String prefix) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclaredVariable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclaredVariable.java index 305a87fd7c6ae..150febe4558ae 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclaredVariable.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/declare/DeclaredVariable.java @@ -25,41 +25,20 @@ import java.util.function.Supplier; /** A variable declared in async state processing. The value could be persisted in checkpoint. */ -public class DeclaredVariable { - - final DeclarationManager manager; - - final int ordinal; +public class DeclaredVariable extends ContextVariable { final TypeSerializer typeSerializer; final String name; - @Nullable final Supplier initializer; - DeclaredVariable( DeclarationManager manager, int ordinal, TypeSerializer typeSerializer, String name, @Nullable Supplier initializer) { - this.manager = manager; - this.ordinal = ordinal; + super(manager, ordinal, initializer); this.typeSerializer = typeSerializer; this.name = name; - this.initializer = initializer; - } - - public T get() { - T t = manager.getVariableValue(ordinal); - if (t == null && initializer != null) { - t = initializer.get(); - manager.setVariableValue(ordinal, t); - } - return t; - } - - public void set(T newValue) { - manager.setVariableValue(ordinal, newValue); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/asyncprocessing/operators/AsyncKeyedProcessOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/asyncprocessing/operators/AsyncKeyedProcessOperatorTest.java index 576169017d9b5..5699447572a91 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/asyncprocessing/operators/AsyncKeyedProcessOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/asyncprocessing/operators/AsyncKeyedProcessOperatorTest.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.state.StateFutureUtils; +import org.apache.flink.runtime.asyncprocessing.declare.ContextVariable; import org.apache.flink.runtime.asyncprocessing.declare.DeclarationContext; import org.apache.flink.runtime.asyncprocessing.declare.DeclarationException; import org.apache.flink.runtime.asyncprocessing.declare.NamedCallback; @@ -61,11 +62,11 @@ public void testNormalProcessor(boolean chained) throws Exception { testOperator, (e) -> e.f0, TypeInformation.of(Integer.class))) { testHarness.open(); testHarness.processElement(new StreamRecord<>(Tuple2.of(5, "5"))); - expectedOutput.add(new StreamRecord<>("12")); - assertThat(function.getValue()).isEqualTo(12); + expectedOutput.add(new StreamRecord<>("11")); + assertThat(function.getValue()).isEqualTo(11); testHarness.processElement(new StreamRecord<>(Tuple2.of(6, "6"))); - expectedOutput.add(new StreamRecord<>("38")); - assertThat(function.getValue()).isEqualTo(38); + expectedOutput.add(new StreamRecord<>("24")); + assertThat(function.getValue()).isEqualTo(24); assertThat(testHarness.getOutput()).containsExactly(expectedOutput.toArray()); } } @@ -86,13 +87,13 @@ public void testTimerProcessor() throws Exception { testHarness.processElement(new StreamRecord<>(Tuple2.of(6, "5"))); assertThat(function.getValue()).isEqualTo(0); testHarness.processWatermark(5L); - expectedOutput.add(new StreamRecord<>("12", 5L)); + expectedOutput.add(new StreamRecord<>("11", 5L)); expectedOutput.add(new Watermark(5L)); - assertThat(function.getValue()).isEqualTo(12); + assertThat(function.getValue()).isEqualTo(11); testHarness.processWatermark(6L); - expectedOutput.add(new StreamRecord<>("38", 6L)); + expectedOutput.add(new StreamRecord<>("24", 6L)); expectedOutput.add(new Watermark(6L)); - assertThat(function.getValue()).isEqualTo(38); + assertThat(function.getValue()).isEqualTo(24); assertThat(testHarness.getOutput()).containsExactly(expectedOutput.toArray()); } } @@ -112,6 +113,7 @@ private static class TestNormalDeclarationFunction extends TestDeclarationFuncti public ThrowingConsumer, Exception> declareProcess( DeclarationContext context, Context ctx, Collector out) throws DeclarationException { + ContextVariable inputValue = context.declareVariable(null); NamedFunction> adder = context.declare( "adder", @@ -122,12 +124,15 @@ public ThrowingConsumer, Exception> declareProcess( context.declare( "doubler", (v) -> { - value.addAndGet(v); + value.addAndGet(inputValue.get()); out.collect(String.valueOf(value.get())); }); assertThat(adder).isInstanceOf(NamedCallback.class); assertThat(doubler).isInstanceOf(NamedCallback.class); return (e) -> { + if (inputValue.get() == null) { + inputValue.set(e.f0); + } value.addAndGet(e.f0); StateFutureUtils.completedVoidFuture().thenCompose(adder).thenAccept(doubler); }; @@ -140,9 +145,13 @@ private static class TestChainDeclarationFunction extends TestDeclarationFunctio public ThrowingConsumer, Exception> declareProcess( DeclarationContext context, Context ctx, Collector out) throws DeclarationException { + ContextVariable inputValue = context.declareVariable(null); return context.>declareChain() .thenCompose( e -> { + if (inputValue.get() == null) { + inputValue.set(e.f0); + } value.addAndGet(e.f0); return StateFutureUtils.completedVoidFuture(); }) @@ -150,7 +159,7 @@ public ThrowingConsumer, Exception> declareProcess( .withName("adder") .thenAccept( (v) -> { - value.addAndGet(v); + value.addAndGet(inputValue.get()); out.collect(String.valueOf(value.get())); }) .withName("doubler") @@ -176,9 +185,13 @@ public ThrowingConsumer, Exception> declareProcess( public ThrowingConsumer declareOnTimer( DeclarationContext context, OnTimerContext ctx, Collector out) throws DeclarationException { + ContextVariable inputValue = context.declareVariable(null); return context.declareChain() .thenCompose( e -> { + if (inputValue.get() == null) { + inputValue.set(e.intValue()); + } value.addAndGet(e.intValue()); return StateFutureUtils.completedVoidFuture(); }) @@ -186,7 +199,7 @@ public ThrowingConsumer declareOnTimer( .withName("adder") .thenAccept( (v) -> { - value.addAndGet(v); + value.addAndGet(inputValue.get()); out.collect(String.valueOf(value.get())); }) .withName("doubler")