diff --git a/docs/Changelog.md b/docs/Changelog.md index 146ebb6a49a..0d5d5f75551 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -22,6 +22,7 @@ in table name * [#6353](https://github.com/TouK/nussknacker/pull/6353) Performance improvement: simple types such as numbers, boolean, string, date types and arrays are serialized/deserialized more optimal in aggregates +* [#6353](https://github.com/TouK/nussknacker/pull/6353) Added `join` component available in Batch processing mode 1.16.1 (16 July 2024) ------------------------- diff --git a/engine/flink/components-api/src/main/scala/pl/touk/nussknacker/engine/flink/api/process/FlinkCustomNodeContext.scala b/engine/flink/components-api/src/main/scala/pl/touk/nussknacker/engine/flink/api/process/FlinkCustomNodeContext.scala index 4bbc46b4e62..5d76b511b13 100644 --- a/engine/flink/components-api/src/main/scala/pl/touk/nussknacker/engine/flink/api/process/FlinkCustomNodeContext.scala +++ b/engine/flink/components-api/src/main/scala/pl/touk/nussknacker/engine/flink/api/process/FlinkCustomNodeContext.scala @@ -58,10 +58,23 @@ case class FlinkCustomNodeContext( lazy val forUnknown: TypeInformation[ValueWithContext[AnyRef]] = forType[AnyRef](Unknown) } + def branchValidationContext(branchId: String): ValidationContext = asJoinContext.getOrElse( + branchId, + throw new IllegalArgumentException(s"No validation context for branchId [$branchId] is defined") + ) + private def asOneOutputContext: ValidationContext = - validationContext.left.getOrElse(throw new IllegalArgumentException("This node is a join, use asJoinContext")) + validationContext.left.getOrElse( + throw new IllegalArgumentException( + "This node is a join, asJoinContext should be used to extract validation context" + ) + ) private def asJoinContext: Map[String, ValidationContext] = - validationContext.getOrElse(throw new IllegalArgumentException()) + validationContext.getOrElse( + throw new IllegalArgumentException( + "This node is a single input node. asOneOutputContext should be used to extract validation context" + ) + ) } diff --git a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java b/engine/flink/components-utils/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java similarity index 100% rename from engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java rename to engine/flink/components-utils/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/join/BranchType.java diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala new file mode 100644 index 00000000000..c5cae20f329 --- /dev/null +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinTest.scala @@ -0,0 +1,206 @@ +package pl.touk.nussknacker.engine.flink.table.join + +import com.typesafe.config.ConfigFactory +import org.apache.flink.api.common.RuntimeExecutionMode +import org.apache.flink.api.connector.source.Boundedness +import org.scalatest.{Inside, LoneElement} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import pl.touk.nussknacker.engine.api.component.ComponentDefinition +import pl.touk.nussknacker.engine.build.{GraphBuilder, ScenarioBuilder} +import pl.touk.nussknacker.engine.flink.table.FlinkTableComponentProvider +import pl.touk.nussknacker.engine.flink.table.join.TableJoinTest.OrderOrProduct +import pl.touk.nussknacker.engine.flink.test.FlinkSpec +import pl.touk.nussknacker.engine.flink.util.transformer.join.BranchType +import pl.touk.nussknacker.engine.util.test.TestScenarioRunner +import pl.touk.nussknacker.test.ValidatedValuesDetailedMessage + +import scala.beans.BeanProperty + +class TableJoinTest + extends AnyFunSuite + with FlinkSpec + with Matchers + with Inside + with ValidatedValuesDetailedMessage + with LoneElement { + + import pl.touk.nussknacker.engine.flink.util.test.FlinkTestScenarioRunner._ + import pl.touk.nussknacker.engine.spel.SpelExtension._ + + import scala.jdk.CollectionConverters._ + + private lazy val additionalComponents: List[ComponentDefinition] = + FlinkTableComponentProvider.configIndependentComponents ::: Nil + + private lazy val runner = TestScenarioRunner + .flinkBased(ConfigFactory.empty(), flinkMiniCluster) + .withExtraComponents(additionalComponents) + .build() + + private val mainBranchId = "main" + private val joinedBranchId = "joined" + + private val joinNodeId = "joined-node-id" + + private val someProduct = OrderOrProduct.createProduct(1, "Foo product") + private val anotherProduct = OrderOrProduct.createProduct(2, "Bar product") + private val delayedProduct = OrderOrProduct.createProduct(3, "Delayed product") + + private val orderReferringToExistingProduct = OrderOrProduct.createOrder(10, someProduct.id) + + private val nonExistingProductId = 100 + private val orderReferringToNonExistingProduct = OrderOrProduct.createOrder(20, nonExistingProductId) + + private val orderReferringToDelayedProduct = OrderOrProduct.createOrder(30, delayedProduct.id) + + test("should inner join stream") { + val enrichedOrders = runner.runWithData[OrderOrProduct, java.util.Map[String, AnyRef]]( + prepareJoiningScenario(JoinType.INNER), + List( + someProduct, + anotherProduct, + orderReferringToExistingProduct, + orderReferringToNonExistingProduct, + orderReferringToDelayedProduct, + delayedProduct + ), + Boundedness.BOUNDED, + Some(RuntimeExecutionMode.BATCH) + ) + + enrichedOrders.validValue.errors shouldBe empty + enrichedOrders.validValue.successes shouldEqual List( + Map( + "orderId" -> orderReferringToExistingProduct.id, + "product" -> Map( + "id" -> someProduct.id, + "name" -> someProduct.name + ).asJava + ).asJava, + Map( + "orderId" -> orderReferringToDelayedProduct.id, + "product" -> Map( + "id" -> delayedProduct.id, + "name" -> delayedProduct.name + ).asJava + ).asJava + ) + } + + test("should outer join stream") { + val enrichedOrders = runner.runWithData[OrderOrProduct, java.util.Map[String, AnyRef]]( + prepareJoiningScenario(JoinType.OUTER), + List( + someProduct, + anotherProduct, + orderReferringToExistingProduct, + orderReferringToNonExistingProduct, + orderReferringToDelayedProduct, + delayedProduct + ), + Boundedness.BOUNDED, + Some(RuntimeExecutionMode.BATCH) + ) + + enrichedOrders.validValue.errors shouldBe empty + enrichedOrders.validValue.successes shouldEqual List( + Map( + "orderId" -> orderReferringToExistingProduct.id, + "product" -> Map( + "id" -> someProduct.id, + "name" -> someProduct.name + ).asJava + ).asJava, + Map( + "orderId" -> orderReferringToNonExistingProduct.id, + "product" -> Map( + "id" -> null, + "name" -> null + ).asJava + ).asJava, + Map( + "orderId" -> orderReferringToDelayedProduct.id, + "product" -> Map( + "id" -> delayedProduct.id, + "name" -> delayedProduct.name + ).asJava + ).asJava + ) + } + + private def prepareJoiningScenario(joinType: JoinType) = ScenarioBuilder + .streaming(classOf[TableJoinTest].getSimpleName) + .sources( + GraphBuilder + .source("orders-source", TestScenarioRunner.testDataSource) + .filter("orders-filter", "#input.type == 'order'".spel) + .branchEnd(mainBranchId, joinNodeId), + GraphBuilder + .source("products-source", TestScenarioRunner.testDataSource) + .filter("product-filter", "#input.type == 'product'".spel) + .branchEnd(joinedBranchId, joinNodeId), + GraphBuilder + .join( + joinNodeId, + "join", + Some("product"), + List( + mainBranchId -> List( + "Branch Type" -> s"T(${classOf[BranchType].getName}).${BranchType.MAIN}".spel, + "Key" -> s"#input.productId.toString".spel + ), + joinedBranchId -> List( + "Branch Type" -> s"T(${classOf[BranchType].getName}).${BranchType.JOINED}".spel, + "Key" -> s"#input.id.toString".spel + ) + ), + "Join Type" -> s"T(${classOf[JoinType].getName}).$joinType".spel, + "Output" -> "#input".spel, + ) + .emptySink( + "end", + TestScenarioRunner.testResultSink, + "value" -> + """{ + | orderId: #input.id, + | product: { + | id: #product?.id, + | name: #product?.name + | } + |}""".stripMargin.spel + ) + ) + +} + +object TableJoinTest { + + // TODO: split into separate classes and pass two streams to separate source nodes + // productId is dedicated only for order events + // name is dedicated only for order events + // It have to by POJO in order by acceptable by table api operators + class OrderOrProduct( + @BeanProperty var `type`: String, + @BeanProperty var id: Int, + @BeanProperty var name: String, + @BeanProperty var productId: Int + ) { + + def this() = this(null, -1, null, -1) + + } + + object OrderOrProduct { + + def createOrder(id: Int, productId: Int): OrderOrProduct = { + new OrderOrProduct("order", id, null, productId) + } + + def createProduct(id: Int, name: String): OrderOrProduct = { + new OrderOrProduct("product", id, name, -1) + } + + } + +} diff --git a/engine/flink/components/table/src/main/java/pl/touk/nussknacker/engine/flink/table/join/JoinType.java b/engine/flink/components/table/src/main/java/pl/touk/nussknacker/engine/flink/table/join/JoinType.java new file mode 100644 index 00000000000..0530d74624f --- /dev/null +++ b/engine/flink/components/table/src/main/java/pl/touk/nussknacker/engine/flink/table/join/JoinType.java @@ -0,0 +1,5 @@ +package pl.touk.nussknacker.engine.flink.table.join; + +public enum JoinType { + INNER, OUTER +} diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/FlinkTableComponentProvider.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/FlinkTableComponentProvider.scala index 3504acee442..02e78e586e7 100644 --- a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/FlinkTableComponentProvider.scala +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/FlinkTableComponentProvider.scala @@ -10,6 +10,7 @@ import pl.touk.nussknacker.engine.flink.table.aggregate.TableAggregationFactory import pl.touk.nussknacker.engine.flink.table.extractor.TableExtractor.extractTablesFromFlinkRuntime import pl.touk.nussknacker.engine.flink.table.extractor.SqlStatementReader import pl.touk.nussknacker.engine.flink.table.extractor.SqlStatementReader.SqlStatement +import pl.touk.nussknacker.engine.flink.table.join.TableJoinComponent import pl.touk.nussknacker.engine.flink.table.sink.TableSinkFactory import pl.touk.nussknacker.engine.flink.table.source.TableSourceFactory import pl.touk.nussknacker.engine.util.ResourceLoader @@ -88,6 +89,10 @@ object FlinkTableComponentProvider { ComponentDefinition( "aggregate", new TableAggregationFactory() + ), + ComponentDefinition( + "join", + TableJoinComponent ) ) diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala new file mode 100644 index 00000000000..b7d6c759f4d --- /dev/null +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/join/TableJoinComponent.scala @@ -0,0 +1,270 @@ +package pl.touk.nussknacker.engine.flink.table.join + +import org.apache.flink.api.common.functions.FlatMapFunction +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.Expressions.$ +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment +import org.apache.flink.types.Row +import org.apache.flink.util.Collector +import pl.touk.nussknacker.engine.api._ +import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.CustomNodeError +import pl.touk.nussknacker.engine.api.context.transformation.{ + DefinedEagerBranchParameter, + DefinedEagerParameter, + DefinedSingleParameter, + JoinDynamicComponent, + NodeDependencyValue +} +import pl.touk.nussknacker.engine.api.context.{OutputVar, ValidationContext} +import pl.touk.nussknacker.engine.api.definition._ +import pl.touk.nussknacker.engine.api.parameter.ParameterName +import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass} +import pl.touk.nussknacker.engine.flink.api.process.{ + AbstractLazyParameterInterpreterFunction, + FlinkCustomJoinTransformation, + FlinkCustomNodeContext, + FlinkLazyParameterFunctionHelper +} +import pl.touk.nussknacker.engine.flink.table.utils.RowConversions +import pl.touk.nussknacker.engine.flink.table.utils.RowConversions.{TypeInformationDetectionExtension, rowToContext} +import pl.touk.nussknacker.engine.flink.util.transformer.join.BranchType +import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap + +object TableJoinComponent + extends CustomStreamTransformer + with JoinDynamicComponent[FlinkCustomJoinTransformation] + with WithExplicitTypesToExtract { + + private val contextInternalColumnName = "context" + private val mainKeyInternalColumnName = "mainKey" + private val joinedKeyInternalColumnName = "joinedKey" + private val outputInternalColumnName = "output" + + val BranchTypeParamName: ParameterName = ParameterName("Branch Type") + + private val BranchTypeParamDeclaration + : ParameterCreatorWithNoDependency with ParameterExtractor[Map[String, BranchType]] = + ParameterDeclaration.branchMandatory[BranchType](BranchTypeParamName).withCreator() + + val JoinTypeParamName: ParameterName = ParameterName("Join Type") + + private val JoinTypeParamDeclaration: ParameterCreatorWithNoDependency with ParameterExtractor[JoinType] = + ParameterDeclaration.mandatory[JoinType](JoinTypeParamName).withCreator() + + val KeyParamName: ParameterName = ParameterName("Key") + + private val KeyParamDeclaration + : ParameterCreatorWithNoDependency with ParameterExtractor[Map[String, LazyParameter[String]]] = + ParameterDeclaration.branchLazyMandatory[String](KeyParamName).withCreator() + + val OutputParamName: ParameterName = ParameterName("Output") + + override type State = JoinType + + override def nodeDependencies: List[NodeDependency] = List(OutputVariableNameDependency) + + override def contextTransformation(contexts: Map[String, ValidationContext], dependencies: List[NodeDependencyValue])( + implicit nodeId: NodeId + ): ContextTransformationDefinition = { + case TransformationStep(Nil, _) => + NextParameters( + List(BranchTypeParamDeclaration, KeyParamDeclaration) + .map(_.createParameter()) + ) + case TransformationStep( + ( + `BranchTypeParamName`, + DefinedEagerBranchParameter(branchTypeByBranchId: Map[String, BranchType] @unchecked, _) + ) :: (`KeyParamName`, _) :: Nil, + _ + ) => + val error = + if (branchTypeByBranchId.values.toList.sorted != BranchType.values().toList) + List( + CustomNodeError( + s"Has to be exactly one MAIN and JOINED branch, got: ${branchTypeByBranchId.values.mkString(", ")}", + Some(BranchTypeParamName) + ) + ) + else + Nil + val joinedVariables = extractJoinedBranchId(branchTypeByBranchId) + .map(contexts) + .getOrElse(ValidationContext()) + .localVariables + .mapValuesNow(AdditionalVariableProvidedInRuntime(_)) + NextParameters( + List( + JoinTypeParamDeclaration.createParameter(), + ParameterDeclaration + .lazyMandatory[AnyRef](OutputParamName) + .withCreator(_.copy(additionalVariables = joinedVariables)) + .createParameter() + ), + error + ) + case TransformationStep( + ( + `BranchTypeParamName`, + DefinedEagerBranchParameter(branchTypeByBranchId: Map[String, BranchType] @unchecked, _) + ) :: + (`KeyParamName`, _) :: + (`JoinTypeParamName`, DefinedEagerParameter(joinType: JoinType, _)) :: + (`OutputParamName`, outputParameter: DefinedSingleParameter) :: + Nil, + _ + ) => + val outName = OutputVariableNameDependency.extract(dependencies) + val mainContext = extractMainBranchId(branchTypeByBranchId).map(contexts).getOrElse(ValidationContext()) + FinalResults + .forValidation(mainContext)( + _.withVariable(OutputVar.customNode(outName), outputParameter.returnType) + ) + .copy(state = Some(joinType)) + } + + override def implementation( + params: Params, + dependencies: List[NodeDependencyValue], + joinTypeState: Option[JoinType] + ): FlinkCustomJoinTransformation = new FlinkCustomJoinTransformation { + + override def transform( + inputs: Map[String, DataStream[Context]], + flinkNodeContext: FlinkCustomNodeContext + ): DataStream[ValueWithContext[AnyRef]] = { + val branchTypeByBranchId: Map[String, BranchType] = BranchTypeParamDeclaration.extractValueUnsafe(params) + val mainBranchId = + extractMainBranchId(branchTypeByBranchId).getOrElse(throw new IllegalStateException("Not defined main branch")) + val joinedBranchId = extractJoinedBranchId(branchTypeByBranchId).getOrElse( + throw new IllegalStateException("Not defined joined branch") + ) + val mainStream = inputs(mainBranchId) + val joinedStream = inputs(joinedBranchId) + val joinType = joinTypeState.getOrElse(throw new IllegalStateException("Not defined join type")) + + val env = mainStream.getExecutionEnvironment + val tableEnv = StreamTableEnvironment.create(env) + + val mainTable = tableEnv.fromDataStream( + mainStream.flatMap( + new MainBranchToRowFunction( + KeyParamDeclaration.extractValueUnsafe(params)(mainBranchId), + flinkNodeContext.lazyParameterHelper + ), + mainBranchTypeInfo(flinkNodeContext, mainBranchId) + ) + ) + + val outputLazyParam = params.extractUnsafe[LazyParameter[AnyRef]](OutputParamName) + val outputTypeInfo = + flinkNodeContext.valueWithContextInfo.forBranch[AnyRef](mainBranchId, outputLazyParam.returnType) + + val joinedTable = tableEnv.fromDataStream( + joinedStream.flatMap( + new JoinedBranchToRowFunction( + KeyParamDeclaration.extractValueUnsafe(params)(joinedBranchId), + outputLazyParam, + flinkNodeContext.lazyParameterHelper + ), + joinedBranchTypeInfo(flinkNodeContext, outputLazyParam) + ) + ) + + val joinPredicate = $(joinedKeyInternalColumnName).isEqual($(mainKeyInternalColumnName)) + val resultTable = joinType match { + case JoinType.INNER => mainTable.join(joinedTable, joinPredicate) + case JoinType.OUTER => mainTable.leftOuterJoin(joinedTable, joinPredicate) + } + + tableEnv + .toDataStream(resultTable) + .map( + (row: Row) => + ValueWithContext[AnyRef]( + row.getField(outputInternalColumnName), + rowToContext(row.getField(contextInternalColumnName).asInstanceOf[Row]) + ), + outputTypeInfo + ) + } + + } + + private class MainBranchToRowFunction( + mainKeyLazyParam: LazyParameter[String], + lazyParameterHelper: FlinkLazyParameterFunctionHelper + ) extends AbstractLazyParameterInterpreterFunction(lazyParameterHelper) + with FlatMapFunction[Context, Row] { + + private lazy val evaluateKey = toEvaluateFunctionConverter.toEvaluateFunction(mainKeyLazyParam) + + override def flatMap(context: Context, out: Collector[Row]): Unit = { + collectHandlingErrors(context, out) { + val row = Row.withNames() + row.setField(contextInternalColumnName, RowConversions.contextToRow(context)) + row.setField(mainKeyInternalColumnName, evaluateKey(context)) + row + } + } + + } + + private def mainBranchTypeInfo(flinkNodeContext: FlinkCustomNodeContext, mainBranchId: String) = { + Types.ROW_NAMED( + Array(contextInternalColumnName, mainKeyInternalColumnName), + flinkNodeContext.typeInformationDetection.contextRowTypeInfo( + flinkNodeContext.branchValidationContext(mainBranchId) + ), + Types.STRING + ) + } + + private class JoinedBranchToRowFunction( + joinedKeyLazyParam: LazyParameter[String], + outputLazyParam: LazyParameter[AnyRef], + lazyParameterHelper: FlinkLazyParameterFunctionHelper + ) extends AbstractLazyParameterInterpreterFunction(lazyParameterHelper) + with FlatMapFunction[Context, Row] { + + private lazy val evaluateKey = toEvaluateFunctionConverter.toEvaluateFunction(joinedKeyLazyParam) + + private lazy val evaluateOutput = toEvaluateFunctionConverter.toEvaluateFunction(outputLazyParam) + + override def flatMap(context: Context, out: Collector[Row]): Unit = { + collectHandlingErrors(context, out) { + val row = Row.withNames() + row.setField(joinedKeyInternalColumnName, evaluateKey(context)) + row.setField(outputInternalColumnName, evaluateOutput(context)) + row + } + } + + } + + private def joinedBranchTypeInfo(flinkNodeContext: FlinkCustomNodeContext, outputLazyParam: LazyParameter[_]) = { + Types.ROW_NAMED( + Array(joinedKeyInternalColumnName, outputInternalColumnName), + Types.STRING, + flinkNodeContext.typeInformationDetection.forType(outputLazyParam.returnType) + ) + } + + private def extractMainBranchId(branchTypeByBranchId: Map[String, BranchType]) = { + branchTypeByBranchId.collectFirst { case (branchId, BranchType.MAIN) => + branchId + } + } + + private def extractJoinedBranchId(branchTypeByBranchId: Map[String, BranchType]) = { + branchTypeByBranchId.collectFirst { case (branchId, BranchType.JOINED) => + branchId + } + } + + override def typesToExtract: List[TypedClass] = List( + Typed.typedClass[JoinType], + ) + +} diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala index 3299441789c..817b2001709 100644 --- a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/utils/RowConversions.scala @@ -1,19 +1,29 @@ package pl.touk.nussknacker.engine.flink.table.utils -import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeinfo.{TypeInformation, Types} import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.types.Row +import pl.touk.nussknacker.engine.api.Context +import pl.touk.nussknacker.engine.api.context.ValidationContext import pl.touk.nussknacker.engine.api.typed.typing.TypedObjectTypingResult import pl.touk.nussknacker.engine.flink.api.typeinformation.TypeInformationDetection +import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap + +import java.util object RowConversions { import scala.jdk.CollectionConverters._ def rowToMap(row: Row): java.util.Map[String, Any] = { + val fields: Map[String, AnyRef] = rowToScalaMap(row) + new util.HashMap[String, Any](fields.asJava) + } + + private def rowToScalaMap(row: Row): Map[String, AnyRef] = { val fieldNames = row.getFieldNames(true).asScala val fields = fieldNames.map(n => n -> row.getField(n)).toMap - new java.util.HashMap[String, Any](fields.asJava) + fields } def mapToRow(map: java.util.Map[String, Any], columnNames: Iterable[String]): Row = { @@ -22,6 +32,31 @@ object RowConversions { row } + private def scalaMapToRow(map: Map[String, Any]): Row = { + val row = Row.withNames() + map.foreach { case (name, value) => + row.setField(name, value) + } + row + } + + def contextToRow(context: Context): Row = { + val row = Row.withPositions(context.parentContext.map(_ => 3).getOrElse(2)) + val variablesRow = scalaMapToRow(context.variables) + row.setField(0, context.id) + row.setField(1, variablesRow) + context.parentContext.map(contextToRow).foreach(row.setField(2, _)) + row + } + + def rowToContext(row: Row): Context = { + Context( + row.getField(0).asInstanceOf[String], + rowToScalaMap(row.getField(1).asInstanceOf[Row]), + Option(row).filter(_.getArity >= 3).map(_.getField(2).asInstanceOf[Row]).map(rowToContext) + ) + } + implicit class TypeInformationDetectionExtension(typeInformationDetection: TypeInformationDetection) { def rowTypeInfoWithColumnsInGivenOrder( @@ -37,6 +72,13 @@ object RowConversions { new RowTypeInfo(typeInfos.toArray[TypeInformation[_]], fieldNames.toArray) } + def contextRowTypeInfo(validationContext: ValidationContext): TypeInformation[_] = { + val (fieldNames, typeInfos) = + validationContext.localVariables.mapValuesNow(typeInformationDetection.forType).unzip + val variablesRow = new RowTypeInfo(typeInfos.toArray[TypeInformation[_]], fieldNames.toArray) + Types.ROW(Types.STRING :: variablesRow :: validationContext.parent.map(contextRowTypeInfo).toList: _*) + } + } }