From f00c47a80dcd7bee78c63cd69ede087467c77d22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czajka?= Date: Wed, 18 Dec 2024 10:18:51 +0100 Subject: [PATCH 1/3] [NU-1921] Add standard deviation and variance aggregations (#7307) Co-authored-by: Pawel Czajka --- docs/Changelog.md | 1 + .../AggregatesInTimeWindows.md | 4 + .../aggregate/AggregatesSpec.scala | 236 +++++++++++++++--- .../aggregate/TransformersTest.scala | 108 +++++++- .../aggregate/AggregateHelper.java | 13 + .../transformer/aggregate/aggregates.scala | 201 ++++++++++++--- .../aggregate/sampleTransformers.scala | 12 + .../extractedTypes/defaultModel.json | 38 ++- .../nussknacker/engine/util/MathUtils.scala | 45 +++- 9 files changed, 571 insertions(+), 87 deletions(-) diff --git a/docs/Changelog.md b/docs/Changelog.md index 3986d46db21..964d82c3cd9 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -28,6 +28,7 @@ * [#7323](https://github.com/TouK/nussknacker/pull/7323) Improve Periodic DeploymentManager db queries * [#7332](https://github.com/TouK/nussknacker/pull/7332) Handle scenario names with spaces when performing migration tests, they were ignored * [#7346](https://github.com/TouK/nussknacker/pull/7346) OpenAPI enricher: ability to configure common secret for any security scheme +* [#7307](https://github.com/TouK/nussknacker/pull/7307) Added StddevPop, StddevSamp, VarPop and VarSamp aggregators ## 1.18 diff --git a/docs/scenarios_authoring/AggregatesInTimeWindows.md b/docs/scenarios_authoring/AggregatesInTimeWindows.md index 23ed9fb7baf..322fb5b8e4a 100644 --- a/docs/scenarios_authoring/AggregatesInTimeWindows.md +++ b/docs/scenarios_authoring/AggregatesInTimeWindows.md @@ -81,6 +81,10 @@ Let’s map the above statement on the parameters of the Nussknacker Aggregate c * Set - the result is a set of inputs received by the aggregator. Can be very ineffective for large sets, try to use ApproximateSetCardinality in this case * CountWhen - accepts boolean values, returns how many of them are true * Average - computes average of values +* StddevPop - computes population standard deviation +* StddevSamp - computes sample standard deviation +* VarPop - computes population variance +* VarSamp - computes sample variance * ApproximateSetCardinality - computes approximate cardinality of a set using [HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) algorithm. Please note that this aggregator treats null as a unique value. If this is undesirable and the set passed to ApproximateSetCardinality aggregator contained null (this can be tested with safe navigation in [SpEL](./Spel.md#safe-navigation)), subtract 1 from the obtained result. If you need to count events in a window, use the CountWhen aggregate function and aggregate by fixed `true` expression - see the table with examples below. Subsequent sections use the Count function on the diagrams as an example for the **aggregator** - it is the easiest function to use in the examples. Please note, however, that technically, we provide an indirect implementation of this aggregator. diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala index 1a8a5e13686..07588ed4335 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala @@ -3,8 +3,9 @@ package pl.touk.nussknacker.engine.flink.util.transformer.aggregate import org.scalatest.prop.TableDrivenPropertyChecks import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers -import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectTypingResult, TypingResult, Unknown} -import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates._ +import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypedObjectTypingResult, TypingResult, Unknown} +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregatesSpec.{EPS_BIG_DECIMAL, EPS_DOUBLE} +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{AverageAggregator, CountWhenAggregator, FirstAggregator, LastAggregator, ListAggregator, MapAggregator, MaxAggregator, MinAggregator, OptionAggregator, PopulationStandardDeviationAggregator, PopulationVarianceAggregator, SampleStandardDeviationAggregator, SampleVarianceAggregator, SetAggregator, SumAggregator} import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap import java.lang.{Integer => JInt, Long => JLong} @@ -83,73 +84,215 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat test("should calculate correct results for first aggregator") { val agg = FirstAggregator - agg.result( - agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(5.asInstanceOf[agg.Element], agg.zero)) - ) shouldEqual 5 + addElementsAndComputeResult(List(5, 8), agg) shouldEqual 5 } test("should calculate correct results for countWhen aggregator") { val agg = CountWhenAggregator - agg.result( - agg.addElement( - false.asInstanceOf[agg.Element], - agg.addElement(true.asInstanceOf[agg.Element], agg.addElement(true.asInstanceOf[agg.Element], agg.zero)) - ) - ) shouldEqual 2 + addElementsAndComputeResult(List(true, true), agg) shouldEqual 2 } test("should calculate correct results for average aggregator") { val agg = AverageAggregator - agg.result( - agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(7.asInstanceOf[agg.Element], agg.zero)) - ) shouldEqual 7.5 + addElementsAndComputeResult(List(7, 8), agg) shouldEqual 7.5 } test("should calculate correct results for average aggregator on BigInt") { val agg = AverageAggregator - agg.result( - agg.addElement( - new BigInteger("8").asInstanceOf[agg.Element], - agg.addElement(new BigInteger("7").asInstanceOf[agg.Element], agg.zero) - ) - ) shouldEqual new java.math.BigDecimal("7.5") + addElementsAndComputeResult(List(new BigInteger("7"), new BigInteger("8")), agg) shouldEqual new java.math.BigDecimal("7.5") } test("should calculate correct results for average aggregator on float") { val agg = AverageAggregator - agg.result( - agg.addElement( - 8.0f.asInstanceOf[agg.Element], - agg.addElement(7.0f.asInstanceOf[agg.Element], agg.zero) - ) - ) shouldEqual 7.5 + addElementsAndComputeResult(List(7.0f, 8.0f), agg) shouldEqual 7.5 } test("should calculate correct results for average aggregator on BigDecimal") { val agg = AverageAggregator - agg.result( - agg.addElement( - new java.math.BigDecimal("8").asInstanceOf[agg.Element], - agg.addElement(new java.math.BigDecimal("7").asInstanceOf[agg.Element], agg.zero) - ) + addElementsAndComputeResult( + List(new java.math.BigDecimal("7"), new java.math.BigDecimal("8")), + agg ) shouldEqual new java.math.BigDecimal("7.5") } - test("AverageAggregator should calculate correct results for empty aggregation set") { - val agg = AverageAggregator - val result = agg.result( - agg.zero + test("some aggregators should produce null on single null input") { + forAll (Table( + "aggregator", + AverageAggregator, + SampleStandardDeviationAggregator, + PopulationStandardDeviationAggregator, + SampleVarianceAggregator, + PopulationVarianceAggregator, + MaxAggregator, + MinAggregator, + FirstAggregator, + LastAggregator, + SumAggregator + )) { agg => + addElementsAndComputeResult(List(null), agg) shouldEqual null + } + } + + test("should calculate correct results for standard deviation and variance on doubles") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, Math.sqrt(2.5) ), + ( PopulationStandardDeviationAggregator, Math.sqrt(2) ), + ( SampleVarianceAggregator, 2.5 ), + ( PopulationVarianceAggregator, 2.0 ) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult(List(5.0, 4.0, 3.0, 2.0, 1.0), agg) + result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE + } + } + + test("should calculate correct results for standard deviation and variance on integers") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, Math.sqrt(2.5) ), + ( PopulationStandardDeviationAggregator, Math.sqrt(2) ), + ( SampleVarianceAggregator, 2.5 ), + ( PopulationVarianceAggregator, 2.0 ) ) - // null is returned because method alignToExpectedType did not run - require(result == null) + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult(List(5, 4, 3, 2, 1), agg) + result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE + } + } + + test("should calculate correct results for standard deviation and variance on BigInt") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, BigDecimal(Math.sqrt(2.5)) ), + ( PopulationStandardDeviationAggregator, BigDecimal(Math.sqrt(2)) ), + ( SampleVarianceAggregator, BigDecimal(2.5) ), + ( PopulationVarianceAggregator, BigDecimal(2.0) ) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult( + List(new BigInteger("5"), new BigInteger("4"), new BigInteger("3"), new BigInteger("2"), new BigInteger("1")), + agg + ) + BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe expectedResult +- EPS_BIG_DECIMAL + } + } + + test("should calculate correct results for standard deviation and variance on float") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, Math.sqrt(2.5) ), + ( PopulationStandardDeviationAggregator, Math.sqrt(2) ), + ( SampleVarianceAggregator, 2.5 ), + ( PopulationVarianceAggregator, 2.0 ) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult(List(5.0f, 4.0f, 3.0f, 2.0f, 1.0f), agg) + result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE + } + } + + test("should calculate correct results for standard deviation and variance on BigDecimals") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, BigDecimal(Math.sqrt(2.5)) ), + ( PopulationStandardDeviationAggregator, BigDecimal(Math.sqrt(2)) ), + ( SampleVarianceAggregator, BigDecimal(2.5) ), + ( PopulationVarianceAggregator, BigDecimal(2.0) ) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult( + List( + new java.math.BigDecimal("5"), + new java.math.BigDecimal("4"), + new java.math.BigDecimal("3"), + new java.math.BigDecimal("2"), + new java.math.BigDecimal("1") + ), + agg + ) + BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe expectedResult +- EPS_BIG_DECIMAL + } + } + + test("some aggregators should ignore nulls ") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, Math.sqrt(2.5) ), + ( PopulationStandardDeviationAggregator, Math.sqrt(2) ), + ( SampleVarianceAggregator, 2.5 ), + ( PopulationVarianceAggregator, 2.0 ), + ( SumAggregator, 15.0), + ( MaxAggregator, 5.0), + ( MinAggregator, 1.0), + ( AverageAggregator, 3.0) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult(List(null, 5.0, 4.0, null, 3.0, 2.0, 1.0), agg) + result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE + } + } + + test("some aggregators should produce null on empty set") { + forAll (Table( + "aggregator", + AverageAggregator, + SampleStandardDeviationAggregator, + PopulationStandardDeviationAggregator, + SampleVarianceAggregator, + PopulationVarianceAggregator, + MaxAggregator, + MinAggregator, + FirstAggregator, + LastAggregator, + SumAggregator + )) { agg => + val result = addElementsAndComputeResult(List(), agg) + result shouldBe null + } + } + + test("should calculate correct results for population standard deviation and variance on single element double set") { + val table = Table( + "aggregator", + SampleStandardDeviationAggregator, + PopulationStandardDeviationAggregator, + SampleVarianceAggregator, + PopulationVarianceAggregator + ) + forAll(table) { agg => + val result = addElementsAndComputeResult(List(1.0d), agg) + result.asInstanceOf[Double] shouldBe 0 + } + } + + test("should calculate correct results for population standard deviation on single element float set") { + val agg = PopulationStandardDeviationAggregator + val result = addElementsAndComputeResult(List(1.0f), agg) + result.asInstanceOf[Double] shouldBe 0 + } + + test("should calculate correct results for population standard deviation on single element BigDecimal set") { + val agg = PopulationStandardDeviationAggregator + val result = addElementsAndComputeResult(List(new java.math.BigDecimal("1.0")), agg) + BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe BigDecimal(0) +- EPS_BIG_DECIMAL + } + + test("should calculate correct results for population standard deviation on single element BigInteger set") { + val agg = PopulationStandardDeviationAggregator + val result = addElementsAndComputeResult(List(new java.math.BigInteger("1")), agg) + BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe BigDecimal(0) +- EPS_BIG_DECIMAL } test("should calculate correct results for last aggregator") { val agg = LastAggregator - agg.result( - agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(5.asInstanceOf[agg.Element], agg.zero)) - ) shouldEqual 8 + addElementsAndComputeResult(List(5, 8), agg) shouldEqual 8 } test("should compute output and stored type for simple aggregators") { @@ -344,5 +487,18 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat aggregator.mergeAggregates(rightElemState, leftElemState) shouldBe combinedState } + private def addElementsAndComputeResult[T](elements: List[T], aggregator: Aggregator): AnyRef = { + aggregator.result( + elements.foldLeft(aggregator.zero)((state, element) => + aggregator.addElement(element.asInstanceOf[aggregator.Element], state) + ) + ) + } + class JustAnyClass } + +object AggregatesSpec { + val EPS_DOUBLE = 0.000001; + val EPS_BIG_DECIMAL = BigDecimal(new java.math.BigDecimal("0.000001")) +} diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala index a2f00006bc4..579ebf06e58 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala @@ -6,6 +6,8 @@ import com.typesafe.config.ConfigFactory import org.scalatest.Inside import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks.forAll +import org.scalatest.prop.Tables.Table import pl.touk.nussknacker.engine.api.component.ComponentDefinition import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.{ CannotCreateObjectError, @@ -74,6 +76,26 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateOk("#AGG.average", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) validateOk("#AGG.average", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.stddevPop", "1", Typed[Double]) + validateOk("#AGG.stddevSamp", "1", Typed[Double]) + validateOk("#AGG.varPop", "1", Typed[Double]) + validateOk("#AGG.varSamp", "1", Typed[Double]) + + validateOk("#AGG.stddevPop", "1.5", Typed[Double]) + validateOk("#AGG.stddevSamp", "1.5", Typed[Double]) + validateOk("#AGG.varPop", "1.5", Typed[Double]) + validateOk("#AGG.varSamp", "1.5", Typed[Double]) + + validateOk("#AGG.stddevPop", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.stddevSamp", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.varPop", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.varSamp", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + + validateOk("#AGG.stddevPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.stddevSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.varPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.varSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.set", "#input.str", Typed.fromDetailedType[java.util.Set[String]]) validateOk( "#AGG.map({f1: #AGG.sum, f2: #AGG.set})", @@ -84,6 +106,12 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateError("#AGG.sum", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.countWhen", "#input.str", "Invalid aggregate type: String, should be: Boolean") validateError("#AGG.average", "#input.str", "Invalid aggregate type: String, should be: Number") + + validateError("#AGG.stddevPop", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError("#AGG.stddevSamp", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError("#AGG.varPop", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError("#AGG.varSamp", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError( "#AGG.map({f1: #AGG.set, f2: #AGG.set})", "{f1: #input.str}", @@ -142,6 +170,31 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldBe List(1.0d, 1.5, 3.5) } + test("standard deviation and average aggregates") { + val table = Table( + ("aggregate", "secondValue"), + ("#AGG.stddevPop", Math.sqrt(0.25)), + ("#AGG.stddevSamp", Math.sqrt(0.5)), + ("#AGG.varPop", 0.25), + ("#AGG.varSamp", 0.5) + ) + + forAll(table) { (aggregationName, secondValue) => + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"))) + val testProcess = sliding(aggregationName, "#input.eId", emitWhenEventLeft = false) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + val mapped = aggregateVariables + .map(e => e.asInstanceOf[Double]) + mapped.size shouldBe 2 + mapped(0) shouldBe 0.0 +- 0.0001 + mapped(1) shouldBe secondValue +- 0.0001 + } + } + test("sliding aggregate should emit context of variables") { val id = "1" @@ -399,7 +452,32 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) aggregateVariables.length shouldEqual (2) aggregateVariables(0) shouldEqual 1.0 - require((aggregateVariables(1).asInstanceOf[Double].isNaN)) + aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true + } + + test( + "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type double" + ) { + val table = Table( + "aggregatorExpression", + "#AGG.stddevPop", + "#AGG.stddevSamp", + "#AGG.varPop", + "#AGG.varSamp" + ) + + forAll(table) { aggregatorName => + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"))) + val testProcess = tumbling(aggregatorName, "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables.length shouldEqual (2) + aggregateVariables(0) shouldEqual 0.0 + aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true + } } test("emit aggregate for extra window when no data come for average aggregator for return type BigDecimal") { @@ -414,6 +492,34 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null) } + test( + "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type BigDecimal" + ) { + val table = Table( + "aggregatorExpression", + "#AGG.stddevPop", + "#AGG.stddevSamp", + "#AGG.varPop", + "#AGG.varSamp" + ) + + forAll(table) { aggregatorName => + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"))) + val testProcess = + tumbling( + aggregatorName, + """T(java.math.BigDecimal).ONE""", + emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow + ) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables shouldEqual List(new java.math.BigDecimal("0"), null) + } + } + test("emit aggregate for extra window when no data come - out of order elements") { val id = "1" diff --git a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java index a6329c746e5..7872db9cba4 100644 --- a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java +++ b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java @@ -27,6 +27,10 @@ public class AggregateHelper implements Serializable { new FixedExpressionValue("#AGG.last", "Last"), new FixedExpressionValue("#AGG.countWhen", "CountWhen"), new FixedExpressionValue("#AGG.average", "Average"), + new FixedExpressionValue("#AGG.stddevPop", "StddevPop"), + new FixedExpressionValue("#AGG.stddevSamp", "StddevSamp"), + new FixedExpressionValue("#AGG.varPop", "VarPop"), + new FixedExpressionValue("#AGG.varSamp", "VarSamp"), new FixedExpressionValue("#AGG.min", "Min"), new FixedExpressionValue("#AGG.max", "Max"), new FixedExpressionValue("#AGG.sum", "Sum"), @@ -46,6 +50,10 @@ public class AggregateHelper implements Serializable { private static final Aggregator LAST = aggregates.LastAggregator$.MODULE$; private static final Aggregator COUNT_WHEN = aggregates.CountWhenAggregator$.MODULE$; private static final Aggregator AVERAGE = aggregates.AverageAggregator$.MODULE$; + private static final Aggregator STDDEV_POP = aggregates.PopulationStandardDeviationAggregator$.MODULE$; + private static final Aggregator STDDEV_SAMP = aggregates.SampleStandardDeviationAggregator$.MODULE$; + private static final Aggregator VAR_POP = aggregates.PopulationVarianceAggregator$.MODULE$; + private static final Aggregator VAR_SAMP = aggregates.SampleVarianceAggregator$.MODULE$; private static final Aggregator APPROX_CARDINALITY = HyperLogLogPlusAggregator$.MODULE$.apply(5, 10); // Instance methods below are for purpose of using in SpEL so your IDE can report that they are not used. @@ -67,6 +75,11 @@ public class AggregateHelper implements Serializable { public Aggregator countWhen = COUNT_WHEN; public Aggregator average = AVERAGE; + public Aggregator stddevPop = STDDEV_POP; + public Aggregator stddevSamp = STDDEV_SAMP; + public Aggregator varPop = VAR_POP; + public Aggregator varSamp = VAR_SAMP; + public Aggregator approxCardinality = APPROX_CARDINALITY; public Aggregator map(@ParamName("parts") Map parts) { diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala index 09d79139607..3f54f4cbbc1 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala @@ -201,48 +201,67 @@ object aggregates { computeOutputType(input) } - object AverageAggregator extends Aggregator { + @TypeInfo(classOf[LargeFloatSumState.TypeInfoFactory]) + // it would be natural to use type Number instead of this class + // it is done this way so that it is serialized properly + case class LargeFloatSumState( + sumDouble: java.lang.Double, + sumBigDecimal: java.math.BigDecimal, + ) { + def asNumber: Number = Option(sumDouble).getOrElse(sumBigDecimal) + + def withAddedElement(element: Number): LargeFloatSumState = { + LargeFloatSumState.fromNumber(MathUtils.largeFloatingSum(element, asNumber)) + } - override type Element = java.lang.Number + def withMergedState(other: LargeFloatSumState): LargeFloatSumState = { + LargeFloatSumState.fromNumber(MathUtils.largeFloatingSum(asNumber, other.asNumber)) + } + + } + + object LargeFloatSumState { + class TypeInfoFactory extends CaseClassTypeInfoFactory[LargeFloatSumState] + + private def fromNumber(sum: Number): LargeFloatSumState = { + sum match { + case null => LargeFloatSumState(null, null) + case sumDouble: java.lang.Double => LargeFloatSumState(sumDouble, null) + case sumBigDecimal: java.math.BigDecimal => LargeFloatSumState(null, sumBigDecimal) + } + } + + def emptyState: LargeFloatSumState = { + LargeFloatSumState.fromNumber(null) + } + + } + + object AverageAggregator extends Aggregator with LargeFloatingNumberAggregate { override type Aggregate = AverageAggregatorState - override def zero: AverageAggregatorState = AverageAggregatorState(null, 0) + override def zero: AverageAggregatorState = AverageAggregatorState(LargeFloatSumState.emptyState, 0) - override def addElement(element: java.lang.Number, aggregate: Aggregate): Aggregate = - AverageAggregatorState(MathUtils.largeFloatingSum(element, aggregate.sum), aggregate.count + 1) + override def addElement(element: Element, aggregate: Aggregate): Aggregate = { + if (element == null) aggregate + else + AverageAggregatorState(aggregate.sum.withAddedElement(element), aggregate.count + 1) + } override def mergeAggregates(aggregate1: Aggregate, aggregate2: Aggregate): Aggregate = AverageAggregatorState( - MathUtils.largeFloatingSum(aggregate1.sum, aggregate2.sum), + aggregate1.sum.withMergedState(aggregate2.sum), aggregate1.count + aggregate2.count ) override def result(finalAggregate: Aggregate): AnyRef = { val count = finalAggregate.count - finalAggregate.sum match { - case null => - // will be replaced to Double.Nan in alignToExpectedType iff return type is known to be Double - null - case sum: java.lang.Double => (sum / count).asInstanceOf[AnyRef] - case sum: java.math.BigDecimal => (BigDecimal(sum) / BigDecimal(count)).bigDecimal - } - } - - override def alignToExpectedType(value: AnyRef, outputType: TypingResult): AnyRef = { - if (value == null && outputType == Typed(classOf[Double])) { - Double.NaN.asInstanceOf[AnyRef] - } else { - value - } - } - - override def computeOutputType(input: typing.TypingResult): Validated[String, typing.TypingResult] = { - - if (!input.canBeConvertedTo(Typed[Number])) { - Invalid(s"Invalid aggregate type: ${input.display}, should be: ${Typed[Number].display}") + val sum = finalAggregate.sum.asNumber + if (sum == null) { + null } else { - Valid(ForLargeFloatingNumbersOperation.promoteSingle(input)) + MathUtils.divideWithDefaultBigDecimalScale(sum, count) } } @@ -250,31 +269,109 @@ object aggregates { Valid(Typed[AverageAggregatorState]) @TypeInfo(classOf[AverageAggregatorState.TypeInfoFactory]) - // it would be natural to have one field sum: Number instead of nullable sumDouble and sumBigDecimal, - // it is done this way to have types serialized properly case class AverageAggregatorState( - sumDouble: java.lang.Double, - sumBigDecimal: java.math.BigDecimal, + sum: LargeFloatSumState, count: java.lang.Long - ) { - def sum: Number = Option(sumDouble).getOrElse(sumBigDecimal) - } + ) object AverageAggregatorState { class TypeInfoFactory extends CaseClassTypeInfoFactory[AverageAggregatorState] + } + + } + + @TypeInfo(classOf[StandardDeviationState.TypeInfoFactory]) + case class StandardDeviationState( + sum: LargeFloatSumState, + squaresSum: LargeFloatSumState, + count: java.lang.Long + ) - def apply(sum: Number, count: java.lang.Long): AverageAggregatorState = { - sum match { - case null => AverageAggregatorState(null, null, count) - case sumDouble: java.lang.Double => AverageAggregatorState(sumDouble, null, count) - case sumBigDecimal: java.math.BigDecimal => AverageAggregatorState(null, sumBigDecimal, count) + object StandardDeviationState { + class TypeInfoFactory extends CaseClassTypeInfoFactory[StandardDeviationState] + } + + private sealed trait StandardDeviationOrVarianceAggregationType + private case object SampleStandardDeviation extends StandardDeviationOrVarianceAggregationType + private case object PopulationStandardDeviation extends StandardDeviationOrVarianceAggregationType + private case object SampleVariance extends StandardDeviationOrVarianceAggregationType + private case object PopulationVariance extends StandardDeviationOrVarianceAggregationType + + class GeneralStandardDeviationAndVarianceAggregator( + private val standardDeviationVarianceType: StandardDeviationOrVarianceAggregationType + ) extends Aggregator + with LargeFloatingNumberAggregate { + + override type Aggregate = StandardDeviationState + + override def zero: StandardDeviationState = + StandardDeviationState(LargeFloatSumState.emptyState, LargeFloatSumState.emptyState, 0) + + override def addElement(element: Element, aggregate: Aggregate): Aggregate = { + if (element == null) aggregate + else + StandardDeviationState( + sum = aggregate.sum.withAddedElement(element), + squaresSum = aggregate.squaresSum.withAddedElement(MathUtils.largeFloatSquare(element)), + count = aggregate.count + 1 + ) + } + + override def mergeAggregates(aggregate1: Aggregate, aggregate2: Aggregate): Aggregate = + StandardDeviationState( + sum = aggregate1.sum.withMergedState(aggregate2.sum), + squaresSum = aggregate1.squaresSum.withMergedState(aggregate2.squaresSum), + count = aggregate1.count + aggregate2.count + ) + + override def result(finalAggregate: Aggregate): AnyRef = { + if (finalAggregate.count == 0 || finalAggregate.sum.asNumber == null || finalAggregate.squaresSum == null) { + // will be replaced to Double.Nan in alignToExpectedType iff return type is known to be Double + null + } else if (finalAggregate.count == 1) { + // zero of the same type as aggregated number + MathUtils.minus(finalAggregate.sum.asNumber, finalAggregate.sum.asNumber) + } else { + val count = finalAggregate.count + val average = MathUtils.divideWithDefaultBigDecimalScale(finalAggregate.sum.asNumber, count) + val averageSquare = MathUtils.divideWithDefaultBigDecimalScale(finalAggregate.squaresSum.asNumber, count) + val populationVariance = MathUtils.minus(averageSquare, MathUtils.largeFloatSquare(average)) + val sampleVariance = MathUtils.multiply(count.toDouble / (count - 1), populationVariance) + + standardDeviationVarianceType match { + case SampleStandardDeviation => MathUtils.largeFloatSqrt(sampleVariance) + case PopulationStandardDeviation => MathUtils.largeFloatSqrt(populationVariance) + case SampleVariance => sampleVariance + case PopulationVariance => populationVariance } } - } + override def computeStoredType(input: typing.TypingResult): Validated[String, typing.TypingResult] = + Valid(Typed[StandardDeviationState]) + } + object SampleStandardDeviationAggregator + extends GeneralStandardDeviationAndVarianceAggregator( + SampleStandardDeviation + ) + + object PopulationStandardDeviationAggregator + extends GeneralStandardDeviationAndVarianceAggregator( + PopulationStandardDeviation + ) + + object SampleVarianceAggregator + extends GeneralStandardDeviationAndVarianceAggregator( + SampleVariance + ) + + object PopulationVarianceAggregator + extends GeneralStandardDeviationAndVarianceAggregator( + PopulationVariance + ) + /* This is more complex aggregator, as it is composed from smaller ones. The idea is that we define aggregator: @@ -454,4 +551,26 @@ object aggregates { protected def promotedType(typ: TypingResult): TypingResult = promotionStrategy.promoteSingle(typ) } + trait LargeFloatingNumberAggregate { self: Aggregator => + override type Element = java.lang.Number + + override def alignToExpectedType(value: AnyRef, outputType: TypingResult): AnyRef = { + if (value == null && outputType == Typed(classOf[Double])) { + Double.NaN.asInstanceOf[AnyRef] + } else { + value + } + } + + override def computeOutputType(input: typing.TypingResult): Validated[String, typing.TypingResult] = { + + if (!input.canBeConvertedTo(Typed[Number])) { + Invalid(s"Invalid aggregate type: ${input.display}, should be: ${Typed[Number].display}") + } else { + Valid(ForLargeFloatingNumbersOperation.promoteSingle(input)) + } + } + + } + } diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala index a992b6eb5cb..cb4ab891e21 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala @@ -41,6 +41,10 @@ object sampleTransformers { new LabeledExpression(label = "Sum", expression = "#AGG.sum"), new LabeledExpression(label = "Average", expression = "#AGG.average"), new LabeledExpression(label = "CountWhen", expression = "#AGG.countWhen"), + new LabeledExpression(label = "StddevPop", expression = "#AGG.stddevPop"), + new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), + new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), + new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -94,6 +98,10 @@ object sampleTransformers { new LabeledExpression(label = "Sum", expression = "#AGG.sum"), new LabeledExpression(label = "Average", expression = "#AGG.average"), new LabeledExpression(label = "CountWhen", expression = "#AGG.countWhen"), + new LabeledExpression(label = "StddevPop", expression = "#AGG.stddevPop"), + new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), + new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), + new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -151,6 +159,10 @@ object sampleTransformers { new LabeledExpression(label = "Sum", expression = "#AGG.sum"), new LabeledExpression(label = "Average", expression = "#AGG.average"), new LabeledExpression(label = "CountWhen", expression = "#AGG.countWhen"), + new LabeledExpression(label = "StddevPop", expression = "#AGG.stddevPop"), + new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), + new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), + new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") diff --git a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json index 03d4cc6889f..12f3af1ca0c 100644 --- a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json +++ b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json @@ -14894,6 +14894,42 @@ } } ], + "stddevPop": [ + { + "name": "stddevPop", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], + "stddevSamp": [ + { + "name": "stddevSamp", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], + "varPop": [ + { + "name": "varPop", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], + "varSamp": [ + { + "name": "varSamp", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], "first": [ { "name": "first", @@ -17733,4 +17769,4 @@ ] } } -] \ No newline at end of file +] diff --git a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala index 2455028ad53..2af9e71b786 100644 --- a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala +++ b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala @@ -8,11 +8,13 @@ import pl.touk.nussknacker.engine.api.typed.supertype.{ } import java.lang +import java.math.MathContext import java.math.RoundingMode +import javax.annotation.Nullable trait MathUtils { - def min(n1: Number, n2: Number): Number = { + def min(@Nullable n1: Number, @Nullable n2: Number): Number = { implicit val promotionStrategy: ReturningSingleClassPromotionStrategy = NumberTypesPromotionStrategy.ForMinMax withNotNullValues(n1, n2) { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandlerReturningNumber { @@ -36,7 +38,7 @@ trait MathUtils { } } - def max(n1: Number, n2: Number): Number = { + def max(@Nullable n1: Number, @Nullable n2: Number): Number = { implicit val promotionStrategy: ReturningSingleClassPromotionStrategy = NumberTypesPromotionStrategy.ForMinMax withNotNullValues(n1, n2) { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandlerReturningNumber { @@ -79,6 +81,27 @@ trait MathUtils { promoteThenSum(n1, n2) } + @Hidden + def largeFloatSquare(number: Number): Number = { + implicit val promotionStrategy: ReturningSingleClassPromotionStrategy = + NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation + val converted = convertToPromotedType(number) + multiply(converted, converted) + } + + @Hidden + def largeFloatSqrt(number: Number): Number = { + implicit val promotionStrategy: ReturningSingleClassPromotionStrategy = + NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation + + val converted = convertToPromotedType(number) + + converted match { + case converted: java.lang.Double => Math.sqrt(converted) + case converted: java.math.BigDecimal => converted.sqrt(MathContext.DECIMAL128) + } + } + def plus(n1: Number, n2: Number): Number = sum(n1, n2) def minus(n1: Number, n2: Number): Number = { @@ -107,6 +130,18 @@ trait MathUtils { })(NumberTypesPromotionStrategy.ForMathOperation) } + // divide method has peculiar behaviour when it comes to BigDecimals (see its implementation), hence this method is sometimes needed + @Hidden + def divideWithDefaultBigDecimalScale(n1: Number, n2: Number): Number = { + if (n1.isInstanceOf[java.math.BigDecimal] || n2.isInstanceOf[java.math.BigDecimal]) { + (BigDecimal(SpringNumberUtils.convertNumberToTargetClass(n1, classOf[java.math.BigDecimal])) + / + BigDecimal(SpringNumberUtils.convertNumberToTargetClass(n2, classOf[java.math.BigDecimal]))).bigDecimal + } else { + divide(n1, n2) + } + } + def divide(n1: Number, n2: Number): Number = { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandlerForPromotingMathOp { override def onInts(n1: java.lang.Integer, n2: java.lang.Integer): java.lang.Integer = n1 / n2 @@ -118,6 +153,8 @@ trait MathUtils { override def onBigDecimals(n1: java.math.BigDecimal, n2: java.math.BigDecimal): java.math.BigDecimal = { n1.divide( n2, + // This is copied behaviour of divide operation in spel (class OpDivide) but it can lead to issues when both big decimals have small scales. + // Small scales happen when integer is converted to BigDecimal using SpringNumberUtils.convertNumberToTargetClass Math.max(n1.scale(), n2.scale), RoundingMode.HALF_EVEN ) // same scale and rounding as used by OpDivide in SpelExpression.java @@ -169,7 +206,7 @@ trait MathUtils { def equal(n1: Number, n2: Number): Boolean = compare(n1, n2) == 0 def notEqual(n1: Number, n2: Number): Boolean = compare(n1, n2) != 0 - private def promoteThenSum(n1: Number, n2: Number)( + private def promoteThenSum(@Nullable n1: Number, @Nullable n2: Number)( implicit promotionStrategy: ReturningSingleClassPromotionStrategy ) = { withNotNullValues(n1, n2) { @@ -186,7 +223,7 @@ trait MathUtils { } } - protected def withNotNullValues(n1: Number, n2: Number)( + protected def withNotNullValues(@Nullable n1: Number, @Nullable n2: Number)( f: => Number )(implicit promotionStrategy: ReturningSingleClassPromotionStrategy): Number = { if (n1 == null) { From ace3d3c20220dccef03e006c36e532f9c581402c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czajka?= Date: Wed, 18 Dec 2024 10:24:16 +0100 Subject: [PATCH 2/3] [NU-1921] Add median aggregator #7321 Co-authored-by: Pawel Czajka --- docs/Changelog.md | 1 + .../AggregatesInTimeWindows.md | 1 + .../aggregate/AggregatesSpec.scala | 81 ++++++++++++++++++- .../aggregate/TransformersTest.scala | 44 ++++++++++ .../aggregate/AggregateHelper.java | 4 + .../transformer/aggregate/aggregates.scala | 26 +++++- .../aggregate/median/MedianHelper.scala | 58 +++++++++++++ .../aggregate/sampleTransformers.scala | 3 + .../extractedTypes/defaultModel.json | 9 +++ .../nussknacker/engine/util/MathUtils.scala | 6 +- 10 files changed, 227 insertions(+), 6 deletions(-) create mode 100644 engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala diff --git a/docs/Changelog.md b/docs/Changelog.md index 964d82c3cd9..8d47f6c0de5 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -29,6 +29,7 @@ * [#7332](https://github.com/TouK/nussknacker/pull/7332) Handle scenario names with spaces when performing migration tests, they were ignored * [#7346](https://github.com/TouK/nussknacker/pull/7346) OpenAPI enricher: ability to configure common secret for any security scheme * [#7307](https://github.com/TouK/nussknacker/pull/7307) Added StddevPop, StddevSamp, VarPop and VarSamp aggregators +* [#7321](https://github.com/TouK/nussknacker/pull/7321) Added Median aggregator ## 1.18 diff --git a/docs/scenarios_authoring/AggregatesInTimeWindows.md b/docs/scenarios_authoring/AggregatesInTimeWindows.md index 322fb5b8e4a..f354d815d94 100644 --- a/docs/scenarios_authoring/AggregatesInTimeWindows.md +++ b/docs/scenarios_authoring/AggregatesInTimeWindows.md @@ -85,6 +85,7 @@ Let’s map the above statement on the parameters of the Nussknacker Aggregate c * StddevSamp - computes sample standard deviation * VarPop - computes population variance * VarSamp - computes sample variance +* Median - computes median * ApproximateSetCardinality - computes approximate cardinality of a set using [HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) algorithm. Please note that this aggregator treats null as a unique value. If this is undesirable and the set passed to ApproximateSetCardinality aggregator contained null (this can be tested with safe navigation in [SpEL](./Spel.md#safe-navigation)), subtract 1 from the obtained result. If you need to count events in a window, use the CountWhen aggregate function and aggregate by fixed `true` expression - see the table with examples below. Subsequent sections use the Count function on the diagrams as an example for the **aggregator** - it is the easiest function to use in the examples. Please note, however, that technically, we provide an indirect implementation of this aggregator. diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala index 07588ed4335..c4243070a9d 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala @@ -5,7 +5,24 @@ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypedObjectTypingResult, TypingResult, Unknown} import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregatesSpec.{EPS_BIG_DECIMAL, EPS_DOUBLE} -import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{AverageAggregator, CountWhenAggregator, FirstAggregator, LastAggregator, ListAggregator, MapAggregator, MaxAggregator, MinAggregator, OptionAggregator, PopulationStandardDeviationAggregator, PopulationVarianceAggregator, SampleStandardDeviationAggregator, SampleVarianceAggregator, SetAggregator, SumAggregator} +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{ + AverageAggregator, + CountWhenAggregator, + FirstAggregator, + LastAggregator, + ListAggregator, + MapAggregator, + MaxAggregator, + MedianAggregator, + MinAggregator, + OptionAggregator, + PopulationStandardDeviationAggregator, + PopulationVarianceAggregator, + SampleStandardDeviationAggregator, + SampleVarianceAggregator, + SetAggregator, + SumAggregator +} import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap import java.lang.{Integer => JInt, Long => JLong} @@ -127,7 +144,8 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat MinAggregator, FirstAggregator, LastAggregator, - SumAggregator + SumAggregator, + MedianAggregator )) { agg => addElementsAndComputeResult(List(null), agg) shouldEqual null } @@ -148,6 +166,62 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat } } + test("should calculate correct results for median aggregator on integers") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(7, 8), agg) + result shouldBe a[Double] + result shouldEqual 7.5 + } + + test("should calculate correct results for median aggregator on integers on single value") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(7), agg) + result shouldBe a[Double] + result shouldEqual 7 + } + + test("should calculate correct results for median aggregator on BigInt") { + val agg = MedianAggregator + addElementsAndComputeResult(List(new BigInteger("7"), new BigInteger("8")), agg) shouldEqual new java.math.BigDecimal("7.5") + } + + test("should calculate correct results for median aggregator on floats") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(7.0f, 8.0f), agg) + result shouldBe a[Double] + result shouldEqual 7.5 + } + + test("should calculate correct results for median aggregator on BigDecimals") { + val agg = MedianAggregator + addElementsAndComputeResult( + List(new java.math.BigDecimal("7"), new java.math.BigDecimal("8")), + agg + ) shouldEqual new java.math.BigDecimal("7.5") + } + + test("should ignore nulls for median aggregator") { + val agg = MedianAggregator + addElementsAndComputeResult( + List(null, new java.math.BigDecimal("7"), null, new java.math.BigDecimal("8")), + agg + ) shouldEqual new java.math.BigDecimal("7.5") + } + + test("MedianAggregator test on odd length list") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5, 90), agg) + + result shouldEqual 5 + } + + test("MedianAggregator test on even length list") { + val agg = MedianAggregator + val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5), agg) + + result shouldEqual 4.5 + } + test("should calculate correct results for standard deviation and variance on integers") { val table = Table( ("aggregator", "value"), @@ -230,7 +304,8 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat ( SumAggregator, 15.0), ( MaxAggregator, 5.0), ( MinAggregator, 1.0), - ( AverageAggregator, 3.0) + ( AverageAggregator, 3.0), + ( MedianAggregator, 3.0) ) forAll(table) { (agg, expectedResult) => diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala index 579ebf06e58..e714c4101f0 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala @@ -96,6 +96,13 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateOk("#AGG.varPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) validateOk("#AGG.varSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.median", """#input.eId""", Typed[Double]) + validateOk("#AGG.median", """1""", Typed[Double]) + validateOk("#AGG.median", """1.5""", Typed[Double]) + + validateOk("#AGG.median", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.median", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.set", "#input.str", Typed.fromDetailedType[java.util.Set[String]]) validateOk( "#AGG.map({f1: #AGG.sum, f2: #AGG.set})", @@ -106,6 +113,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateError("#AGG.sum", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.countWhen", "#input.str", "Invalid aggregate type: String, should be: Boolean") validateError("#AGG.average", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError("#AGG.median", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.stddevPop", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.stddevSamp", "#input.str", "Invalid aggregate type: String, should be: Number") @@ -170,6 +178,17 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldBe List(1.0d, 1.5, 3.5) } + test("median aggregate") { + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b"))) + val testProcess = sliding("#AGG.median", "#input.eId", emitWhenEventLeft = false) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables shouldBe List(1.0d, 1.5, 3.5) + } + test("standard deviation and average aggregates") { val table = Table( ("aggregate", "secondValue"), @@ -455,6 +474,19 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true } + test("emit aggregate for extra window when no data come for median aggregator for return type double") { + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"))) + val testProcess = tumbling("#AGG.median", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables.length shouldEqual (2) + aggregateVariables(0) shouldEqual 1.0 + aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true + } + test( "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type double" ) { @@ -492,6 +524,18 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null) } + test("emit aggregate for extra window when no data come for median aggregator for return type BigDecimal") { + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"))) + val testProcess = + tumbling("#AGG.median", """T(java.math.BigDecimal).ONE""", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null) + } + test( "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type BigDecimal" ) { diff --git a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java index 7872db9cba4..9be0d53b6fb 100644 --- a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java +++ b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java @@ -31,6 +31,7 @@ public class AggregateHelper implements Serializable { new FixedExpressionValue("#AGG.stddevSamp", "StddevSamp"), new FixedExpressionValue("#AGG.varPop", "VarPop"), new FixedExpressionValue("#AGG.varSamp", "VarSamp"), + new FixedExpressionValue("#AGG.median", "Median"), new FixedExpressionValue("#AGG.min", "Min"), new FixedExpressionValue("#AGG.max", "Max"), new FixedExpressionValue("#AGG.sum", "Sum"), @@ -54,6 +55,7 @@ public class AggregateHelper implements Serializable { private static final Aggregator STDDEV_SAMP = aggregates.SampleStandardDeviationAggregator$.MODULE$; private static final Aggregator VAR_POP = aggregates.PopulationVarianceAggregator$.MODULE$; private static final Aggregator VAR_SAMP = aggregates.SampleVarianceAggregator$.MODULE$; + private static final Aggregator MEDIAN = aggregates.MedianAggregator$.MODULE$; private static final Aggregator APPROX_CARDINALITY = HyperLogLogPlusAggregator$.MODULE$.apply(5, 10); // Instance methods below are for purpose of using in SpEL so your IDE can report that they are not used. @@ -80,6 +82,8 @@ public class AggregateHelper implements Serializable { public Aggregator varPop = VAR_POP; public Aggregator varSamp = VAR_SAMP; + public Aggregator median = MEDIAN; + public Aggregator approxCardinality = APPROX_CARDINALITY; public Aggregator map(@ParamName("parts") Map parts) { diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala index 3f54f4cbbc1..0aa361f8574 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala @@ -5,15 +5,19 @@ import cats.data.{NonEmptyList, Validated} import cats.instances.list._ import org.apache.flink.api.common.typeinfo.TypeInfo import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy -import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation +import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{ + ForLargeFloatingNumbersOperation, +} import pl.touk.nussknacker.engine.api.typed.typing._ import pl.touk.nussknacker.engine.api.typed.{NumberTypeUtils, typing} import pl.touk.nussknacker.engine.flink.api.typeinfo.caseclass.CaseClassTypeInfoFactory +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median.MedianHelper import pl.touk.nussknacker.engine.util.Implicits._ import pl.touk.nussknacker.engine.util.MathUtils import pl.touk.nussknacker.engine.util.validated.ValidatedSyntax._ import java.util +import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ /* @@ -69,6 +73,26 @@ object aggregates { } + object MedianAggregator extends Aggregator with LargeFloatingNumberAggregate { + + override type Aggregate = ListBuffer[Number] + + override type Element = Number + + override def zero: Aggregate = ListBuffer.empty + + override def addElement(el: Element, agg: Aggregate): Aggregate = if (el == null) agg else agg.addOne(el) + + override def mergeAggregates(agg1: Aggregate, agg2: Aggregate): Aggregate = agg1 ++ agg2 + + override def result(finalAggregate: Aggregate): AnyRef = MedianHelper.calculateMedian(finalAggregate.toList).orNull + + override def computeStoredType(input: TypingResult): Validated[String, TypingResult] = Valid( + Typed.genericTypeClass[ListBuffer[_]](List(input)) + ) + + } + object ListAggregator extends Aggregator { override type Aggregate = List[AnyRef] diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala new file mode 100644 index 00000000000..a4685a3b9b5 --- /dev/null +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala @@ -0,0 +1,58 @@ +package pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median + +import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{ + ForLargeFloatingNumbersOperation, +} +import pl.touk.nussknacker.engine.util.MathUtils + +import scala.annotation.tailrec +import scala.util.Random + +object MedianHelper { + private val rand = new Random(42) + + def calculateMedian(numbers: List[Number]): Option[Number] = { + if (numbers.isEmpty) { + None + } else if (numbers.size % 2 == 1) { + Some(MathUtils.convertToPromotedType(quickSelect(numbers, (numbers.size - 1) / 2))(ForLargeFloatingNumbersOperation)) + } else { + // it is possible to fetch both numbers with single recursion, but it would complicate code + val firstNumber = quickSelect(numbers, numbers.size / 2 - 1) + val secondNumber = quickSelect(numbers, numbers.size / 2) + + val sum = MathUtils.largeFloatingSum(firstNumber, secondNumber) + Some(MathUtils.divideWithDefaultBigDecimalScale(sum, 2)) + } + } + + // https://en.wikipedia.org/wiki/Quickselect + @tailrec + private def quickSelect(numbers: List[Number], indexToTake: Int): Number = { + require(numbers.nonEmpty) + require(indexToTake >= 0) + require(indexToTake < numbers.size) + + val randomElement = numbers(rand.nextInt(numbers.size)) + val groupedBy = numbers.groupBy(e => { + val cmp = MathUtils.compare(e, randomElement) + if (cmp < 0) { + -1 + } else if (cmp == 0) { + 0 + } else 1 + }) + val smallerNumbers = groupedBy.getOrElse(-1, Nil) + val equalNumbers = groupedBy.getOrElse(0, Nil) + val largerNumbers = groupedBy.getOrElse(1, Nil) + + if (indexToTake < smallerNumbers.size) { + quickSelect(smallerNumbers, indexToTake) + } else if (indexToTake < smallerNumbers.size + equalNumbers.size) { + equalNumbers(indexToTake - smallerNumbers.size) + } else { + quickSelect(largerNumbers, indexToTake - smallerNumbers.size - equalNumbers.size) + } + } + +} diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala index cb4ab891e21..d38ce11e8a8 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala @@ -45,6 +45,7 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), + new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -102,6 +103,7 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), + new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -163,6 +165,7 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), + new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") diff --git a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json index 12f3af1ca0c..029dedb5548 100644 --- a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json +++ b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json @@ -14930,6 +14930,15 @@ } } ], + "median": [ + { + "name": "median", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], "first": [ { "name": "first", diff --git a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala index 2af9e71b786..8a78b354d9b 100644 --- a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala +++ b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala @@ -186,7 +186,8 @@ trait MathUtils { case n1: java.math.BigDecimal => n1.negate() } - private def compare(n1: Number, n2: Number): Int = { + @Hidden + def compare(n1: Number, n2: Number): Int = { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandler[Int] { override def onBytes(n1: java.lang.Byte, n2: java.lang.Byte): Int = n1.compareTo(n2) override def onShorts(n1: java.lang.Short, n2: java.lang.Short): Int = n1.compareTo(n2) @@ -286,7 +287,8 @@ trait MathUtils { } } - private def convertToPromotedType( + @Hidden + def convertToPromotedType( n: Number )(implicit promotionStrategy: ReturningSingleClassPromotionStrategy): Number = { // In some cases type can be promoted to other class e.g. Byte is promoted to Int for sum From f9cdc9eb6dc8160c5f91a1913c4241f7764df7dc Mon Sep 17 00:00:00 2001 From: Pawel Czajka Date: Wed, 18 Dec 2024 11:33:47 +0100 Subject: [PATCH 3/3] Revert "[NU-1921] Add median aggregator #7321" This reverts commit ace3d3c20220dccef03e006c36e532f9c581402c. --- docs/Changelog.md | 1 - .../AggregatesInTimeWindows.md | 1 - .../aggregate/AggregatesSpec.scala | 81 +------------------ .../aggregate/TransformersTest.scala | 44 ---------- .../aggregate/AggregateHelper.java | 4 - .../transformer/aggregate/aggregates.scala | 26 +----- .../aggregate/median/MedianHelper.scala | 58 ------------- .../aggregate/sampleTransformers.scala | 3 - .../extractedTypes/defaultModel.json | 9 --- .../nussknacker/engine/util/MathUtils.scala | 6 +- 10 files changed, 6 insertions(+), 227 deletions(-) delete mode 100644 engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala diff --git a/docs/Changelog.md b/docs/Changelog.md index 8d47f6c0de5..964d82c3cd9 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -29,7 +29,6 @@ * [#7332](https://github.com/TouK/nussknacker/pull/7332) Handle scenario names with spaces when performing migration tests, they were ignored * [#7346](https://github.com/TouK/nussknacker/pull/7346) OpenAPI enricher: ability to configure common secret for any security scheme * [#7307](https://github.com/TouK/nussknacker/pull/7307) Added StddevPop, StddevSamp, VarPop and VarSamp aggregators -* [#7321](https://github.com/TouK/nussknacker/pull/7321) Added Median aggregator ## 1.18 diff --git a/docs/scenarios_authoring/AggregatesInTimeWindows.md b/docs/scenarios_authoring/AggregatesInTimeWindows.md index f354d815d94..322fb5b8e4a 100644 --- a/docs/scenarios_authoring/AggregatesInTimeWindows.md +++ b/docs/scenarios_authoring/AggregatesInTimeWindows.md @@ -85,7 +85,6 @@ Let’s map the above statement on the parameters of the Nussknacker Aggregate c * StddevSamp - computes sample standard deviation * VarPop - computes population variance * VarSamp - computes sample variance -* Median - computes median * ApproximateSetCardinality - computes approximate cardinality of a set using [HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) algorithm. Please note that this aggregator treats null as a unique value. If this is undesirable and the set passed to ApproximateSetCardinality aggregator contained null (this can be tested with safe navigation in [SpEL](./Spel.md#safe-navigation)), subtract 1 from the obtained result. If you need to count events in a window, use the CountWhen aggregate function and aggregate by fixed `true` expression - see the table with examples below. Subsequent sections use the Count function on the diagrams as an example for the **aggregator** - it is the easiest function to use in the examples. Please note, however, that technically, we provide an indirect implementation of this aggregator. diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala index c4243070a9d..07588ed4335 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala @@ -5,24 +5,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypedObjectTypingResult, TypingResult, Unknown} import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregatesSpec.{EPS_BIG_DECIMAL, EPS_DOUBLE} -import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{ - AverageAggregator, - CountWhenAggregator, - FirstAggregator, - LastAggregator, - ListAggregator, - MapAggregator, - MaxAggregator, - MedianAggregator, - MinAggregator, - OptionAggregator, - PopulationStandardDeviationAggregator, - PopulationVarianceAggregator, - SampleStandardDeviationAggregator, - SampleVarianceAggregator, - SetAggregator, - SumAggregator -} +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{AverageAggregator, CountWhenAggregator, FirstAggregator, LastAggregator, ListAggregator, MapAggregator, MaxAggregator, MinAggregator, OptionAggregator, PopulationStandardDeviationAggregator, PopulationVarianceAggregator, SampleStandardDeviationAggregator, SampleVarianceAggregator, SetAggregator, SumAggregator} import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap import java.lang.{Integer => JInt, Long => JLong} @@ -144,8 +127,7 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat MinAggregator, FirstAggregator, LastAggregator, - SumAggregator, - MedianAggregator + SumAggregator )) { agg => addElementsAndComputeResult(List(null), agg) shouldEqual null } @@ -166,62 +148,6 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat } } - test("should calculate correct results for median aggregator on integers") { - val agg = MedianAggregator - val result = addElementsAndComputeResult(List(7, 8), agg) - result shouldBe a[Double] - result shouldEqual 7.5 - } - - test("should calculate correct results for median aggregator on integers on single value") { - val agg = MedianAggregator - val result = addElementsAndComputeResult(List(7), agg) - result shouldBe a[Double] - result shouldEqual 7 - } - - test("should calculate correct results for median aggregator on BigInt") { - val agg = MedianAggregator - addElementsAndComputeResult(List(new BigInteger("7"), new BigInteger("8")), agg) shouldEqual new java.math.BigDecimal("7.5") - } - - test("should calculate correct results for median aggregator on floats") { - val agg = MedianAggregator - val result = addElementsAndComputeResult(List(7.0f, 8.0f), agg) - result shouldBe a[Double] - result shouldEqual 7.5 - } - - test("should calculate correct results for median aggregator on BigDecimals") { - val agg = MedianAggregator - addElementsAndComputeResult( - List(new java.math.BigDecimal("7"), new java.math.BigDecimal("8")), - agg - ) shouldEqual new java.math.BigDecimal("7.5") - } - - test("should ignore nulls for median aggregator") { - val agg = MedianAggregator - addElementsAndComputeResult( - List(null, new java.math.BigDecimal("7"), null, new java.math.BigDecimal("8")), - agg - ) shouldEqual new java.math.BigDecimal("7.5") - } - - test("MedianAggregator test on odd length list") { - val agg = MedianAggregator - val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5, 90), agg) - - result shouldEqual 5 - } - - test("MedianAggregator test on even length list") { - val agg = MedianAggregator - val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5), agg) - - result shouldEqual 4.5 - } - test("should calculate correct results for standard deviation and variance on integers") { val table = Table( ("aggregator", "value"), @@ -304,8 +230,7 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat ( SumAggregator, 15.0), ( MaxAggregator, 5.0), ( MinAggregator, 1.0), - ( AverageAggregator, 3.0), - ( MedianAggregator, 3.0) + ( AverageAggregator, 3.0) ) forAll(table) { (agg, expectedResult) => diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala index e714c4101f0..579ebf06e58 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala @@ -96,13 +96,6 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateOk("#AGG.varPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) validateOk("#AGG.varSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) - validateOk("#AGG.median", """#input.eId""", Typed[Double]) - validateOk("#AGG.median", """1""", Typed[Double]) - validateOk("#AGG.median", """1.5""", Typed[Double]) - - validateOk("#AGG.median", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) - validateOk("#AGG.median", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) - validateOk("#AGG.set", "#input.str", Typed.fromDetailedType[java.util.Set[String]]) validateOk( "#AGG.map({f1: #AGG.sum, f2: #AGG.set})", @@ -113,7 +106,6 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateError("#AGG.sum", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.countWhen", "#input.str", "Invalid aggregate type: String, should be: Boolean") validateError("#AGG.average", "#input.str", "Invalid aggregate type: String, should be: Number") - validateError("#AGG.median", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.stddevPop", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.stddevSamp", "#input.str", "Invalid aggregate type: String, should be: Number") @@ -178,17 +170,6 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldBe List(1.0d, 1.5, 3.5) } - test("median aggregate") { - val id = "1" - - val model = - modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b"))) - val testProcess = sliding("#AGG.median", "#input.eId", emitWhenEventLeft = false) - - val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) - aggregateVariables shouldBe List(1.0d, 1.5, 3.5) - } - test("standard deviation and average aggregates") { val table = Table( ("aggregate", "secondValue"), @@ -474,19 +455,6 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true } - test("emit aggregate for extra window when no data come for median aggregator for return type double") { - val id = "1" - - val model = - modelData(List(TestRecordHours(id, 0, 1, "a"))) - val testProcess = tumbling("#AGG.median", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow) - - val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) - aggregateVariables.length shouldEqual (2) - aggregateVariables(0) shouldEqual 1.0 - aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true - } - test( "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type double" ) { @@ -524,18 +492,6 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null) } - test("emit aggregate for extra window when no data come for median aggregator for return type BigDecimal") { - val id = "1" - - val model = - modelData(List(TestRecordHours(id, 0, 1, "a"))) - val testProcess = - tumbling("#AGG.median", """T(java.math.BigDecimal).ONE""", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow) - - val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) - aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null) - } - test( "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type BigDecimal" ) { diff --git a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java index 9be0d53b6fb..7872db9cba4 100644 --- a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java +++ b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java @@ -31,7 +31,6 @@ public class AggregateHelper implements Serializable { new FixedExpressionValue("#AGG.stddevSamp", "StddevSamp"), new FixedExpressionValue("#AGG.varPop", "VarPop"), new FixedExpressionValue("#AGG.varSamp", "VarSamp"), - new FixedExpressionValue("#AGG.median", "Median"), new FixedExpressionValue("#AGG.min", "Min"), new FixedExpressionValue("#AGG.max", "Max"), new FixedExpressionValue("#AGG.sum", "Sum"), @@ -55,7 +54,6 @@ public class AggregateHelper implements Serializable { private static final Aggregator STDDEV_SAMP = aggregates.SampleStandardDeviationAggregator$.MODULE$; private static final Aggregator VAR_POP = aggregates.PopulationVarianceAggregator$.MODULE$; private static final Aggregator VAR_SAMP = aggregates.SampleVarianceAggregator$.MODULE$; - private static final Aggregator MEDIAN = aggregates.MedianAggregator$.MODULE$; private static final Aggregator APPROX_CARDINALITY = HyperLogLogPlusAggregator$.MODULE$.apply(5, 10); // Instance methods below are for purpose of using in SpEL so your IDE can report that they are not used. @@ -82,8 +80,6 @@ public class AggregateHelper implements Serializable { public Aggregator varPop = VAR_POP; public Aggregator varSamp = VAR_SAMP; - public Aggregator median = MEDIAN; - public Aggregator approxCardinality = APPROX_CARDINALITY; public Aggregator map(@ParamName("parts") Map parts) { diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala index 0aa361f8574..3f54f4cbbc1 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala @@ -5,19 +5,15 @@ import cats.data.{NonEmptyList, Validated} import cats.instances.list._ import org.apache.flink.api.common.typeinfo.TypeInfo import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy -import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{ - ForLargeFloatingNumbersOperation, -} +import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation import pl.touk.nussknacker.engine.api.typed.typing._ import pl.touk.nussknacker.engine.api.typed.{NumberTypeUtils, typing} import pl.touk.nussknacker.engine.flink.api.typeinfo.caseclass.CaseClassTypeInfoFactory -import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median.MedianHelper import pl.touk.nussknacker.engine.util.Implicits._ import pl.touk.nussknacker.engine.util.MathUtils import pl.touk.nussknacker.engine.util.validated.ValidatedSyntax._ import java.util -import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ /* @@ -73,26 +69,6 @@ object aggregates { } - object MedianAggregator extends Aggregator with LargeFloatingNumberAggregate { - - override type Aggregate = ListBuffer[Number] - - override type Element = Number - - override def zero: Aggregate = ListBuffer.empty - - override def addElement(el: Element, agg: Aggregate): Aggregate = if (el == null) agg else agg.addOne(el) - - override def mergeAggregates(agg1: Aggregate, agg2: Aggregate): Aggregate = agg1 ++ agg2 - - override def result(finalAggregate: Aggregate): AnyRef = MedianHelper.calculateMedian(finalAggregate.toList).orNull - - override def computeStoredType(input: TypingResult): Validated[String, TypingResult] = Valid( - Typed.genericTypeClass[ListBuffer[_]](List(input)) - ) - - } - object ListAggregator extends Aggregator { override type Aggregate = List[AnyRef] diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala deleted file mode 100644 index a4685a3b9b5..00000000000 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/median/MedianHelper.scala +++ /dev/null @@ -1,58 +0,0 @@ -package pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median - -import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{ - ForLargeFloatingNumbersOperation, -} -import pl.touk.nussknacker.engine.util.MathUtils - -import scala.annotation.tailrec -import scala.util.Random - -object MedianHelper { - private val rand = new Random(42) - - def calculateMedian(numbers: List[Number]): Option[Number] = { - if (numbers.isEmpty) { - None - } else if (numbers.size % 2 == 1) { - Some(MathUtils.convertToPromotedType(quickSelect(numbers, (numbers.size - 1) / 2))(ForLargeFloatingNumbersOperation)) - } else { - // it is possible to fetch both numbers with single recursion, but it would complicate code - val firstNumber = quickSelect(numbers, numbers.size / 2 - 1) - val secondNumber = quickSelect(numbers, numbers.size / 2) - - val sum = MathUtils.largeFloatingSum(firstNumber, secondNumber) - Some(MathUtils.divideWithDefaultBigDecimalScale(sum, 2)) - } - } - - // https://en.wikipedia.org/wiki/Quickselect - @tailrec - private def quickSelect(numbers: List[Number], indexToTake: Int): Number = { - require(numbers.nonEmpty) - require(indexToTake >= 0) - require(indexToTake < numbers.size) - - val randomElement = numbers(rand.nextInt(numbers.size)) - val groupedBy = numbers.groupBy(e => { - val cmp = MathUtils.compare(e, randomElement) - if (cmp < 0) { - -1 - } else if (cmp == 0) { - 0 - } else 1 - }) - val smallerNumbers = groupedBy.getOrElse(-1, Nil) - val equalNumbers = groupedBy.getOrElse(0, Nil) - val largerNumbers = groupedBy.getOrElse(1, Nil) - - if (indexToTake < smallerNumbers.size) { - quickSelect(smallerNumbers, indexToTake) - } else if (indexToTake < smallerNumbers.size + equalNumbers.size) { - equalNumbers(indexToTake - smallerNumbers.size) - } else { - quickSelect(largerNumbers, indexToTake - smallerNumbers.size - equalNumbers.size) - } - } - -} diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala index d38ce11e8a8..cb4ab891e21 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala @@ -45,7 +45,6 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), - new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -103,7 +102,6 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), - new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -165,7 +163,6 @@ object sampleTransformers { new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), - new LabeledExpression(label = "Median", expression = "#AGG.median"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") diff --git a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json index 029dedb5548..12f3af1ca0c 100644 --- a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json +++ b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json @@ -14930,15 +14930,6 @@ } } ], - "median": [ - { - "name": "median", - "signature": { - "noVarArgs": [], - "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} - } - } - ], "first": [ { "name": "first", diff --git a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala index 8a78b354d9b..2af9e71b786 100644 --- a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala +++ b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala @@ -186,8 +186,7 @@ trait MathUtils { case n1: java.math.BigDecimal => n1.negate() } - @Hidden - def compare(n1: Number, n2: Number): Int = { + private def compare(n1: Number, n2: Number): Int = { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandler[Int] { override def onBytes(n1: java.lang.Byte, n2: java.lang.Byte): Int = n1.compareTo(n2) override def onShorts(n1: java.lang.Short, n2: java.lang.Short): Int = n1.compareTo(n2) @@ -287,8 +286,7 @@ trait MathUtils { } } - @Hidden - def convertToPromotedType( + private def convertToPromotedType( n: Number )(implicit promotionStrategy: ReturningSingleClassPromotionStrategy): Number = { // In some cases type can be promoted to other class e.g. Byte is promoted to Int for sum