diff --git a/docs/Changelog.md b/docs/Changelog.md index 964d82c3cd9..8d47f6c0de5 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -29,6 +29,7 @@ * [#7332](https://github.com/TouK/nussknacker/pull/7332) Handle scenario names with spaces when performing migration tests, they were ignored * [#7346](https://github.com/TouK/nussknacker/pull/7346) OpenAPI enricher: ability to configure common secret for any security scheme * [#7307](https://github.com/TouK/nussknacker/pull/7307) Added StddevPop, StddevSamp, VarPop and VarSamp aggregators +* [#7321](https://github.com/TouK/nussknacker/pull/7321) Added Median aggregator ## 1.18 diff --git a/docs/scenarios_authoring/AggregatesInTimeWindows.md b/docs/scenarios_authoring/AggregatesInTimeWindows.md index 322fb5b8e4a..f354d815d94 100644 --- a/docs/scenarios_authoring/AggregatesInTimeWindows.md +++ b/docs/scenarios_authoring/AggregatesInTimeWindows.md @@ -85,6 +85,7 @@ Let’s map the above statement on the parameters of the Nussknacker Aggregate c * StddevSamp - computes sample standard deviation * VarPop - computes population variance * VarSamp - computes sample variance +* Median - computes median * ApproximateSetCardinality - computes approximate cardinality of a set using [HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) algorithm. Please note that this aggregator treats null as a unique value. If this is undesirable and the set passed to ApproximateSetCardinality aggregator contained null (this can be tested with safe navigation in [SpEL](./Spel.md#safe-navigation)), subtract 1 from the obtained result. If you need to count events in a window, use the CountWhen aggregate function and aggregate by fixed `true` expression - see the table with examples below. Subsequent sections use the Count function on the diagrams as an example for the **aggregator** - it is the easiest function to use in the examples. Please note, however, that technically, we provide an indirect implementation of this aggregator. diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala index 07588ed4335..c4243070a9d 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala @@ -5,7 +5,24 @@ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypedObjectTypingResult, TypingResult, Unknown} import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregatesSpec.{EPS_BIG_DECIMAL, EPS_DOUBLE} -import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{AverageAggregator, CountWhenAggregator, FirstAggregator, LastAggregator, ListAggregator, MapAggregator, MaxAggregator, MinAggregator, OptionAggregator, PopulationStandardDeviationAggregator, PopulationVarianceAggregator, SampleStandardDeviationAggregator, SampleVarianceAggregator, SetAggregator, SumAggregator} +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{ + AverageAggregator, + CountWhenAggregator, + FirstAggregator, + LastAggregator, + ListAggregator, + MapAggregator, + MaxAggregator, + MedianAggregator, + MinAggregator, + OptionAggregator, + PopulationStandardDeviationAggregator, + PopulationVarianceAggregator, + SampleStandardDeviationAggregator, + SampleVarianceAggregator, + SetAggregator, + SumAggregator +} import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap import java.lang.{Integer => JInt, Long => JLong} @@ -127,7 +144,8 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat MinAggregator, FirstAggregator, LastAggregator, - SumAggregator + SumAggregator, + MedianAggregator )) { agg => addElementsAndComputeResult(List(null), agg) shouldEqual null } @@ -148,6 +166,62 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat } } + test("should calculate correct results for median aggregator on integers") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(7, 8), agg) + result shouldBe a[Double] + result shouldEqual 7.5 + } + + test("should calculate correct results for median aggregator on integers on single value") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(7), agg) + result shouldBe a[Double] + result shouldEqual 7 + } + + test("should calculate correct results for median aggregator on BigInt") { + val agg = MedianAggregator + addElementsAndComputeResult(List(new BigInteger("7"), new BigInteger("8")), agg) shouldEqual new java.math.BigDecimal("7.5") + } + + test("should calculate correct results for median aggregator on floats") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(7.0f, 8.0f), agg) + result shouldBe a[Double] + result shouldEqual 7.5 + } + + test("should calculate correct results for median aggregator on BigDecimals") { + val agg = MedianAggregator + addElementsAndComputeResult( + List(new java.math.BigDecimal("7"), new java.math.BigDecimal("8")), + agg + ) shouldEqual new java.math.BigDecimal("7.5") + } + + test("should ignore nulls for median aggregator") { + val agg = MedianAggregator + addElementsAndComputeResult( + List(null, new java.math.BigDecimal("7"), null, new java.math.BigDecimal("8")), + agg + ) shouldEqual new java.math.BigDecimal("7.5") + } + + test("MedianAggregator test on odd length list") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5, 90), agg) + + result shouldEqual 5 + } + + test("MedianAggregator test on even length list") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5), agg) + + result shouldEqual 4.5 + } + test("should calculate correct results for standard deviation and variance on integers") { val table = Table( ("aggregator", "value"), @@ -230,7 +304,8 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat ( SumAggregator, 15.0), ( MaxAggregator, 5.0), ( MinAggregator, 1.0), - ( AverageAggregator, 3.0) + ( AverageAggregator, 3.0), + ( MedianAggregator, 3.0) ) forAll(table) { (agg, expectedResult) => diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala index 579ebf06e58..e714c4101f0 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala @@ -96,6 +96,13 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateOk("#AGG.varPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) validateOk("#AGG.varSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.median", """#input.eId""", Typed[Double]) + validateOk("#AGG.median", """1""", Typed[Double]) + validateOk("#AGG.median", """1.5""", Typed[Double]) + + validateOk("#AGG.median", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.median", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.set", "#input.str", Typed.fromDetailedType[java.util.Set[String]]) validateOk( "#AGG.map({f1: #AGG.sum, f2: #AGG.set})", @@ -106,6 +113,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateError("#AGG.sum", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.countWhen", "#input.str", "Invalid aggregate type: String, should be: Boolean") validateError("#AGG.average", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError("#AGG.median", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.stddevPop", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.stddevSamp", "#input.str", "Invalid aggregate type: String, should be: Number") @@ -170,6 +178,17 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldBe List(1.0d, 1.5, 3.5) } + test("median aggregate") { + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b"))) + val testProcess = sliding("#AGG.median", "#input.eId", emitWhenEventLeft = false) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables shouldBe List(1.0d, 1.5, 3.5) + } + test("standard deviation and average aggregates") { val table = Table( ("aggregate", "secondValue"), @@ -455,6 +474,19 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true } + test("emit aggregate for extra window when no data come for median aggregator for return type double") { + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"))) + val testProcess = tumbling("#AGG.median", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables.length shouldEqual (2) + aggregateVariables(0) shouldEqual 1.0 + aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true + } + test( "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type double" ) { @@ -492,6 +524,18 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null) } + test("emit aggregate for extra window when no data come for median aggregator for return type BigDecimal") { + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"))) + val testProcess = + tumbling("#AGG.median", """T(java.math.BigDecimal).ONE""", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null) + } + test( "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type BigDecimal" ) { diff --git a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java index 7872db9cba4..9be0d53b6fb 100644 --- a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java +++ b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java @@ -31,6 +31,7 @@ public class AggregateHelper implements Serializable { new FixedExpressionValue("#AGG.stddevSamp", "StddevSamp"), new FixedExpressionValue("#AGG.varPop", "VarPop"), new FixedExpressionValue("#AGG.varSamp", "VarSamp"), + new FixedExpressionValue("#AGG.median", "Median"), new FixedExpressionValue("#AGG.min", "Min"), new FixedExpressionValue("#AGG.max", "Max"), new FixedExpressionValue("#AGG.sum", "Sum"), @@ -54,6 +55,7 @@ public class AggregateHelper implements Serializable { private static final Aggregator STDDEV_SAMP = aggregates.SampleStandardDeviationAggregator$.MODULE$; private static final Aggregator VAR_POP = aggregates.PopulationVarianceAggregator$.MODULE$; private static final Aggregator VAR_SAMP = aggregates.SampleVarianceAggregator$.MODULE$; + private static final Aggregator MEDIAN = aggregates.MedianAggregator$.MODULE$; private static final Aggregator APPROX_CARDINALITY = HyperLogLogPlusAggregator$.MODULE$.apply(5, 10); // Instance methods below are for purpose of using in SpEL so your IDE can report that they are not used. @@ -80,6 +82,8 @@ public class AggregateHelper implements Serializable { public Aggregator varPop = VAR_POP; public Aggregator varSamp = VAR_SAMP; + public Aggregator median = MEDIAN; + public Aggregator approxCardinality = APPROX_CARDINALITY; public Aggregator map(@ParamName("parts") Map parts) { diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala index 3f54f4cbbc1..0aa361f8574 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala @@ -5,15 +5,19 @@ import cats.data.{NonEmptyList, Validated} import cats.instances.list._ import org.apache.flink.api.common.typeinfo.TypeInfo import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy -import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation +import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{ + ForLargeFloatingNumbersOperation, +} import pl.touk.nussknacker.engine.api.typed.typing._ import pl.touk.nussknacker.engine.api.typed.{NumberTypeUtils, typing} import pl.touk.nussknacker.engine.flink.api.typeinfo.caseclass.CaseClassTypeInfoFactory +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median.MedianHelper import pl.touk.nussknacker.engine.util.Implicits._ import pl.touk.nussknacker.engine.util.MathUtils import pl.touk.nussknacker.engine.util.validated.ValidatedSyntax._ import java.util +import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ /* @@ -69,6 +73,26 @@ object aggregates { } + object MedianAggregator extends Aggregator with LargeFloatingNumberAggregate { + + override type Aggregate = ListBuffer[Number] + + override type Element = Number + + override def zero: Aggregate = ListBuffer.empty + + override def addElement(el: Element, agg: Aggregate): Aggregate = if (el == null) agg else agg.addOne(el) + + override def mergeAggregates(agg1: Aggregate, agg2: Aggregate): Aggregate = agg1 ++ agg2 + + override def result(finalAggregate: Aggregate): AnyRef = MedianHelper.calculateMedian(finalAggregate.toList).orNull + + override def computeStoredType(input: TypingResult): Validated[String, TypingResult] = Valid( + Typed.genericTypeClass[ListBuffer[_]](List(input)) + ) + + } + object ListAggregator extends Aggregator { override type Aggregate = List[AnyRef] diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala new file mode 100644 index 00000000000..a4685a3b9b5 --- /dev/null +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala @@ -0,0 +1,58 @@ +package pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median + +import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{ + ForLargeFloatingNumbersOperation, +} +import pl.touk.nussknacker.engine.util.MathUtils + +import scala.annotation.tailrec +import scala.util.Random + +object MedianHelper { + private val rand = new Random(42) + + def calculateMedian(numbers: List[Number]): Option[Number] = { + if (numbers.isEmpty) { + None + } else if (numbers.size % 2 == 1) { + Some(MathUtils.convertToPromotedType(quickSelect(numbers, (numbers.size - 1) / 2))(ForLargeFloatingNumbersOperation)) + } else { + // it is possible to fetch both numbers with single recursion, but it would complicate code + val firstNumber = quickSelect(numbers, numbers.size / 2 - 1) + val secondNumber = quickSelect(numbers, numbers.size / 2) + + val sum = MathUtils.largeFloatingSum(firstNumber, secondNumber) + Some(MathUtils.divideWithDefaultBigDecimalScale(sum, 2)) + } + } + + // https://en.wikipedia.org/wiki/Quickselect + @tailrec + private def quickSelect(numbers: List[Number], indexToTake: Int): Number = { + require(numbers.nonEmpty) + require(indexToTake >= 0) + require(indexToTake < numbers.size) + + val randomElement = numbers(rand.nextInt(numbers.size)) + val groupedBy = numbers.groupBy(e => { + val cmp = MathUtils.compare(e, randomElement) + if (cmp < 0) { + -1 + } else if (cmp == 0) { + 0 + } else 1 + }) + val smallerNumbers = groupedBy.getOrElse(-1, Nil) + val equalNumbers = groupedBy.getOrElse(0, Nil) + val largerNumbers = groupedBy.getOrElse(1, Nil) + + if (indexToTake < smallerNumbers.size) { + quickSelect(smallerNumbers, indexToTake) + } else if (indexToTake < smallerNumbers.size + equalNumbers.size) { + equalNumbers(indexToTake - smallerNumbers.size) + } else { + quickSelect(largerNumbers, indexToTake - smallerNumbers.size - equalNumbers.size) + } + } + +} diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala index cb4ab891e21..d38ce11e8a8 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala @@ -45,6 +45,7 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), + new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -102,6 +103,7 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), + new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -163,6 +165,7 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), + new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") diff --git a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json index 12f3af1ca0c..029dedb5548 100644 --- a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json +++ b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json @@ -14930,6 +14930,15 @@ } } ], + "median": [ + { + "name": "median", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], "first": [ { "name": "first", diff --git a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala index 2af9e71b786..8a78b354d9b 100644 --- a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala +++ b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala @@ -186,7 +186,8 @@ trait MathUtils { case n1: java.math.BigDecimal => n1.negate() } - private def compare(n1: Number, n2: Number): Int = { + @Hidden + def compare(n1: Number, n2: Number): Int = { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandler[Int] { override def onBytes(n1: java.lang.Byte, n2: java.lang.Byte): Int = n1.compareTo(n2) override def onShorts(n1: java.lang.Short, n2: java.lang.Short): Int = n1.compareTo(n2) @@ -286,7 +287,8 @@ trait MathUtils { } } - private def convertToPromotedType( + @Hidden + def convertToPromotedType( n: Number )(implicit promotionStrategy: ReturningSingleClassPromotionStrategy): Number = { // In some cases type can be promoted to other class e.g. Byte is promoted to Int for sum