From c50fd1139b25961f4d01bbe02936f334ead1b703 Mon Sep 17 00:00:00 2001 From: Arek Burdach Date: Thu, 4 Jul 2024 16:52:47 +0200 Subject: [PATCH 1/3] [NU-1679] table join component --- docs/Changelog.md | 1 + .../api/process/FlinkCustomNodeContext.scala | 17 +- .../util/transformer/join/BranchType.java | 0 .../flink/table/join/TableJoinTest.scala | 105 ++++++++ .../table/FlinkTableComponentProvider.scala | 5 + .../flink/table/join/TableJoinComponent.scala | 241 ++++++++++++++++++ .../flink/table/utils/RowConversions.scala | 47 +++- 7 files changed, 412 insertions(+), 4 deletions(-) rename engine/flink/{components/base-unbounded => components-utils}/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java (100%) create mode 100644 engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala create mode 100644 engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala diff --git a/docs/Changelog.md b/docs/Changelog.md index 146ebb6a49a..0d5d5f75551 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -22,6 +22,7 @@ in table name * [#6353](https://github.com/TouK/nussknacker/pull/6353) Performance improvement: simple types such as numbers, boolean, string, date types and arrays are serialized/deserialized more optimal in aggregates +* [#6353](https://github.com/TouK/nussknacker/pull/6353) Added `join` component available in Batch processing mode 1.16.1 (16 July 2024) ------------------------- diff --git a/engine/flink/components-api/src/main/scala/pl/touk/nussknacker/engine/flink/api/process/FlinkCustomNodeContext.scala b/engine/flink/components-api/src/main/scala/pl/touk/nussknacker/engine/flink/api/process/FlinkCustomNodeContext.scala index 4bbc46b4e62..5d76b511b13 100644 --- a/engine/flink/components-api/src/main/scala/pl/touk/nussknacker/engine/flink/api/process/FlinkCustomNodeContext.scala +++ b/engine/flink/components-api/src/main/scala/pl/touk/nussknacker/engine/flink/api/process/FlinkCustomNodeContext.scala @@ -58,10 +58,23 @@ case class FlinkCustomNodeContext( lazy val forUnknown: TypeInformation[ValueWithContext[AnyRef]] = forType[AnyRef](Unknown) } + def branchValidationContext(branchId: String): ValidationContext = asJoinContext.getOrElse( + branchId, + throw new IllegalArgumentException(s"No validation context for branchId [$branchId] is defined") + ) + private def asOneOutputContext: ValidationContext = - validationContext.left.getOrElse(throw new IllegalArgumentException("This node is a join, use asJoinContext")) + validationContext.left.getOrElse( + throw new IllegalArgumentException( + "This node is a join, asJoinContext should be used to extract validation context" + ) + ) private def asJoinContext: Map[String, ValidationContext] = - validationContext.getOrElse(throw new IllegalArgumentException()) + validationContext.getOrElse( + throw new IllegalArgumentException( + "This node is a single input node. asOneOutputContext should be used to extract validation context" + ) + ) } diff --git a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java b/engine/flink/components-utils/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java similarity index 100% rename from engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java rename to engine/flink/components-utils/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala new file mode 100644 index 00000000000..2fcb4fcfb87 --- /dev/null +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala @@ -0,0 +1,105 @@ +package pl.touk.nussknacker.engine.flink.table.join + +import com.typesafe.config.ConfigFactory +import org.apache.flink.api.common.RuntimeExecutionMode +import org.apache.flink.api.connector.source.Boundedness +import org.scalatest.Inside +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import pl.touk.nussknacker.engine.api.component.ComponentDefinition +import pl.touk.nussknacker.engine.build.{GraphBuilder, ScenarioBuilder} +import pl.touk.nussknacker.engine.flink.table.FlinkTableComponentProvider +import pl.touk.nussknacker.engine.flink.table.join.TableJoinTest.OrderProduct +import pl.touk.nussknacker.engine.flink.test.FlinkSpec +import pl.touk.nussknacker.engine.flink.util.transformer.join.BranchType +import pl.touk.nussknacker.engine.util.test.TestScenarioRunner +import pl.touk.nussknacker.test.ValidatedValuesDetailedMessage + +import scala.beans.BeanProperty + +class TableJoinTest extends AnyFunSuite with FlinkSpec with Matchers with Inside with ValidatedValuesDetailedMessage { + + import pl.touk.nussknacker.engine.flink.util.test.FlinkTestScenarioRunner._ + import pl.touk.nussknacker.engine.spel.SpelExtension._ + + import scala.jdk.CollectionConverters._ + + private lazy val additionalComponents: List[ComponentDefinition] = + FlinkTableComponentProvider.configIndependentComponents ::: Nil + + private lazy val runner = TestScenarioRunner + .flinkBased(ConfigFactory.empty(), flinkMiniCluster) + .withExtraComponents(additionalComponents) + .build() + + private val MainBranchId = "main" + + private val JoinedBranchId = "joined" + + private val JoinNodeId = "joined-node-id" + + test("should be able to join") { + val scenario = ScenarioBuilder + .streaming("sample-join-last") + .sources( + GraphBuilder + .source("orders-source", TestScenarioRunner.testDataSource) + .filter("orders-filter", "#input.type == 'order'".spel) + .branchEnd(MainBranchId, JoinNodeId), + GraphBuilder + .source("products-source", TestScenarioRunner.testDataSource) + .filter("product-filter", "#input.type == 'product'".spel) + .branchEnd(JoinedBranchId, JoinNodeId), + GraphBuilder + .join( + JoinNodeId, + "join", + Some("product"), + List( + MainBranchId -> List( + "branchType" -> s"T(${classOf[BranchType].getName}).MAIN".spel, + "key" -> s"#input.productId.toString".spel + ), + JoinedBranchId -> List( + "branchType" -> s"T(${classOf[BranchType].getName}).JOINED".spel, + "key" -> s"#input.id.toString".spel + ) + ), + "output" -> "#input".spel, + ) + .emptySink("end", TestScenarioRunner.testResultSink, "value" -> "{#input, #product}".spel) + ) + + val result = runner.runWithData( + scenario, + List( + OrderProduct("product", 1, -1), + OrderProduct("order", 10, 1), + ), + Boundedness.BOUNDED, + Some(RuntimeExecutionMode.BATCH) + ) + + result.validValue.successes shouldBe List( + List(OrderProduct("order", 10, 1), OrderProduct("product", 1, -1)).asJava, + ) + } + +} + +object TableJoinTest { + + // TODO: split into separate classes and pass two streams to separate source nodes + // productId is dedicated only for order events + // It have to by POJO in order by acceptable by table api operators + case class OrderProduct( + @BeanProperty var `type`: String, + @BeanProperty var id: Int, + @BeanProperty var productId: Int + ) { + + def this() = this(null, -1, -1) + + } + +} diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/FlinkTableComponentProvider.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/FlinkTableComponentProvider.scala index 3504acee442..02e78e586e7 100644 --- a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/FlinkTableComponentProvider.scala +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/FlinkTableComponentProvider.scala @@ -10,6 +10,7 @@ import pl.touk.nussknacker.engine.flink.table.aggregate.TableAggregationFactory import pl.touk.nussknacker.engine.flink.table.extractor.TableExtractor.extractTablesFromFlinkRuntime import pl.touk.nussknacker.engine.flink.table.extractor.SqlStatementReader import pl.touk.nussknacker.engine.flink.table.extractor.SqlStatementReader.SqlStatement +import pl.touk.nussknacker.engine.flink.table.join.TableJoinComponent import pl.touk.nussknacker.engine.flink.table.sink.TableSinkFactory import pl.touk.nussknacker.engine.flink.table.source.TableSourceFactory import pl.touk.nussknacker.engine.util.ResourceLoader @@ -88,6 +89,10 @@ object FlinkTableComponentProvider { ComponentDefinition( "aggregate", new TableAggregationFactory() + ), + ComponentDefinition( + "join", + TableJoinComponent ) ) diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala new file mode 100644 index 00000000000..3cdf5cc2da6 --- /dev/null +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala @@ -0,0 +1,241 @@ +package pl.touk.nussknacker.engine.flink.table.join + +import org.apache.flink.api.common.functions.FlatMapFunction +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.Expressions.$ +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment +import org.apache.flink.types.Row +import org.apache.flink.util.Collector +import pl.touk.nussknacker.engine.api._ +import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.CustomNodeError +import pl.touk.nussknacker.engine.api.context.transformation.{ + DefinedEagerBranchParameter, + DefinedSingleParameter, + JoinDynamicComponent, + NodeDependencyValue +} +import pl.touk.nussknacker.engine.api.context.{OutputVar, ValidationContext} +import pl.touk.nussknacker.engine.api.definition._ +import pl.touk.nussknacker.engine.api.parameter.ParameterName +import pl.touk.nussknacker.engine.flink.api.process.{ + AbstractLazyParameterInterpreterFunction, + FlinkCustomJoinTransformation, + FlinkCustomNodeContext, + FlinkLazyParameterFunctionHelper +} +import pl.touk.nussknacker.engine.flink.table.utils.RowConversions +import pl.touk.nussknacker.engine.flink.table.utils.RowConversions.{TypeInformationDetectionExtension, rowToContext} +import pl.touk.nussknacker.engine.flink.util.transformer.join.BranchType +import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap + +object TableJoinComponent extends CustomStreamTransformer with JoinDynamicComponent[FlinkCustomJoinTransformation] { + + private val contextInternalColumnName = "context" + private val mainKeyInternalColumnName = "mainKey" + private val joinedKeyInternalColumnName = "joinedKey" + private val outputInternalColumnName = "output" + + val BranchTypeParamName: ParameterName = ParameterName("branchType") + + private val BranchTypeParamDeclaration + : ParameterCreatorWithNoDependency with ParameterExtractor[Map[String, BranchType]] = + ParameterDeclaration.branchMandatory[BranchType](BranchTypeParamName).withCreator() + + val KeyParamName: ParameterName = ParameterName("key") + + private val KeyParamDeclaration + : ParameterCreatorWithNoDependency with ParameterExtractor[Map[String, LazyParameter[String]]] = + ParameterDeclaration.branchLazyMandatory[String](KeyParamName).withCreator() + + val OutputParamName: ParameterName = ParameterName("output") + + override type State = Nothing + + override def nodeDependencies: List[NodeDependency] = List(OutputVariableNameDependency) + + override def contextTransformation(contexts: Map[String, ValidationContext], dependencies: List[NodeDependencyValue])( + implicit nodeId: NodeId + ): ContextTransformationDefinition = { + case TransformationStep(Nil, _) => + NextParameters( + List(BranchTypeParamDeclaration, KeyParamDeclaration) + .map(_.createParameter()) + ) + case TransformationStep( + ( + `BranchTypeParamName`, + DefinedEagerBranchParameter(branchTypeByBranchId: Map[String, BranchType] @unchecked, _) + ) :: (`KeyParamName`, _) :: Nil, + _ + ) => + val error = + if (branchTypeByBranchId.values.toList.sorted != BranchType.values().toList) + List( + CustomNodeError( + s"Has to be exactly one MAIN and JOINED branch, got: ${branchTypeByBranchId.values.mkString(", ")}", + Some(BranchTypeParamName) + ) + ) + else + Nil + val joinedVariables = extractJoinedBranchId(branchTypeByBranchId) + .map(contexts) + .getOrElse(ValidationContext()) + .localVariables + .mapValuesNow(AdditionalVariableProvidedInRuntime(_)) + NextParameters( + List( + ParameterDeclaration + .lazyMandatory[AnyRef](OutputParamName) + .withCreator(_.copy(additionalVariables = joinedVariables)) + .createParameter() + ), + error + ) + + case TransformationStep( + ( + `BranchTypeParamName`, + DefinedEagerBranchParameter(branchTypeByBranchId: Map[String, BranchType] @unchecked, _) + ) :: + (`KeyParamName`, _) :: (`OutputParamName`, outputParameter: DefinedSingleParameter) :: Nil, + _ + ) => + val outName = OutputVariableNameDependency.extract(dependencies) + val mainContext = extractMainBranchId(branchTypeByBranchId).map(contexts).getOrElse(ValidationContext()) + FinalResults.forValidation(mainContext)( + _.withVariable(OutputVar.customNode(outName), outputParameter.returnType) + ) + + } + + override def implementation( + params: Params, + dependencies: List[NodeDependencyValue], + finalState: Option[Nothing] + ): FlinkCustomJoinTransformation = new FlinkCustomJoinTransformation { + + override def transform( + inputs: Map[String, DataStream[Context]], + flinkNodeContext: FlinkCustomNodeContext + ): DataStream[ValueWithContext[AnyRef]] = { + val branchTypeByBranchId: Map[String, BranchType] = BranchTypeParamDeclaration.extractValueUnsafe(params) + val mainBranchId = extractMainBranchId(branchTypeByBranchId).get + val joinedBranchId = extractJoinedBranchId(branchTypeByBranchId).get + val mainStream = inputs(mainBranchId) + val joinedStream = inputs(joinedBranchId) + + val env = mainStream.getExecutionEnvironment + val tableEnv = StreamTableEnvironment.create(env) + + val mainTable = tableEnv.fromDataStream( + mainStream.flatMap( + new MainBranchToRowFunction( + KeyParamDeclaration.extractValueUnsafe(params)(mainBranchId), + flinkNodeContext.lazyParameterHelper + ), + mainBranchTypeInfo(flinkNodeContext, mainBranchId) + ) + ) + + val outputLazyParam = params.extractUnsafe[LazyParameter[AnyRef]](OutputParamName) + val outputTypeInfo = + flinkNodeContext.valueWithContextInfo.forBranch[AnyRef](mainBranchId, outputLazyParam.returnType) + + val joinedTable = tableEnv.fromDataStream( + joinedStream.flatMap( + new JoinedBranchToRowFunction( + KeyParamDeclaration.extractValueUnsafe(params)(joinedBranchId), + outputLazyParam, + flinkNodeContext.lazyParameterHelper + ), + joinedBranchTypeInfo(flinkNodeContext, outputLazyParam) + ) + ) + + val resultTable = + mainTable.join(joinedTable, $(joinedKeyInternalColumnName).isEqual($(mainKeyInternalColumnName))) + + tableEnv + .toDataStream(resultTable) + .map( + (row: Row) => + ValueWithContext[AnyRef]( + row.getField(outputInternalColumnName), + rowToContext(row.getField(contextInternalColumnName).asInstanceOf[Row]) + ), + outputTypeInfo + ) + } + + } + + private class MainBranchToRowFunction( + mainKeyLazyParam: LazyParameter[String], + lazyParameterHelper: FlinkLazyParameterFunctionHelper + ) extends AbstractLazyParameterInterpreterFunction(lazyParameterHelper) + with FlatMapFunction[Context, Row] { + + private lazy val evaluateKey = toEvaluateFunctionConverter.toEvaluateFunction(mainKeyLazyParam) + + override def flatMap(context: Context, out: Collector[Row]): Unit = { + collectHandlingErrors(context, out) { + val row = Row.withNames() + row.setField(contextInternalColumnName, RowConversions.contextToRow(context)) + row.setField(mainKeyInternalColumnName, evaluateKey(context)) + row + } + } + + } + + private def mainBranchTypeInfo(flinkNodeContext: FlinkCustomNodeContext, mainBranchId: String) = { + Types.ROW_NAMED( + Array(contextInternalColumnName, mainKeyInternalColumnName), + flinkNodeContext.typeInformationDetection.contextRowTypeInfo( + flinkNodeContext.branchValidationContext(mainBranchId) + ), + Types.STRING + ) + } + + private class JoinedBranchToRowFunction( + joinedKeyLazyParam: LazyParameter[String], + outputLazyParam: LazyParameter[AnyRef], + lazyParameterHelper: FlinkLazyParameterFunctionHelper + ) extends AbstractLazyParameterInterpreterFunction(lazyParameterHelper) + with FlatMapFunction[Context, Row] { + + private lazy val evaluateKey = toEvaluateFunctionConverter.toEvaluateFunction(joinedKeyLazyParam) + + private lazy val evaluateOutput = toEvaluateFunctionConverter.toEvaluateFunction(outputLazyParam) + + override def flatMap(context: Context, out: Collector[Row]): Unit = { + collectHandlingErrors(context, out) { + val row = Row.withNames() + row.setField(joinedKeyInternalColumnName, evaluateKey(context)) + row.setField(outputInternalColumnName, evaluateOutput(context)) + row + } + } + + } + + private def joinedBranchTypeInfo(flinkNodeContext: FlinkCustomNodeContext, outputLazyParam: LazyParameter[_]) = { + Types.ROW_NAMED( + Array(joinedKeyInternalColumnName, outputInternalColumnName), + Types.STRING, + flinkNodeContext.typeInformationDetection.forType(outputLazyParam.returnType) + ) + } + + private def extractMainBranchId(branchTypeByBranchId: Map[String, BranchType]) = { + branchTypeByBranchId.find(_._2 == BranchType.MAIN).map(_._1) + } + + private def extractJoinedBranchId(branchTypeByBranchId: Map[String, BranchType]) = { + branchTypeByBranchId.find(_._2 == BranchType.JOINED).map(_._1) + } + +} diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala index 3299441789c..14312b2ad44 100644 --- a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala @@ -1,19 +1,29 @@ package pl.touk.nussknacker.engine.flink.table.utils -import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeinfo.{TypeInformation, Types} import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.types.Row +import pl.touk.nussknacker.engine.api.Context +import pl.touk.nussknacker.engine.api.context.ValidationContext import pl.touk.nussknacker.engine.api.typed.typing.TypedObjectTypingResult import pl.touk.nussknacker.engine.flink.api.typeinformation.TypeInformationDetection +import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap + +import java.util object RowConversions { import scala.jdk.CollectionConverters._ def rowToMap(row: Row): java.util.Map[String, Any] = { + val fields: Map[String, AnyRef] = rowToScalaMap(row) + new util.HashMap[String, Any](fields.asJava) + } + + private def rowToScalaMap(row: Row): Map[String, AnyRef] = { val fieldNames = row.getFieldNames(true).asScala val fields = fieldNames.map(n => n -> row.getField(n)).toMap - new java.util.HashMap[String, Any](fields.asJava) + fields } def mapToRow(map: java.util.Map[String, Any], columnNames: Iterable[String]): Row = { @@ -22,6 +32,32 @@ object RowConversions { row } + private def scalaMapToRow(map: Map[String, Any]): Row = { + val row = Row.withNames() + // TODO: add type alignment, e.g. bigint is represented as long by flink tables + map.foreach { case (name, value) => + row.setField(name, value) + } + row + } + + def contextToRow(context: Context): Row = { + val row = Row.withPositions(context.parentContext.map(_ => 3).getOrElse(2)) + val variablesRow = scalaMapToRow(context.variables) + row.setField(0, context.id) + row.setField(1, variablesRow) + context.parentContext.map(contextToRow).foreach(row.setField(2, _)) + row + } + + def rowToContext(row: Row): Context = { + Context( + row.getField(0).asInstanceOf[String], + rowToScalaMap(row.getField(1).asInstanceOf[Row]), + Option(row).filter(_.getArity >= 3).map(_.getField(2).asInstanceOf[Row]).map(rowToContext) + ) + } + implicit class TypeInformationDetectionExtension(typeInformationDetection: TypeInformationDetection) { def rowTypeInfoWithColumnsInGivenOrder( @@ -37,6 +73,13 @@ object RowConversions { new RowTypeInfo(typeInfos.toArray[TypeInformation[_]], fieldNames.toArray) } + def contextRowTypeInfo(validationContext: ValidationContext): TypeInformation[_] = { + val (fieldNames, typeInfos) = + validationContext.localVariables.mapValuesNow(typeInformationDetection.forType).unzip + val variablesRow = new RowTypeInfo(typeInfos.toArray[TypeInformation[_]], fieldNames.toArray) + Types.ROW(Types.STRING :: variablesRow :: validationContext.parent.map(contextRowTypeInfo).toList: _*) + } + } } From e3d49f45f4c48277ce6e1093012bc13afcaaf274 Mon Sep 17 00:00:00 2001 From: Arek Burdach Date: Tue, 23 Jul 2024 16:09:33 +0200 Subject: [PATCH 2/3] review fixes --- .../flink/table/join/TableJoinTest.scala | 62 +++++++++++++++---- .../flink/table/join/TableJoinComponent.scala | 8 ++- .../flink/table/utils/RowConversions.scala | 1 - 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala index 2fcb4fcfb87..fa028899dd3 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala @@ -3,13 +3,13 @@ package pl.touk.nussknacker.engine.flink.table.join import com.typesafe.config.ConfigFactory import org.apache.flink.api.common.RuntimeExecutionMode import org.apache.flink.api.connector.source.Boundedness -import org.scalatest.Inside +import org.scalatest.{Inside, LoneElement} import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers import pl.touk.nussknacker.engine.api.component.ComponentDefinition import pl.touk.nussknacker.engine.build.{GraphBuilder, ScenarioBuilder} import pl.touk.nussknacker.engine.flink.table.FlinkTableComponentProvider -import pl.touk.nussknacker.engine.flink.table.join.TableJoinTest.OrderProduct +import pl.touk.nussknacker.engine.flink.table.join.TableJoinTest.OrderOrProduct import pl.touk.nussknacker.engine.flink.test.FlinkSpec import pl.touk.nussknacker.engine.flink.util.transformer.join.BranchType import pl.touk.nussknacker.engine.util.test.TestScenarioRunner @@ -17,7 +17,13 @@ import pl.touk.nussknacker.test.ValidatedValuesDetailedMessage import scala.beans.BeanProperty -class TableJoinTest extends AnyFunSuite with FlinkSpec with Matchers with Inside with ValidatedValuesDetailedMessage { +class TableJoinTest + extends AnyFunSuite + with FlinkSpec + with Matchers + with Inside + with ValidatedValuesDetailedMessage + with LoneElement { import pl.touk.nussknacker.engine.flink.util.test.FlinkTestScenarioRunner._ import pl.touk.nussknacker.engine.spel.SpelExtension._ @@ -67,22 +73,40 @@ class TableJoinTest extends AnyFunSuite with FlinkSpec with Matchers with Inside ), "output" -> "#input".spel, ) - .emptySink("end", TestScenarioRunner.testResultSink, "value" -> "{#input, #product}".spel) + .emptySink( + "end", + TestScenarioRunner.testResultSink, + "value" -> + """{ + | orderId: #input.id, + | product: { + | id: #product.id, + | name: #product.name + | } + |}""".stripMargin.spel + ) ) - val result = runner.runWithData( + val productId = 1 + val orderId = 10 + val enrichedOrders = runner.runWithData[OrderOrProduct, java.util.Map[String, AnyRef]]( scenario, List( - OrderProduct("product", 1, -1), - OrderProduct("order", 10, 1), + OrderOrProduct.createProduct(productId, "Foo product"), + OrderOrProduct.createOrder(orderId, productId), ), Boundedness.BOUNDED, Some(RuntimeExecutionMode.BATCH) ) - result.validValue.successes shouldBe List( - List(OrderProduct("order", 10, 1), OrderProduct("product", 1, -1)).asJava, - ) + val expectedEnrichedOrder = Map( + "orderId" -> orderId, + "product" -> Map( + "id" -> productId, + "name" -> "Foo product" + ).asJava + ).asJava + enrichedOrders.validValue.successes.loneElement shouldEqual expectedEnrichedOrder } } @@ -91,14 +115,28 @@ object TableJoinTest { // TODO: split into separate classes and pass two streams to separate source nodes // productId is dedicated only for order events + // name is dedicated only for order events // It have to by POJO in order by acceptable by table api operators - case class OrderProduct( + class OrderOrProduct( @BeanProperty var `type`: String, @BeanProperty var id: Int, + @BeanProperty var name: String, @BeanProperty var productId: Int ) { - def this() = this(null, -1, -1) + def this() = this(null, -1, null, -1) + + } + + object OrderOrProduct { + + def createOrder(id: Int, productId: Int): OrderOrProduct = { + new OrderOrProduct("order", id, null, productId) + } + + def createProduct(id: Int, name: String): OrderOrProduct = { + new OrderOrProduct("product", id, name, -1) + } } diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala index 3cdf5cc2da6..99c1969eb5f 100644 --- a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala @@ -231,11 +231,15 @@ object TableJoinComponent extends CustomStreamTransformer with JoinDynamicCompon } private def extractMainBranchId(branchTypeByBranchId: Map[String, BranchType]) = { - branchTypeByBranchId.find(_._2 == BranchType.MAIN).map(_._1) + branchTypeByBranchId.collectFirst { case (branchId, BranchType.MAIN) => + branchId + } } private def extractJoinedBranchId(branchTypeByBranchId: Map[String, BranchType]) = { - branchTypeByBranchId.find(_._2 == BranchType.JOINED).map(_._1) + branchTypeByBranchId.collectFirst { case (branchId, BranchType.JOINED) => + branchId + } } } diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala index 14312b2ad44..817b2001709 100644 --- a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala @@ -34,7 +34,6 @@ object RowConversions { private def scalaMapToRow(map: Map[String, Any]): Row = { val row = Row.withNames() - // TODO: add type alignment, e.g. bigint is represented as long by flink tables map.foreach { case (name, value) => row.setField(name, value) } From 579d36089a55088b6158df915c5fde191018f6ea Mon Sep 17 00:00:00 2001 From: Arek Burdach Date: Wed, 24 Jul 2024 11:18:49 +0200 Subject: [PATCH 3/3] review fixes - join type --- .../flink/table/join/TableJoinTest.scala | 185 ++++++++++++------ .../engine/flink/table/join/JoinType.java | 5 + .../flink/table/join/TableJoinComponent.scala | 61 ++++-- 3 files changed, 172 insertions(+), 79 deletions(-) create mode 100644 engine/flink/components/table/src/main/java/pl/touk/nussknacker/engine/flink/table/join/JoinType.java diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala index fa028899dd3..c5cae20f329 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala @@ -38,77 +38,140 @@ class TableJoinTest .withExtraComponents(additionalComponents) .build() - private val MainBranchId = "main" - - private val JoinedBranchId = "joined" - - private val JoinNodeId = "joined-node-id" - - test("should be able to join") { - val scenario = ScenarioBuilder - .streaming("sample-join-last") - .sources( - GraphBuilder - .source("orders-source", TestScenarioRunner.testDataSource) - .filter("orders-filter", "#input.type == 'order'".spel) - .branchEnd(MainBranchId, JoinNodeId), - GraphBuilder - .source("products-source", TestScenarioRunner.testDataSource) - .filter("product-filter", "#input.type == 'product'".spel) - .branchEnd(JoinedBranchId, JoinNodeId), - GraphBuilder - .join( - JoinNodeId, - "join", - Some("product"), - List( - MainBranchId -> List( - "branchType" -> s"T(${classOf[BranchType].getName}).MAIN".spel, - "key" -> s"#input.productId.toString".spel - ), - JoinedBranchId -> List( - "branchType" -> s"T(${classOf[BranchType].getName}).JOINED".spel, - "key" -> s"#input.id.toString".spel - ) - ), - "output" -> "#input".spel, - ) - .emptySink( - "end", - TestScenarioRunner.testResultSink, - "value" -> - """{ - | orderId: #input.id, - | product: { - | id: #product.id, - | name: #product.name - | } - |}""".stripMargin.spel - ) - ) - - val productId = 1 - val orderId = 10 + private val mainBranchId = "main" + private val joinedBranchId = "joined" + + private val joinNodeId = "joined-node-id" + + private val someProduct = OrderOrProduct.createProduct(1, "Foo product") + private val anotherProduct = OrderOrProduct.createProduct(2, "Bar product") + private val delayedProduct = OrderOrProduct.createProduct(3, "Delayed product") + + private val orderReferringToExistingProduct = OrderOrProduct.createOrder(10, someProduct.id) + + private val nonExistingProductId = 100 + private val orderReferringToNonExistingProduct = OrderOrProduct.createOrder(20, nonExistingProductId) + + private val orderReferringToDelayedProduct = OrderOrProduct.createOrder(30, delayedProduct.id) + + test("should inner join stream") { + val enrichedOrders = runner.runWithData[OrderOrProduct, java.util.Map[String, AnyRef]]( + prepareJoiningScenario(JoinType.INNER), + List( + someProduct, + anotherProduct, + orderReferringToExistingProduct, + orderReferringToNonExistingProduct, + orderReferringToDelayedProduct, + delayedProduct + ), + Boundedness.BOUNDED, + Some(RuntimeExecutionMode.BATCH) + ) + + enrichedOrders.validValue.errors shouldBe empty + enrichedOrders.validValue.successes shouldEqual List( + Map( + "orderId" -> orderReferringToExistingProduct.id, + "product" -> Map( + "id" -> someProduct.id, + "name" -> someProduct.name + ).asJava + ).asJava, + Map( + "orderId" -> orderReferringToDelayedProduct.id, + "product" -> Map( + "id" -> delayedProduct.id, + "name" -> delayedProduct.name + ).asJava + ).asJava + ) + } + + test("should outer join stream") { val enrichedOrders = runner.runWithData[OrderOrProduct, java.util.Map[String, AnyRef]]( - scenario, + prepareJoiningScenario(JoinType.OUTER), List( - OrderOrProduct.createProduct(productId, "Foo product"), - OrderOrProduct.createOrder(orderId, productId), + someProduct, + anotherProduct, + orderReferringToExistingProduct, + orderReferringToNonExistingProduct, + orderReferringToDelayedProduct, + delayedProduct ), Boundedness.BOUNDED, Some(RuntimeExecutionMode.BATCH) ) - val expectedEnrichedOrder = Map( - "orderId" -> orderId, - "product" -> Map( - "id" -> productId, - "name" -> "Foo product" + enrichedOrders.validValue.errors shouldBe empty + enrichedOrders.validValue.successes shouldEqual List( + Map( + "orderId" -> orderReferringToExistingProduct.id, + "product" -> Map( + "id" -> someProduct.id, + "name" -> someProduct.name + ).asJava + ).asJava, + Map( + "orderId" -> orderReferringToNonExistingProduct.id, + "product" -> Map( + "id" -> null, + "name" -> null + ).asJava + ).asJava, + Map( + "orderId" -> orderReferringToDelayedProduct.id, + "product" -> Map( + "id" -> delayedProduct.id, + "name" -> delayedProduct.name + ).asJava ).asJava - ).asJava - enrichedOrders.validValue.successes.loneElement shouldEqual expectedEnrichedOrder + ) } + private def prepareJoiningScenario(joinType: JoinType) = ScenarioBuilder + .streaming(classOf[TableJoinTest].getSimpleName) + .sources( + GraphBuilder + .source("orders-source", TestScenarioRunner.testDataSource) + .filter("orders-filter", "#input.type == 'order'".spel) + .branchEnd(mainBranchId, joinNodeId), + GraphBuilder + .source("products-source", TestScenarioRunner.testDataSource) + .filter("product-filter", "#input.type == 'product'".spel) + .branchEnd(joinedBranchId, joinNodeId), + GraphBuilder + .join( + joinNodeId, + "join", + Some("product"), + List( + mainBranchId -> List( + "Branch Type" -> s"T(${classOf[BranchType].getName}).${BranchType.MAIN}".spel, + "Key" -> s"#input.productId.toString".spel + ), + joinedBranchId -> List( + "Branch Type" -> s"T(${classOf[BranchType].getName}).${BranchType.JOINED}".spel, + "Key" -> s"#input.id.toString".spel + ) + ), + "Join Type" -> s"T(${classOf[JoinType].getName}).$joinType".spel, + "Output" -> "#input".spel, + ) + .emptySink( + "end", + TestScenarioRunner.testResultSink, + "value" -> + """{ + | orderId: #input.id, + | product: { + | id: #product?.id, + | name: #product?.name + | } + |}""".stripMargin.spel + ) + ) + } object TableJoinTest { diff --git a/engine/flink/components/table/src/main/java/pl/touk/nussknacker/engine/flink/table/join/JoinType.java b/engine/flink/components/table/src/main/java/pl/touk/nussknacker/engine/flink/table/join/JoinType.java new file mode 100644 index 00000000000..0530d74624f --- /dev/null +++ b/engine/flink/components/table/src/main/java/pl/touk/nussknacker/engine/flink/table/join/JoinType.java @@ -0,0 +1,5 @@ +package pl.touk.nussknacker.engine.flink.table.join; + +public enum JoinType { + INNER, OUTER +} diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala index 99c1969eb5f..b7d6c759f4d 100644 --- a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala @@ -11,6 +11,7 @@ import pl.touk.nussknacker.engine.api._ import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.CustomNodeError import pl.touk.nussknacker.engine.api.context.transformation.{ DefinedEagerBranchParameter, + DefinedEagerParameter, DefinedSingleParameter, JoinDynamicComponent, NodeDependencyValue @@ -18,6 +19,7 @@ import pl.touk.nussknacker.engine.api.context.transformation.{ import pl.touk.nussknacker.engine.api.context.{OutputVar, ValidationContext} import pl.touk.nussknacker.engine.api.definition._ import pl.touk.nussknacker.engine.api.parameter.ParameterName +import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass} import pl.touk.nussknacker.engine.flink.api.process.{ AbstractLazyParameterInterpreterFunction, FlinkCustomJoinTransformation, @@ -29,28 +31,36 @@ import pl.touk.nussknacker.engine.flink.table.utils.RowConversions.{TypeInformat import pl.touk.nussknacker.engine.flink.util.transformer.join.BranchType import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap -object TableJoinComponent extends CustomStreamTransformer with JoinDynamicComponent[FlinkCustomJoinTransformation] { +object TableJoinComponent + extends CustomStreamTransformer + with JoinDynamicComponent[FlinkCustomJoinTransformation] + with WithExplicitTypesToExtract { private val contextInternalColumnName = "context" private val mainKeyInternalColumnName = "mainKey" private val joinedKeyInternalColumnName = "joinedKey" private val outputInternalColumnName = "output" - val BranchTypeParamName: ParameterName = ParameterName("branchType") + val BranchTypeParamName: ParameterName = ParameterName("Branch Type") private val BranchTypeParamDeclaration : ParameterCreatorWithNoDependency with ParameterExtractor[Map[String, BranchType]] = ParameterDeclaration.branchMandatory[BranchType](BranchTypeParamName).withCreator() - val KeyParamName: ParameterName = ParameterName("key") + val JoinTypeParamName: ParameterName = ParameterName("Join Type") + + private val JoinTypeParamDeclaration: ParameterCreatorWithNoDependency with ParameterExtractor[JoinType] = + ParameterDeclaration.mandatory[JoinType](JoinTypeParamName).withCreator() + + val KeyParamName: ParameterName = ParameterName("Key") private val KeyParamDeclaration : ParameterCreatorWithNoDependency with ParameterExtractor[Map[String, LazyParameter[String]]] = ParameterDeclaration.branchLazyMandatory[String](KeyParamName).withCreator() - val OutputParamName: ParameterName = ParameterName("output") + val OutputParamName: ParameterName = ParameterName("Output") - override type State = Nothing + override type State = JoinType override def nodeDependencies: List[NodeDependency] = List(OutputVariableNameDependency) @@ -86,6 +96,7 @@ object TableJoinComponent extends CustomStreamTransformer with JoinDynamicCompon .mapValuesNow(AdditionalVariableProvidedInRuntime(_)) NextParameters( List( + JoinTypeParamDeclaration.createParameter(), ParameterDeclaration .lazyMandatory[AnyRef](OutputParamName) .withCreator(_.copy(additionalVariables = joinedVariables)) @@ -93,27 +104,30 @@ object TableJoinComponent extends CustomStreamTransformer with JoinDynamicCompon ), error ) - case TransformationStep( ( `BranchTypeParamName`, DefinedEagerBranchParameter(branchTypeByBranchId: Map[String, BranchType] @unchecked, _) ) :: - (`KeyParamName`, _) :: (`OutputParamName`, outputParameter: DefinedSingleParameter) :: Nil, + (`KeyParamName`, _) :: + (`JoinTypeParamName`, DefinedEagerParameter(joinType: JoinType, _)) :: + (`OutputParamName`, outputParameter: DefinedSingleParameter) :: + Nil, _ ) => val outName = OutputVariableNameDependency.extract(dependencies) val mainContext = extractMainBranchId(branchTypeByBranchId).map(contexts).getOrElse(ValidationContext()) - FinalResults.forValidation(mainContext)( - _.withVariable(OutputVar.customNode(outName), outputParameter.returnType) - ) - + FinalResults + .forValidation(mainContext)( + _.withVariable(OutputVar.customNode(outName), outputParameter.returnType) + ) + .copy(state = Some(joinType)) } override def implementation( params: Params, dependencies: List[NodeDependencyValue], - finalState: Option[Nothing] + joinTypeState: Option[JoinType] ): FlinkCustomJoinTransformation = new FlinkCustomJoinTransformation { override def transform( @@ -121,10 +135,14 @@ object TableJoinComponent extends CustomStreamTransformer with JoinDynamicCompon flinkNodeContext: FlinkCustomNodeContext ): DataStream[ValueWithContext[AnyRef]] = { val branchTypeByBranchId: Map[String, BranchType] = BranchTypeParamDeclaration.extractValueUnsafe(params) - val mainBranchId = extractMainBranchId(branchTypeByBranchId).get - val joinedBranchId = extractJoinedBranchId(branchTypeByBranchId).get - val mainStream = inputs(mainBranchId) - val joinedStream = inputs(joinedBranchId) + val mainBranchId = + extractMainBranchId(branchTypeByBranchId).getOrElse(throw new IllegalStateException("Not defined main branch")) + val joinedBranchId = extractJoinedBranchId(branchTypeByBranchId).getOrElse( + throw new IllegalStateException("Not defined joined branch") + ) + val mainStream = inputs(mainBranchId) + val joinedStream = inputs(joinedBranchId) + val joinType = joinTypeState.getOrElse(throw new IllegalStateException("Not defined join type")) val env = mainStream.getExecutionEnvironment val tableEnv = StreamTableEnvironment.create(env) @@ -154,8 +172,11 @@ object TableJoinComponent extends CustomStreamTransformer with JoinDynamicCompon ) ) - val resultTable = - mainTable.join(joinedTable, $(joinedKeyInternalColumnName).isEqual($(mainKeyInternalColumnName))) + val joinPredicate = $(joinedKeyInternalColumnName).isEqual($(mainKeyInternalColumnName)) + val resultTable = joinType match { + case JoinType.INNER => mainTable.join(joinedTable, joinPredicate) + case JoinType.OUTER => mainTable.leftOuterJoin(joinedTable, joinPredicate) + } tableEnv .toDataStream(resultTable) @@ -242,4 +263,8 @@ object TableJoinComponent extends CustomStreamTransformer with JoinDynamicCompon } } + override def typesToExtract: List[TypedClass] = List( + Typed.typedClass[JoinType], + ) + }