Skip to content

Commit

Permalink
[NU-1921] Add standard deviation and variance aggregations (#7307)
Browse files Browse the repository at this point in the history
Co-authored-by: Pawel Czajka <[email protected]>
  • Loading branch information
paw787878 and Pawel Czajka authored Dec 18, 2024
1 parent a5d5305 commit f00c47a
Show file tree
Hide file tree
Showing 9 changed files with 571 additions and 87 deletions.
1 change: 1 addition & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
* [#7323](https://github.com/TouK/nussknacker/pull/7323) Improve Periodic DeploymentManager db queries
* [#7332](https://github.com/TouK/nussknacker/pull/7332) Handle scenario names with spaces when performing migration tests, they were ignored
* [#7346](https://github.com/TouK/nussknacker/pull/7346) OpenAPI enricher: ability to configure common secret for any security scheme
* [#7307](https://github.com/TouK/nussknacker/pull/7307) Added StddevPop, StddevSamp, VarPop and VarSamp aggregators

## 1.18

Expand Down
4 changes: 4 additions & 0 deletions docs/scenarios_authoring/AggregatesInTimeWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ Let’s map the above statement on the parameters of the Nussknacker Aggregate c
* Set - the result is a set of inputs received by the aggregator. Can be very ineffective for large sets, try to use ApproximateSetCardinality in this case
* CountWhen - accepts boolean values, returns how many of them are true
* Average - computes average of values
* StddevPop - computes population standard deviation
* StddevSamp - computes sample standard deviation
* VarPop - computes population variance
* VarSamp - computes sample variance
* ApproximateSetCardinality - computes approximate cardinality of a set using [HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) algorithm. Please note that this aggregator treats null as a unique value. If this is undesirable and the set passed to ApproximateSetCardinality aggregator contained null (this can be tested with safe navigation in [SpEL](./Spel.md#safe-navigation)), subtract 1 from the obtained result.

If you need to count events in a window, use the CountWhen aggregate function and aggregate by fixed `true` expression - see the table with examples below. Subsequent sections use the Count function on the diagrams as an example for the **aggregator** - it is the easiest function to use in the examples. Please note, however, that technically, we provide an indirect implementation of this aggregator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package pl.touk.nussknacker.engine.flink.util.transformer.aggregate
import org.scalatest.prop.TableDrivenPropertyChecks
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectTypingResult, TypingResult, Unknown}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates._
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypedObjectTypingResult, TypingResult, Unknown}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregatesSpec.{EPS_BIG_DECIMAL, EPS_DOUBLE}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{AverageAggregator, CountWhenAggregator, FirstAggregator, LastAggregator, ListAggregator, MapAggregator, MaxAggregator, MinAggregator, OptionAggregator, PopulationStandardDeviationAggregator, PopulationVarianceAggregator, SampleStandardDeviationAggregator, SampleVarianceAggregator, SetAggregator, SumAggregator}
import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap

import java.lang.{Integer => JInt, Long => JLong}
Expand Down Expand Up @@ -83,73 +84,215 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat

test("should calculate correct results for first aggregator") {
val agg = FirstAggregator
agg.result(
agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(5.asInstanceOf[agg.Element], agg.zero))
) shouldEqual 5
addElementsAndComputeResult(List(5, 8), agg) shouldEqual 5
}

test("should calculate correct results for countWhen aggregator") {
val agg = CountWhenAggregator
agg.result(
agg.addElement(
false.asInstanceOf[agg.Element],
agg.addElement(true.asInstanceOf[agg.Element], agg.addElement(true.asInstanceOf[agg.Element], agg.zero))
)
) shouldEqual 2
addElementsAndComputeResult(List(true, true), agg) shouldEqual 2
}

test("should calculate correct results for average aggregator") {
val agg = AverageAggregator
agg.result(
agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(7.asInstanceOf[agg.Element], agg.zero))
) shouldEqual 7.5
addElementsAndComputeResult(List(7, 8), agg) shouldEqual 7.5
}

test("should calculate correct results for average aggregator on BigInt") {
val agg = AverageAggregator
agg.result(
agg.addElement(
new BigInteger("8").asInstanceOf[agg.Element],
agg.addElement(new BigInteger("7").asInstanceOf[agg.Element], agg.zero)
)
) shouldEqual new java.math.BigDecimal("7.5")
addElementsAndComputeResult(List(new BigInteger("7"), new BigInteger("8")), agg) shouldEqual new java.math.BigDecimal("7.5")
}

test("should calculate correct results for average aggregator on float") {
val agg = AverageAggregator
agg.result(
agg.addElement(
8.0f.asInstanceOf[agg.Element],
agg.addElement(7.0f.asInstanceOf[agg.Element], agg.zero)
)
) shouldEqual 7.5
addElementsAndComputeResult(List(7.0f, 8.0f), agg) shouldEqual 7.5
}

test("should calculate correct results for average aggregator on BigDecimal") {
val agg = AverageAggregator
agg.result(
agg.addElement(
new java.math.BigDecimal("8").asInstanceOf[agg.Element],
agg.addElement(new java.math.BigDecimal("7").asInstanceOf[agg.Element], agg.zero)
)
addElementsAndComputeResult(
List(new java.math.BigDecimal("7"), new java.math.BigDecimal("8")),
agg
) shouldEqual new java.math.BigDecimal("7.5")
}

test("AverageAggregator should calculate correct results for empty aggregation set") {
val agg = AverageAggregator
val result = agg.result(
agg.zero
test("some aggregators should produce null on single null input") {
forAll (Table(
"aggregator",
AverageAggregator,
SampleStandardDeviationAggregator,
PopulationStandardDeviationAggregator,
SampleVarianceAggregator,
PopulationVarianceAggregator,
MaxAggregator,
MinAggregator,
FirstAggregator,
LastAggregator,
SumAggregator
)) { agg =>
addElementsAndComputeResult(List(null), agg) shouldEqual null
}
}

test("should calculate correct results for standard deviation and variance on doubles") {
val table = Table(
("aggregator", "value"),
( SampleStandardDeviationAggregator, Math.sqrt(2.5) ),
( PopulationStandardDeviationAggregator, Math.sqrt(2) ),
( SampleVarianceAggregator, 2.5 ),
( PopulationVarianceAggregator, 2.0 )
)

forAll(table) { (agg, expectedResult) =>
val result = addElementsAndComputeResult(List(5.0, 4.0, 3.0, 2.0, 1.0), agg)
result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE
}
}

test("should calculate correct results for standard deviation and variance on integers") {
val table = Table(
("aggregator", "value"),
( SampleStandardDeviationAggregator, Math.sqrt(2.5) ),
( PopulationStandardDeviationAggregator, Math.sqrt(2) ),
( SampleVarianceAggregator, 2.5 ),
( PopulationVarianceAggregator, 2.0 )
)

// null is returned because method alignToExpectedType did not run
require(result == null)
forAll(table) { (agg, expectedResult) =>
val result = addElementsAndComputeResult(List(5, 4, 3, 2, 1), agg)
result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE
}
}

test("should calculate correct results for standard deviation and variance on BigInt") {
val table = Table(
("aggregator", "value"),
( SampleStandardDeviationAggregator, BigDecimal(Math.sqrt(2.5)) ),
( PopulationStandardDeviationAggregator, BigDecimal(Math.sqrt(2)) ),
( SampleVarianceAggregator, BigDecimal(2.5) ),
( PopulationVarianceAggregator, BigDecimal(2.0) )
)

forAll(table) { (agg, expectedResult) =>
val result = addElementsAndComputeResult(
List(new BigInteger("5"), new BigInteger("4"), new BigInteger("3"), new BigInteger("2"), new BigInteger("1")),
agg
)
BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe expectedResult +- EPS_BIG_DECIMAL
}
}

test("should calculate correct results for standard deviation and variance on float") {
val table = Table(
("aggregator", "value"),
( SampleStandardDeviationAggregator, Math.sqrt(2.5) ),
( PopulationStandardDeviationAggregator, Math.sqrt(2) ),
( SampleVarianceAggregator, 2.5 ),
( PopulationVarianceAggregator, 2.0 )
)

forAll(table) { (agg, expectedResult) =>
val result = addElementsAndComputeResult(List(5.0f, 4.0f, 3.0f, 2.0f, 1.0f), agg)
result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE
}
}

test("should calculate correct results for standard deviation and variance on BigDecimals") {
val table = Table(
("aggregator", "value"),
( SampleStandardDeviationAggregator, BigDecimal(Math.sqrt(2.5)) ),
( PopulationStandardDeviationAggregator, BigDecimal(Math.sqrt(2)) ),
( SampleVarianceAggregator, BigDecimal(2.5) ),
( PopulationVarianceAggregator, BigDecimal(2.0) )
)

forAll(table) { (agg, expectedResult) =>
val result = addElementsAndComputeResult(
List(
new java.math.BigDecimal("5"),
new java.math.BigDecimal("4"),
new java.math.BigDecimal("3"),
new java.math.BigDecimal("2"),
new java.math.BigDecimal("1")
),
agg
)
BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe expectedResult +- EPS_BIG_DECIMAL
}
}

test("some aggregators should ignore nulls ") {
val table = Table(
("aggregator", "value"),
( SampleStandardDeviationAggregator, Math.sqrt(2.5) ),
( PopulationStandardDeviationAggregator, Math.sqrt(2) ),
( SampleVarianceAggregator, 2.5 ),
( PopulationVarianceAggregator, 2.0 ),
( SumAggregator, 15.0),
( MaxAggregator, 5.0),
( MinAggregator, 1.0),
( AverageAggregator, 3.0)
)

forAll(table) { (agg, expectedResult) =>
val result = addElementsAndComputeResult(List(null, 5.0, 4.0, null, 3.0, 2.0, 1.0), agg)
result.asInstanceOf[Double] shouldBe expectedResult +- EPS_DOUBLE
}
}

test("some aggregators should produce null on empty set") {
forAll (Table(
"aggregator",
AverageAggregator,
SampleStandardDeviationAggregator,
PopulationStandardDeviationAggregator,
SampleVarianceAggregator,
PopulationVarianceAggregator,
MaxAggregator,
MinAggregator,
FirstAggregator,
LastAggregator,
SumAggregator
)) { agg =>
val result = addElementsAndComputeResult(List(), agg)
result shouldBe null
}
}

test("should calculate correct results for population standard deviation and variance on single element double set") {
val table = Table(
"aggregator",
SampleStandardDeviationAggregator,
PopulationStandardDeviationAggregator,
SampleVarianceAggregator,
PopulationVarianceAggregator
)
forAll(table) { agg =>
val result = addElementsAndComputeResult(List(1.0d), agg)
result.asInstanceOf[Double] shouldBe 0
}
}

test("should calculate correct results for population standard deviation on single element float set") {
val agg = PopulationStandardDeviationAggregator
val result = addElementsAndComputeResult(List(1.0f), agg)
result.asInstanceOf[Double] shouldBe 0
}

test("should calculate correct results for population standard deviation on single element BigDecimal set") {
val agg = PopulationStandardDeviationAggregator
val result = addElementsAndComputeResult(List(new java.math.BigDecimal("1.0")), agg)
BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe BigDecimal(0) +- EPS_BIG_DECIMAL
}

test("should calculate correct results for population standard deviation on single element BigInteger set") {
val agg = PopulationStandardDeviationAggregator
val result = addElementsAndComputeResult(List(new java.math.BigInteger("1")), agg)
BigDecimal(result.asInstanceOf[java.math.BigDecimal]) shouldBe BigDecimal(0) +- EPS_BIG_DECIMAL
}

test("should calculate correct results for last aggregator") {
val agg = LastAggregator
agg.result(
agg.addElement(8.asInstanceOf[agg.Element], agg.addElement(5.asInstanceOf[agg.Element], agg.zero))
) shouldEqual 8
addElementsAndComputeResult(List(5, 8), agg) shouldEqual 8
}

test("should compute output and stored type for simple aggregators") {
Expand Down Expand Up @@ -344,5 +487,18 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat
aggregator.mergeAggregates(rightElemState, leftElemState) shouldBe combinedState
}

private def addElementsAndComputeResult[T](elements: List[T], aggregator: Aggregator): AnyRef = {
aggregator.result(
elements.foldLeft(aggregator.zero)((state, element) =>
aggregator.addElement(element.asInstanceOf[aggregator.Element], state)
)
)
}

class JustAnyClass
}

object AggregatesSpec {
val EPS_DOUBLE = 0.000001;
val EPS_BIG_DECIMAL = BigDecimal(new java.math.BigDecimal("0.000001"))
}
Loading

0 comments on commit f00c47a

Please sign in to comment.