diff --git a/docs/Changelog.md b/docs/Changelog.md index 3986d46db21..964d82c3cd9 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -28,6 +28,7 @@ * [#7323](https://github.com/TouK/nussknacker/pull/7323) Improve Periodic DeploymentManager db queries * [#7332](https://github.com/TouK/nussknacker/pull/7332) Handle scenario names with spaces when performing migration tests, they were ignored * [#7346](https://github.com/TouK/nussknacker/pull/7346) OpenAPI enricher: ability to configure common secret for any security scheme +* [#7307](https://github.com/TouK/nussknacker/pull/7307) Added StddevPop, StddevSamp, VarPop and VarSamp aggregators ## 1.18 diff --git a/docs/scenarios_authoring/AggregatesInTimeWindows.md b/docs/scenarios_authoring/AggregatesInTimeWindows.md index 23ed9fb7baf..322fb5b8e4a 100644 --- a/docs/scenarios_authoring/AggregatesInTimeWindows.md +++ b/docs/scenarios_authoring/AggregatesInTimeWindows.md @@ -81,6 +81,10 @@ Let’s map the above statement on the parameters of the Nussknacker Aggregate c * Set - the result is a set of inputs received by the aggregator. Can be very ineffective for large sets, try to use ApproximateSetCardinality in this case * CountWhen - accepts boolean values, returns how many of them are true * Average - computes average of values +* StddevPop - computes population standard deviation +* StddevSamp - computes sample standard deviation +* VarPop - computes population variance +* VarSamp - computes sample variance * ApproximateSetCardinality - computes approximate cardinality of a set using [HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) algorithm. Please note that this aggregator treats null as a unique value. If this is undesirable and the set passed to ApproximateSetCardinality aggregator contained null (this can be tested with safe navigation in [SpEL](./Spel.md#safe-navigation)), subtract 1 from the obtained result. If you need to count events in a window, use the CountWhen aggregate function and aggregate by fixed `true` expression - see the table with examples below. Subsequent sections use the Count function on the diagrams as an example for the **aggregator** - it is the easiest function to use in the examples. Please note, however, that technically, we provide an indirect implementation of this aggregator. diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala index 1a8a5e13686..07588ed4335 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregatesSpec.scala @@ -3,8 +3,9 @@ package pl.touk.nussknacker.engine.flink.util.transformer.aggregate import org.scalatest.prop.TableDrivenPropertyChecks import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers -import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectTypingResult, TypingResult, Unknown} -import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates._ +import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypedObjectTypingResult, TypingResult, Unknown} +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregatesSpec.{EPS_BIG_DECIMAL, EPS_DOUBLE} +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{AverageAggregator, CountWhenAggregator, FirstAggregator, LastAggregator, ListAggregator, MapAggregator, MaxAggregator, MinAggregator, OptionAggregator, PopulationStandardDeviationAggregator, PopulationVarianceAggregator, SampleStandardDeviationAggregator, SampleVarianceAggregator, SetAggregator, SumAggregator} import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap import java.lang.{Integer => JInt, Long => JLong} @@ -83,73 +84,215 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat test("should calculate correct results for first aggregator") { val agg = FirstAggregator - agg.result( - agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(5.asInstanceOf[agg.Element], agg.zero)) - ) shouldEqual 5 + addElementsAndComputeResult(List(5, 8), agg) shouldEqual 5 } test("should calculate correct results for countWhen aggregator") { val agg = CountWhenAggregator - agg.result( - agg.addElement( - false.asInstanceOf[agg.Element], - agg.addElement(true.asInstanceOf[agg.Element], agg.addElement(true.asInstanceOf[agg.Element], agg.zero)) - ) - ) shouldEqual 2 + addElementsAndComputeResult(List(true, true), agg) shouldEqual 2 } test("should calculate correct results for average aggregator") { val agg = AverageAggregator - agg.result( - agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(7.asInstanceOf[agg.Element], agg.zero)) - ) shouldEqual 7.5 + addElementsAndComputeResult(List(7, 8), agg) shouldEqual 7.5 } test("should calculate correct results for average aggregator on BigInt") { val agg = AverageAggregator - agg.result( - agg.addElement( - new BigInteger("8").asInstanceOf[agg.Element], - agg.addElement(new BigInteger("7").asInstanceOf[agg.Element], agg.zero) - ) - ) shouldEqual new java.math.BigDecimal("7.5") + addElementsAndComputeResult(List(new BigInteger("7"), new BigInteger("8")), agg) shouldEqual new java.math.BigDecimal("7.5") } test("should calculate correct results for average aggregator on float") { val agg = AverageAggregator - agg.result( - agg.addElement( - 8.0f.asInstanceOf[agg.Element], - agg.addElement(7.0f.asInstanceOf[agg.Element], agg.zero) - ) - ) shouldEqual 7.5 + addElementsAndComputeResult(List(7.0f, 8.0f), agg) shouldEqual 7.5 } test("should calculate correct results for average aggregator on BigDecimal") { val agg = AverageAggregator - agg.result( - agg.addElement( - new java.math.BigDecimal("8").asInstanceOf[agg.Element], - agg.addElement(new java.math.BigDecimal("7").asInstanceOf[agg.Element], agg.zero) - ) + addElementsAndComputeResult( + List(new java.math.BigDecimal("7"), new java.math.BigDecimal("8")), + agg ) shouldEqual new java.math.BigDecimal("7.5") } - test("AverageAggregator should calculate correct results for empty aggregation set") { - val agg = AverageAggregator - val result = agg.result( - agg.zero + test("some aggregators should produce null on single null input") { + forAll (Table( + "aggregator", + AverageAggregator, + SampleStandardDeviationAggregator, + PopulationStandardDeviationAggregator, + SampleVarianceAggregator, + PopulationVarianceAggregator, + MaxAggregator, + MinAggregator, + FirstAggregator, + LastAggregator, + SumAggregator + )) { agg => + addElementsAndComputeResult(List(null), agg) shouldEqual null + } + } + + test("should calculate correct results for standard deviation and variance on doubles") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, Math.sqrt(2.5) ), + ( PopulationStandardDeviationAggregator, Math.sqrt(2) ), + ( SampleVarianceAggregator, 2.5 ), + ( PopulationVarianceAggregator, 2.0 ) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult(List(5.0, 4.0, 3.0, 2.0, 1.0), agg) + result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE + } + } + + test("should calculate correct results for standard deviation and variance on integers") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, Math.sqrt(2.5) ), + ( PopulationStandardDeviationAggregator, Math.sqrt(2) ), + ( SampleVarianceAggregator, 2.5 ), + ( PopulationVarianceAggregator, 2.0 ) ) - // null is returned because method alignToExpectedType did not run - require(result == null) + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult(List(5, 4, 3, 2, 1), agg) + result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE + } + } + + test("should calculate correct results for standard deviation and variance on BigInt") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, BigDecimal(Math.sqrt(2.5)) ), + ( PopulationStandardDeviationAggregator, BigDecimal(Math.sqrt(2)) ), + ( SampleVarianceAggregator, BigDecimal(2.5) ), + ( PopulationVarianceAggregator, BigDecimal(2.0) ) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult( + List(new BigInteger("5"), new BigInteger("4"), new BigInteger("3"), new BigInteger("2"), new BigInteger("1")), + agg + ) + BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe expectedResult +- EPS_BIG_DECIMAL + } + } + + test("should calculate correct results for standard deviation and variance on float") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, Math.sqrt(2.5) ), + ( PopulationStandardDeviationAggregator, Math.sqrt(2) ), + ( SampleVarianceAggregator, 2.5 ), + ( PopulationVarianceAggregator, 2.0 ) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult(List(5.0f, 4.0f, 3.0f, 2.0f, 1.0f), agg) + result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE + } + } + + test("should calculate correct results for standard deviation and variance on BigDecimals") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, BigDecimal(Math.sqrt(2.5)) ), + ( PopulationStandardDeviationAggregator, BigDecimal(Math.sqrt(2)) ), + ( SampleVarianceAggregator, BigDecimal(2.5) ), + ( PopulationVarianceAggregator, BigDecimal(2.0) ) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult( + List( + new java.math.BigDecimal("5"), + new java.math.BigDecimal("4"), + new java.math.BigDecimal("3"), + new java.math.BigDecimal("2"), + new java.math.BigDecimal("1") + ), + agg + ) + BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe expectedResult +- EPS_BIG_DECIMAL + } + } + + test("some aggregators should ignore nulls ") { + val table = Table( + ("aggregator", "value"), + ( SampleStandardDeviationAggregator, Math.sqrt(2.5) ), + ( PopulationStandardDeviationAggregator, Math.sqrt(2) ), + ( SampleVarianceAggregator, 2.5 ), + ( PopulationVarianceAggregator, 2.0 ), + ( SumAggregator, 15.0), + ( MaxAggregator, 5.0), + ( MinAggregator, 1.0), + ( AverageAggregator, 3.0) + ) + + forAll(table) { (agg, expectedResult) => + val result = addElementsAndComputeResult(List(null, 5.0, 4.0, null, 3.0, 2.0, 1.0), agg) + result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE + } + } + + test("some aggregators should produce null on empty set") { + forAll (Table( + "aggregator", + AverageAggregator, + SampleStandardDeviationAggregator, + PopulationStandardDeviationAggregator, + SampleVarianceAggregator, + PopulationVarianceAggregator, + MaxAggregator, + MinAggregator, + FirstAggregator, + LastAggregator, + SumAggregator + )) { agg => + val result = addElementsAndComputeResult(List(), agg) + result shouldBe null + } + } + + test("should calculate correct results for population standard deviation and variance on single element double set") { + val table = Table( + "aggregator", + SampleStandardDeviationAggregator, + PopulationStandardDeviationAggregator, + SampleVarianceAggregator, + PopulationVarianceAggregator + ) + forAll(table) { agg => + val result = addElementsAndComputeResult(List(1.0d), agg) + result.asInstanceOf[Double] shouldBe 0 + } + } + + test("should calculate correct results for population standard deviation on single element float set") { + val agg = PopulationStandardDeviationAggregator + val result = addElementsAndComputeResult(List(1.0f), agg) + result.asInstanceOf[Double] shouldBe 0 + } + + test("should calculate correct results for population standard deviation on single element BigDecimal set") { + val agg = PopulationStandardDeviationAggregator + val result = addElementsAndComputeResult(List(new java.math.BigDecimal("1.0")), agg) + BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe BigDecimal(0) +- EPS_BIG_DECIMAL + } + + test("should calculate correct results for population standard deviation on single element BigInteger set") { + val agg = PopulationStandardDeviationAggregator + val result = addElementsAndComputeResult(List(new java.math.BigInteger("1")), agg) + BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe BigDecimal(0) +- EPS_BIG_DECIMAL } test("should calculate correct results for last aggregator") { val agg = LastAggregator - agg.result( - agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(5.asInstanceOf[agg.Element], agg.zero)) - ) shouldEqual 8 + addElementsAndComputeResult(List(5, 8), agg) shouldEqual 8 } test("should compute output and stored type for simple aggregators") { @@ -344,5 +487,18 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat aggregator.mergeAggregates(rightElemState, leftElemState) shouldBe combinedState } + private def addElementsAndComputeResult[T](elements: List[T], aggregator: Aggregator): AnyRef = { + aggregator.result( + elements.foldLeft(aggregator.zero)((state, element) => + aggregator.addElement(element.asInstanceOf[aggregator.Element], state) + ) + ) + } + class JustAnyClass } + +object AggregatesSpec { + val EPS_DOUBLE = 0.000001; + val EPS_BIG_DECIMAL = BigDecimal(new java.math.BigDecimal("0.000001")) +} diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala index a2f00006bc4..579ebf06e58 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/TransformersTest.scala @@ -6,6 +6,8 @@ import com.typesafe.config.ConfigFactory import org.scalatest.Inside import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks.forAll +import org.scalatest.prop.Tables.Table import pl.touk.nussknacker.engine.api.component.ComponentDefinition import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.{ CannotCreateObjectError, @@ -74,6 +76,26 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateOk("#AGG.average", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) validateOk("#AGG.average", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.stddevPop", "1", Typed[Double]) + validateOk("#AGG.stddevSamp", "1", Typed[Double]) + validateOk("#AGG.varPop", "1", Typed[Double]) + validateOk("#AGG.varSamp", "1", Typed[Double]) + + validateOk("#AGG.stddevPop", "1.5", Typed[Double]) + validateOk("#AGG.stddevSamp", "1.5", Typed[Double]) + validateOk("#AGG.varPop", "1.5", Typed[Double]) + validateOk("#AGG.varSamp", "1.5", Typed[Double]) + + validateOk("#AGG.stddevPop", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.stddevSamp", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.varPop", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.varSamp", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal]) + + validateOk("#AGG.stddevPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.stddevSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.varPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.varSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal]) + validateOk("#AGG.set", "#input.str", Typed.fromDetailedType[java.util.Set[String]]) validateOk( "#AGG.map({f1: #AGG.sum, f2: #AGG.set})", @@ -84,6 +106,12 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins validateError("#AGG.sum", "#input.str", "Invalid aggregate type: String, should be: Number") validateError("#AGG.countWhen", "#input.str", "Invalid aggregate type: String, should be: Boolean") validateError("#AGG.average", "#input.str", "Invalid aggregate type: String, should be: Number") + + validateError("#AGG.stddevPop", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError("#AGG.stddevSamp", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError("#AGG.varPop", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError("#AGG.varSamp", "#input.str", "Invalid aggregate type: String, should be: Number") + validateError( "#AGG.map({f1: #AGG.set, f2: #AGG.set})", "{f1: #input.str}", @@ -142,6 +170,31 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldBe List(1.0d, 1.5, 3.5) } + test("standard deviation and average aggregates") { + val table = Table( + ("aggregate", "secondValue"), + ("#AGG.stddevPop", Math.sqrt(0.25)), + ("#AGG.stddevSamp", Math.sqrt(0.5)), + ("#AGG.varPop", 0.25), + ("#AGG.varSamp", 0.5) + ) + + forAll(table) { (aggregationName, secondValue) => + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"))) + val testProcess = sliding(aggregationName, "#input.eId", emitWhenEventLeft = false) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + val mapped = aggregateVariables + .map(e => e.asInstanceOf[Double]) + mapped.size shouldBe 2 + mapped(0) shouldBe 0.0 +- 0.0001 + mapped(1) shouldBe secondValue +- 0.0001 + } + } + test("sliding aggregate should emit context of variables") { val id = "1" @@ -399,7 +452,32 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) aggregateVariables.length shouldEqual (2) aggregateVariables(0) shouldEqual 1.0 - require((aggregateVariables(1).asInstanceOf[Double].isNaN)) + aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true + } + + test( + "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type double" + ) { + val table = Table( + "aggregatorExpression", + "#AGG.stddevPop", + "#AGG.stddevSamp", + "#AGG.varPop", + "#AGG.varSamp" + ) + + forAll(table) { aggregatorName => + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"))) + val testProcess = tumbling(aggregatorName, "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables.length shouldEqual (2) + aggregateVariables(0) shouldEqual 0.0 + aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true + } } test("emit aggregate for extra window when no data come for average aggregator for return type BigDecimal") { @@ -414,6 +492,34 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null) } + test( + "emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type BigDecimal" + ) { + val table = Table( + "aggregatorExpression", + "#AGG.stddevPop", + "#AGG.stddevSamp", + "#AGG.varPop", + "#AGG.varSamp" + ) + + forAll(table) { aggregatorName => + val id = "1" + + val model = + modelData(List(TestRecordHours(id, 0, 1, "a"))) + val testProcess = + tumbling( + aggregatorName, + """T(java.math.BigDecimal).ONE""", + emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow + ) + + val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess) + aggregateVariables shouldEqual List(new java.math.BigDecimal("0"), null) + } + } + test("emit aggregate for extra window when no data come - out of order elements") { val id = "1" diff --git a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java index a6329c746e5..7872db9cba4 100644 --- a/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java +++ b/engine/flink/components/base-unbounded/src/main/java/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/AggregateHelper.java @@ -27,6 +27,10 @@ public class AggregateHelper implements Serializable { new FixedExpressionValue("#AGG.last", "Last"), new FixedExpressionValue("#AGG.countWhen", "CountWhen"), new FixedExpressionValue("#AGG.average", "Average"), + new FixedExpressionValue("#AGG.stddevPop", "StddevPop"), + new FixedExpressionValue("#AGG.stddevSamp", "StddevSamp"), + new FixedExpressionValue("#AGG.varPop", "VarPop"), + new FixedExpressionValue("#AGG.varSamp", "VarSamp"), new FixedExpressionValue("#AGG.min", "Min"), new FixedExpressionValue("#AGG.max", "Max"), new FixedExpressionValue("#AGG.sum", "Sum"), @@ -46,6 +50,10 @@ public class AggregateHelper implements Serializable { private static final Aggregator LAST = aggregates.LastAggregator$.MODULE$; private static final Aggregator COUNT_WHEN = aggregates.CountWhenAggregator$.MODULE$; private static final Aggregator AVERAGE = aggregates.AverageAggregator$.MODULE$; + private static final Aggregator STDDEV_POP = aggregates.PopulationStandardDeviationAggregator$.MODULE$; + private static final Aggregator STDDEV_SAMP = aggregates.SampleStandardDeviationAggregator$.MODULE$; + private static final Aggregator VAR_POP = aggregates.PopulationVarianceAggregator$.MODULE$; + private static final Aggregator VAR_SAMP = aggregates.SampleVarianceAggregator$.MODULE$; private static final Aggregator APPROX_CARDINALITY = HyperLogLogPlusAggregator$.MODULE$.apply(5, 10); // Instance methods below are for purpose of using in SpEL so your IDE can report that they are not used. @@ -67,6 +75,11 @@ public class AggregateHelper implements Serializable { public Aggregator countWhen = COUNT_WHEN; public Aggregator average = AVERAGE; + public Aggregator stddevPop = STDDEV_POP; + public Aggregator stddevSamp = STDDEV_SAMP; + public Aggregator varPop = VAR_POP; + public Aggregator varSamp = VAR_SAMP; + public Aggregator approxCardinality = APPROX_CARDINALITY; public Aggregator map(@ParamName("parts") Map parts) { diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala index 09d79139607..3f54f4cbbc1 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/aggregates.scala @@ -201,48 +201,67 @@ object aggregates { computeOutputType(input) } - object AverageAggregator extends Aggregator { + @TypeInfo(classOf[LargeFloatSumState.TypeInfoFactory]) + // it would be natural to use type Number instead of this class + // it is done this way so that it is serialized properly + case class LargeFloatSumState( + sumDouble: java.lang.Double, + sumBigDecimal: java.math.BigDecimal, + ) { + def asNumber: Number = Option(sumDouble).getOrElse(sumBigDecimal) + + def withAddedElement(element: Number): LargeFloatSumState = { + LargeFloatSumState.fromNumber(MathUtils.largeFloatingSum(element, asNumber)) + } - override type Element = java.lang.Number + def withMergedState(other: LargeFloatSumState): LargeFloatSumState = { + LargeFloatSumState.fromNumber(MathUtils.largeFloatingSum(asNumber, other.asNumber)) + } + + } + + object LargeFloatSumState { + class TypeInfoFactory extends CaseClassTypeInfoFactory[LargeFloatSumState] + + private def fromNumber(sum: Number): LargeFloatSumState = { + sum match { + case null => LargeFloatSumState(null, null) + case sumDouble: java.lang.Double => LargeFloatSumState(sumDouble, null) + case sumBigDecimal: java.math.BigDecimal => LargeFloatSumState(null, sumBigDecimal) + } + } + + def emptyState: LargeFloatSumState = { + LargeFloatSumState.fromNumber(null) + } + + } + + object AverageAggregator extends Aggregator with LargeFloatingNumberAggregate { override type Aggregate = AverageAggregatorState - override def zero: AverageAggregatorState = AverageAggregatorState(null, 0) + override def zero: AverageAggregatorState = AverageAggregatorState(LargeFloatSumState.emptyState, 0) - override def addElement(element: java.lang.Number, aggregate: Aggregate): Aggregate = - AverageAggregatorState(MathUtils.largeFloatingSum(element, aggregate.sum), aggregate.count + 1) + override def addElement(element: Element, aggregate: Aggregate): Aggregate = { + if (element == null) aggregate + else + AverageAggregatorState(aggregate.sum.withAddedElement(element), aggregate.count + 1) + } override def mergeAggregates(aggregate1: Aggregate, aggregate2: Aggregate): Aggregate = AverageAggregatorState( - MathUtils.largeFloatingSum(aggregate1.sum, aggregate2.sum), + aggregate1.sum.withMergedState(aggregate2.sum), aggregate1.count + aggregate2.count ) override def result(finalAggregate: Aggregate): AnyRef = { val count = finalAggregate.count - finalAggregate.sum match { - case null => - // will be replaced to Double.Nan in alignToExpectedType iff return type is known to be Double - null - case sum: java.lang.Double => (sum / count).asInstanceOf[AnyRef] - case sum: java.math.BigDecimal => (BigDecimal(sum) / BigDecimal(count)).bigDecimal - } - } - - override def alignToExpectedType(value: AnyRef, outputType: TypingResult): AnyRef = { - if (value == null && outputType == Typed(classOf[Double])) { - Double.NaN.asInstanceOf[AnyRef] - } else { - value - } - } - - override def computeOutputType(input: typing.TypingResult): Validated[String, typing.TypingResult] = { - - if (!input.canBeConvertedTo(Typed[Number])) { - Invalid(s"Invalid aggregate type: ${input.display}, should be: ${Typed[Number].display}") + val sum = finalAggregate.sum.asNumber + if (sum == null) { + null } else { - Valid(ForLargeFloatingNumbersOperation.promoteSingle(input)) + MathUtils.divideWithDefaultBigDecimalScale(sum, count) } } @@ -250,31 +269,109 @@ object aggregates { Valid(Typed[AverageAggregatorState]) @TypeInfo(classOf[AverageAggregatorState.TypeInfoFactory]) - // it would be natural to have one field sum: Number instead of nullable sumDouble and sumBigDecimal, - // it is done this way to have types serialized properly case class AverageAggregatorState( - sumDouble: java.lang.Double, - sumBigDecimal: java.math.BigDecimal, + sum: LargeFloatSumState, count: java.lang.Long - ) { - def sum: Number = Option(sumDouble).getOrElse(sumBigDecimal) - } + ) object AverageAggregatorState { class TypeInfoFactory extends CaseClassTypeInfoFactory[AverageAggregatorState] + } + + } + + @TypeInfo(classOf[StandardDeviationState.TypeInfoFactory]) + case class StandardDeviationState( + sum: LargeFloatSumState, + squaresSum: LargeFloatSumState, + count: java.lang.Long + ) - def apply(sum: Number, count: java.lang.Long): AverageAggregatorState = { - sum match { - case null => AverageAggregatorState(null, null, count) - case sumDouble: java.lang.Double => AverageAggregatorState(sumDouble, null, count) - case sumBigDecimal: java.math.BigDecimal => AverageAggregatorState(null, sumBigDecimal, count) + object StandardDeviationState { + class TypeInfoFactory extends CaseClassTypeInfoFactory[StandardDeviationState] + } + + private sealed trait StandardDeviationOrVarianceAggregationType + private case object SampleStandardDeviation extends StandardDeviationOrVarianceAggregationType + private case object PopulationStandardDeviation extends StandardDeviationOrVarianceAggregationType + private case object SampleVariance extends StandardDeviationOrVarianceAggregationType + private case object PopulationVariance extends StandardDeviationOrVarianceAggregationType + + class GeneralStandardDeviationAndVarianceAggregator( + private val standardDeviationVarianceType: StandardDeviationOrVarianceAggregationType + ) extends Aggregator + with LargeFloatingNumberAggregate { + + override type Aggregate = StandardDeviationState + + override def zero: StandardDeviationState = + StandardDeviationState(LargeFloatSumState.emptyState, LargeFloatSumState.emptyState, 0) + + override def addElement(element: Element, aggregate: Aggregate): Aggregate = { + if (element == null) aggregate + else + StandardDeviationState( + sum = aggregate.sum.withAddedElement(element), + squaresSum = aggregate.squaresSum.withAddedElement(MathUtils.largeFloatSquare(element)), + count = aggregate.count + 1 + ) + } + + override def mergeAggregates(aggregate1: Aggregate, aggregate2: Aggregate): Aggregate = + StandardDeviationState( + sum = aggregate1.sum.withMergedState(aggregate2.sum), + squaresSum = aggregate1.squaresSum.withMergedState(aggregate2.squaresSum), + count = aggregate1.count + aggregate2.count + ) + + override def result(finalAggregate: Aggregate): AnyRef = { + if (finalAggregate.count == 0 || finalAggregate.sum.asNumber == null || finalAggregate.squaresSum == null) { + // will be replaced to Double.Nan in alignToExpectedType iff return type is known to be Double + null + } else if (finalAggregate.count == 1) { + // zero of the same type as aggregated number + MathUtils.minus(finalAggregate.sum.asNumber, finalAggregate.sum.asNumber) + } else { + val count = finalAggregate.count + val average = MathUtils.divideWithDefaultBigDecimalScale(finalAggregate.sum.asNumber, count) + val averageSquare = MathUtils.divideWithDefaultBigDecimalScale(finalAggregate.squaresSum.asNumber, count) + val populationVariance = MathUtils.minus(averageSquare, MathUtils.largeFloatSquare(average)) + val sampleVariance = MathUtils.multiply(count.toDouble / (count - 1), populationVariance) + + standardDeviationVarianceType match { + case SampleStandardDeviation => MathUtils.largeFloatSqrt(sampleVariance) + case PopulationStandardDeviation => MathUtils.largeFloatSqrt(populationVariance) + case SampleVariance => sampleVariance + case PopulationVariance => populationVariance } } - } + override def computeStoredType(input: typing.TypingResult): Validated[String, typing.TypingResult] = + Valid(Typed[StandardDeviationState]) + } + object SampleStandardDeviationAggregator + extends GeneralStandardDeviationAndVarianceAggregator( + SampleStandardDeviation + ) + + object PopulationStandardDeviationAggregator + extends GeneralStandardDeviationAndVarianceAggregator( + PopulationStandardDeviation + ) + + object SampleVarianceAggregator + extends GeneralStandardDeviationAndVarianceAggregator( + SampleVariance + ) + + object PopulationVarianceAggregator + extends GeneralStandardDeviationAndVarianceAggregator( + PopulationVariance + ) + /* This is more complex aggregator, as it is composed from smaller ones. The idea is that we define aggregator: @@ -454,4 +551,26 @@ object aggregates { protected def promotedType(typ: TypingResult): TypingResult = promotionStrategy.promoteSingle(typ) } + trait LargeFloatingNumberAggregate { self: Aggregator => + override type Element = java.lang.Number + + override def alignToExpectedType(value: AnyRef, outputType: TypingResult): AnyRef = { + if (value == null && outputType == Typed(classOf[Double])) { + Double.NaN.asInstanceOf[AnyRef] + } else { + value + } + } + + override def computeOutputType(input: typing.TypingResult): Validated[String, typing.TypingResult] = { + + if (!input.canBeConvertedTo(Typed[Number])) { + Invalid(s"Invalid aggregate type: ${input.display}, should be: ${Typed[Number].display}") + } else { + Valid(ForLargeFloatingNumbersOperation.promoteSingle(input)) + } + } + + } + } diff --git a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala index a992b6eb5cb..cb4ab891e21 100644 --- a/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala +++ b/engine/flink/components/base-unbounded/src/main/scala/pl/touk/nussknacker/engine/flink/util/transformer/aggregate/sampleTransformers.scala @@ -41,6 +41,10 @@ object sampleTransformers { new LabeledExpression(label = "Sum", expression = "#AGG.sum"), new LabeledExpression(label = "Average", expression = "#AGG.average"), new LabeledExpression(label = "CountWhen", expression = "#AGG.countWhen"), + new LabeledExpression(label = "StddevPop", expression = "#AGG.stddevPop"), + new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), + new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), + new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -94,6 +98,10 @@ object sampleTransformers { new LabeledExpression(label = "Sum", expression = "#AGG.sum"), new LabeledExpression(label = "Average", expression = "#AGG.average"), new LabeledExpression(label = "CountWhen", expression = "#AGG.countWhen"), + new LabeledExpression(label = "StddevPop", expression = "#AGG.stddevPop"), + new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), + new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), + new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") @@ -151,6 +159,10 @@ object sampleTransformers { new LabeledExpression(label = "Sum", expression = "#AGG.sum"), new LabeledExpression(label = "Average", expression = "#AGG.average"), new LabeledExpression(label = "CountWhen", expression = "#AGG.countWhen"), + new LabeledExpression(label = "StddevPop", expression = "#AGG.stddevPop"), + new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"), + new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"), + new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"), new LabeledExpression(label = "List", expression = "#AGG.list"), new LabeledExpression(label = "Set", expression = "#AGG.set"), new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality") diff --git a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json index 03d4cc6889f..12f3af1ca0c 100644 --- a/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json +++ b/engine/flink/tests/src/test/resources/extractedTypes/defaultModel.json @@ -14894,6 +14894,42 @@ } } ], + "stddevPop": [ + { + "name": "stddevPop", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], + "stddevSamp": [ + { + "name": "stddevSamp", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], + "varPop": [ + { + "name": "varPop", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], + "varSamp": [ + { + "name": "varSamp", + "signature": { + "noVarArgs": [], + "result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"} + } + } + ], "first": [ { "name": "first", @@ -17733,4 +17769,4 @@ ] } } -] \ No newline at end of file +] diff --git a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala index 2455028ad53..2af9e71b786 100644 --- a/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala +++ b/utils/math-utils/src/main/scala/pl/touk/nussknacker/engine/util/MathUtils.scala @@ -8,11 +8,13 @@ import pl.touk.nussknacker.engine.api.typed.supertype.{ } import java.lang +import java.math.MathContext import java.math.RoundingMode +import javax.annotation.Nullable trait MathUtils { - def min(n1: Number, n2: Number): Number = { + def min(@Nullable n1: Number, @Nullable n2: Number): Number = { implicit val promotionStrategy: ReturningSingleClassPromotionStrategy = NumberTypesPromotionStrategy.ForMinMax withNotNullValues(n1, n2) { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandlerReturningNumber { @@ -36,7 +38,7 @@ trait MathUtils { } } - def max(n1: Number, n2: Number): Number = { + def max(@Nullable n1: Number, @Nullable n2: Number): Number = { implicit val promotionStrategy: ReturningSingleClassPromotionStrategy = NumberTypesPromotionStrategy.ForMinMax withNotNullValues(n1, n2) { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandlerReturningNumber { @@ -79,6 +81,27 @@ trait MathUtils { promoteThenSum(n1, n2) } + @Hidden + def largeFloatSquare(number: Number): Number = { + implicit val promotionStrategy: ReturningSingleClassPromotionStrategy = + NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation + val converted = convertToPromotedType(number) + multiply(converted, converted) + } + + @Hidden + def largeFloatSqrt(number: Number): Number = { + implicit val promotionStrategy: ReturningSingleClassPromotionStrategy = + NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation + + val converted = convertToPromotedType(number) + + converted match { + case converted: java.lang.Double => Math.sqrt(converted) + case converted: java.math.BigDecimal => converted.sqrt(MathContext.DECIMAL128) + } + } + def plus(n1: Number, n2: Number): Number = sum(n1, n2) def minus(n1: Number, n2: Number): Number = { @@ -107,6 +130,18 @@ trait MathUtils { })(NumberTypesPromotionStrategy.ForMathOperation) } + // divide method has peculiar behaviour when it comes to BigDecimals (see its implementation), hence this method is sometimes needed + @Hidden + def divideWithDefaultBigDecimalScale(n1: Number, n2: Number): Number = { + if (n1.isInstanceOf[java.math.BigDecimal] || n2.isInstanceOf[java.math.BigDecimal]) { + (BigDecimal(SpringNumberUtils.convertNumberToTargetClass(n1, classOf[java.math.BigDecimal])) + / + BigDecimal(SpringNumberUtils.convertNumberToTargetClass(n2, classOf[java.math.BigDecimal]))).bigDecimal + } else { + divide(n1, n2) + } + } + def divide(n1: Number, n2: Number): Number = { withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandlerForPromotingMathOp { override def onInts(n1: java.lang.Integer, n2: java.lang.Integer): java.lang.Integer = n1 / n2 @@ -118,6 +153,8 @@ trait MathUtils { override def onBigDecimals(n1: java.math.BigDecimal, n2: java.math.BigDecimal): java.math.BigDecimal = { n1.divide( n2, + // This is copied behaviour of divide operation in spel (class OpDivide) but it can lead to issues when both big decimals have small scales. + // Small scales happen when integer is converted to BigDecimal using SpringNumberUtils.convertNumberToTargetClass Math.max(n1.scale(), n2.scale), RoundingMode.HALF_EVEN ) // same scale and rounding as used by OpDivide in SpelExpression.java @@ -169,7 +206,7 @@ trait MathUtils { def equal(n1: Number, n2: Number): Boolean = compare(n1, n2) == 0 def notEqual(n1: Number, n2: Number): Boolean = compare(n1, n2) != 0 - private def promoteThenSum(n1: Number, n2: Number)( + private def promoteThenSum(@Nullable n1: Number, @Nullable n2: Number)( implicit promotionStrategy: ReturningSingleClassPromotionStrategy ) = { withNotNullValues(n1, n2) { @@ -186,7 +223,7 @@ trait MathUtils { } } - protected def withNotNullValues(n1: Number, n2: Number)( + protected def withNotNullValues(@Nullable n1: Number, @Nullable n2: Number)( f: => Number )(implicit promotionStrategy: ReturningSingleClassPromotionStrategy): Number = { if (n1 == null) {