Skip to content

Commit

Permalink
[NU-1921] Add median aggregator #7321
Browse files Browse the repository at this point in the history
Co-authored-by: Pawel Czajka <[email protected]>
  • Loading branch information
paw787878 and Pawel Czajka authored Dec 18, 2024
1 parent f00c47a commit ace3d3c
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
* [#7332](https://github.com/TouK/nussknacker/pull/7332) Handle scenario names with spaces when performing migration tests, they were ignored
* [#7346](https://github.com/TouK/nussknacker/pull/7346) OpenAPI enricher: ability to configure common secret for any security scheme
* [#7307](https://github.com/TouK/nussknacker/pull/7307) Added StddevPop, StddevSamp, VarPop and VarSamp aggregators
* [#7321](https://github.com/TouK/nussknacker/pull/7321) Added Median aggregator

## 1.18

Expand Down
1 change: 1 addition & 0 deletions docs/scenarios_authoring/AggregatesInTimeWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Let’s map the above statement on the parameters of the Nussknacker Aggregate c
* StddevSamp - computes sample standard deviation
* VarPop - computes population variance
* VarSamp - computes sample variance
* Median - computes median
* ApproximateSetCardinality - computes approximate cardinality of a set using [HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) algorithm. Please note that this aggregator treats null as a unique value. If this is undesirable and the set passed to ApproximateSetCardinality aggregator contained null (this can be tested with safe navigation in [SpEL](./Spel.md#safe-navigation)), subtract 1 from the obtained result.

If you need to count events in a window, use the CountWhen aggregate function and aggregate by fixed `true` expression - see the table with examples below. Subsequent sections use the Count function on the diagrams as an example for the **aggregator** - it is the easiest function to use in the examples. Please note, however, that technically, we provide an indirect implementation of this aggregator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,24 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypedObjectTypingResult, TypingResult, Unknown}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregatesSpec.{EPS_BIG_DECIMAL, EPS_DOUBLE}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{AverageAggregator, CountWhenAggregator, FirstAggregator, LastAggregator, ListAggregator, MapAggregator, MaxAggregator, MinAggregator, OptionAggregator, PopulationStandardDeviationAggregator, PopulationVarianceAggregator, SampleStandardDeviationAggregator, SampleVarianceAggregator, SetAggregator, SumAggregator}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{
AverageAggregator,
CountWhenAggregator,
FirstAggregator,
LastAggregator,
ListAggregator,
MapAggregator,
MaxAggregator,
MedianAggregator,
MinAggregator,
OptionAggregator,
PopulationStandardDeviationAggregator,
PopulationVarianceAggregator,
SampleStandardDeviationAggregator,
SampleVarianceAggregator,
SetAggregator,
SumAggregator
}
import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap

import java.lang.{Integer => JInt, Long => JLong}
Expand Down Expand Up @@ -127,7 +144,8 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat
MinAggregator,
FirstAggregator,
LastAggregator,
SumAggregator
SumAggregator,
MedianAggregator
)) { agg =>
addElementsAndComputeResult(List(null), agg) shouldEqual null
}
Expand All @@ -148,6 +166,62 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat
}
}

test("should calculate correct results for median aggregator on integers") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(7, 8), agg)
result shouldBe a[Double]
result shouldEqual 7.5
}

test("should calculate correct results for median aggregator on integers on single value") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(7), agg)
result shouldBe a[Double]
result shouldEqual 7
}

test("should calculate correct results for median aggregator on BigInt") {
val agg = MedianAggregator
addElementsAndComputeResult(List(new BigInteger("7"), new BigInteger("8")), agg) shouldEqual new java.math.BigDecimal("7.5")
}

test("should calculate correct results for median aggregator on floats") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(7.0f, 8.0f), agg)
result shouldBe a[Double]
result shouldEqual 7.5
}

test("should calculate correct results for median aggregator on BigDecimals") {
val agg = MedianAggregator
addElementsAndComputeResult(
List(new java.math.BigDecimal("7"), new java.math.BigDecimal("8")),
agg
) shouldEqual new java.math.BigDecimal("7.5")
}

test("should ignore nulls for median aggregator") {
val agg = MedianAggregator
addElementsAndComputeResult(
List(null, new java.math.BigDecimal("7"), null, new java.math.BigDecimal("8")),
agg
) shouldEqual new java.math.BigDecimal("7.5")
}

test("MedianAggregator test on odd length list") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5, 90), agg)

result shouldEqual 5
}

test("MedianAggregator test on even length list") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5), agg)

result shouldEqual 4.5
}

test("should calculate correct results for standard deviation and variance on integers") {
val table = Table(
("aggregator", "value"),
Expand Down Expand Up @@ -230,7 +304,8 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat
( SumAggregator, 15.0),
( MaxAggregator, 5.0),
( MinAggregator, 1.0),
( AverageAggregator, 3.0)
( AverageAggregator, 3.0),
( MedianAggregator, 3.0)
)

forAll(table) { (agg, expectedResult) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
validateOk("#AGG.varPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal])
validateOk("#AGG.varSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal])

validateOk("#AGG.median", """#input.eId""", Typed[Double])
validateOk("#AGG.median", """1""", Typed[Double])
validateOk("#AGG.median", """1.5""", Typed[Double])

validateOk("#AGG.median", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal])
validateOk("#AGG.median", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal])

validateOk("#AGG.set", "#input.str", Typed.fromDetailedType[java.util.Set[String]])
validateOk(
"#AGG.map({f1: #AGG.sum, f2: #AGG.set})",
Expand All @@ -106,6 +113,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
validateError("#AGG.sum", "#input.str", "Invalid aggregate type: String, should be: Number")
validateError("#AGG.countWhen", "#input.str", "Invalid aggregate type: String, should be: Boolean")
validateError("#AGG.average", "#input.str", "Invalid aggregate type: String, should be: Number")
validateError("#AGG.median", "#input.str", "Invalid aggregate type: String, should be: Number")

validateError("#AGG.stddevPop", "#input.str", "Invalid aggregate type: String, should be: Number")
validateError("#AGG.stddevSamp", "#input.str", "Invalid aggregate type: String, should be: Number")
Expand Down Expand Up @@ -170,6 +178,17 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
aggregateVariables shouldBe List(1.0d, 1.5, 3.5)
}

test("median aggregate") {
val id = "1"

val model =
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b")))
val testProcess = sliding("#AGG.median", "#input.eId", emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
aggregateVariables shouldBe List(1.0d, 1.5, 3.5)
}

test("standard deviation and average aggregates") {
val table = Table(
("aggregate", "secondValue"),
Expand Down Expand Up @@ -455,6 +474,19 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true
}

test("emit aggregate for extra window when no data come for median aggregator for return type double") {
val id = "1"

val model =
modelData(List(TestRecordHours(id, 0, 1, "a")))
val testProcess = tumbling("#AGG.median", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
aggregateVariables.length shouldEqual (2)
aggregateVariables(0) shouldEqual 1.0
aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true
}

test(
"emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type double"
) {
Expand Down Expand Up @@ -492,6 +524,18 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null)
}

test("emit aggregate for extra window when no data come for median aggregator for return type BigDecimal") {
val id = "1"

val model =
modelData(List(TestRecordHours(id, 0, 1, "a")))
val testProcess =
tumbling("#AGG.median", """T(java.math.BigDecimal).ONE""", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null)
}

test(
"emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type BigDecimal"
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class AggregateHelper implements Serializable {
new FixedExpressionValue("#AGG.stddevSamp", "StddevSamp"),
new FixedExpressionValue("#AGG.varPop", "VarPop"),
new FixedExpressionValue("#AGG.varSamp", "VarSamp"),
new FixedExpressionValue("#AGG.median", "Median"),
new FixedExpressionValue("#AGG.min", "Min"),
new FixedExpressionValue("#AGG.max", "Max"),
new FixedExpressionValue("#AGG.sum", "Sum"),
Expand All @@ -54,6 +55,7 @@ public class AggregateHelper implements Serializable {
private static final Aggregator STDDEV_SAMP = aggregates.SampleStandardDeviationAggregator$.MODULE$;
private static final Aggregator VAR_POP = aggregates.PopulationVarianceAggregator$.MODULE$;
private static final Aggregator VAR_SAMP = aggregates.SampleVarianceAggregator$.MODULE$;
private static final Aggregator MEDIAN = aggregates.MedianAggregator$.MODULE$;
private static final Aggregator APPROX_CARDINALITY = HyperLogLogPlusAggregator$.MODULE$.apply(5, 10);

// Instance methods below are for purpose of using in SpEL so your IDE can report that they are not used.
Expand All @@ -80,6 +82,8 @@ public class AggregateHelper implements Serializable {
public Aggregator varPop = VAR_POP;
public Aggregator varSamp = VAR_SAMP;

public Aggregator median = MEDIAN;

public Aggregator approxCardinality = APPROX_CARDINALITY;

public Aggregator map(@ParamName("parts") Map<String, Aggregator> parts) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ import cats.data.{NonEmptyList, Validated}
import cats.instances.list._
import org.apache.flink.api.common.typeinfo.TypeInfo
import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy
import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation
import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{
ForLargeFloatingNumbersOperation,
}
import pl.touk.nussknacker.engine.api.typed.typing._
import pl.touk.nussknacker.engine.api.typed.{NumberTypeUtils, typing}
import pl.touk.nussknacker.engine.flink.api.typeinfo.caseclass.CaseClassTypeInfoFactory
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median.MedianHelper
import pl.touk.nussknacker.engine.util.Implicits._
import pl.touk.nussknacker.engine.util.MathUtils
import pl.touk.nussknacker.engine.util.validated.ValidatedSyntax._

import java.util
import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._

/*
Expand Down Expand Up @@ -69,6 +73,26 @@ object aggregates {

}

object MedianAggregator extends Aggregator with LargeFloatingNumberAggregate {

override type Aggregate = ListBuffer[Number]

override type Element = Number

override def zero: Aggregate = ListBuffer.empty

override def addElement(el: Element, agg: Aggregate): Aggregate = if (el == null) agg else agg.addOne(el)

override def mergeAggregates(agg1: Aggregate, agg2: Aggregate): Aggregate = agg1 ++ agg2

override def result(finalAggregate: Aggregate): AnyRef = MedianHelper.calculateMedian(finalAggregate.toList).orNull

override def computeStoredType(input: TypingResult): Validated[String, TypingResult] = Valid(
Typed.genericTypeClass[ListBuffer[_]](List(input))
)

}

object ListAggregator extends Aggregator {

override type Aggregate = List[AnyRef]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median

import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{
ForLargeFloatingNumbersOperation,
}
import pl.touk.nussknacker.engine.util.MathUtils

import scala.annotation.tailrec
import scala.util.Random

object MedianHelper {
private val rand = new Random(42)

def calculateMedian(numbers: List[Number]): Option[Number] = {
if (numbers.isEmpty) {
None
} else if (numbers.size % 2 == 1) {
Some(MathUtils.convertToPromotedType(quickSelect(numbers, (numbers.size - 1) / 2))(ForLargeFloatingNumbersOperation))
} else {
// it is possible to fetch both numbers with single recursion, but it would complicate code
val firstNumber = quickSelect(numbers, numbers.size / 2 - 1)
val secondNumber = quickSelect(numbers, numbers.size / 2)

val sum = MathUtils.largeFloatingSum(firstNumber, secondNumber)
Some(MathUtils.divideWithDefaultBigDecimalScale(sum, 2))
}
}

// https://en.wikipedia.org/wiki/Quickselect
@tailrec
private def quickSelect(numbers: List[Number], indexToTake: Int): Number = {
require(numbers.nonEmpty)
require(indexToTake >= 0)
require(indexToTake < numbers.size)

val randomElement = numbers(rand.nextInt(numbers.size))
val groupedBy = numbers.groupBy(e => {
val cmp = MathUtils.compare(e, randomElement)
if (cmp < 0) {
-1
} else if (cmp == 0) {
0
} else 1
})
val smallerNumbers = groupedBy.getOrElse(-1, Nil)
val equalNumbers = groupedBy.getOrElse(0, Nil)
val largerNumbers = groupedBy.getOrElse(1, Nil)

if (indexToTake < smallerNumbers.size) {
quickSelect(smallerNumbers, indexToTake)
} else if (indexToTake < smallerNumbers.size + equalNumbers.size) {
equalNumbers(indexToTake - smallerNumbers.size)
} else {
quickSelect(largerNumbers, indexToTake - smallerNumbers.size - equalNumbers.size)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ object sampleTransformers {
new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"),
new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"),
new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"),
new LabeledExpression(label = "Median", expression = "#AGG.median"),
new LabeledExpression(label = "List", expression = "#AGG.list"),
new LabeledExpression(label = "Set", expression = "#AGG.set"),
new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality")
Expand Down Expand Up @@ -102,6 +103,7 @@ object sampleTransformers {
new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"),
new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"),
new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"),
new LabeledExpression(label = "Median", expression = "#AGG.median"),
new LabeledExpression(label = "List", expression = "#AGG.list"),
new LabeledExpression(label = "Set", expression = "#AGG.set"),
new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality")
Expand Down Expand Up @@ -163,6 +165,7 @@ object sampleTransformers {
new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"),
new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"),
new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"),
new LabeledExpression(label = "Median", expression = "#AGG.median"),
new LabeledExpression(label = "List", expression = "#AGG.list"),
new LabeledExpression(label = "Set", expression = "#AGG.set"),
new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14930,6 +14930,15 @@
}
}
],
"median": [
{
"name": "median",
"signature": {
"noVarArgs": [],
"result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"}
}
}
],
"first": [
{
"name": "first",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ trait MathUtils {
case n1: java.math.BigDecimal => n1.negate()
}

private def compare(n1: Number, n2: Number): Int = {
@Hidden
def compare(n1: Number, n2: Number): Int = {
withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandler[Int] {
override def onBytes(n1: java.lang.Byte, n2: java.lang.Byte): Int = n1.compareTo(n2)
override def onShorts(n1: java.lang.Short, n2: java.lang.Short): Int = n1.compareTo(n2)
Expand Down Expand Up @@ -286,7 +287,8 @@ trait MathUtils {
}
}

private def convertToPromotedType(
@Hidden
def convertToPromotedType(
n: Number
)(implicit promotionStrategy: ReturningSingleClassPromotionStrategy): Number = {
// In some cases type can be promoted to other class e.g. Byte is promoted to Int for sum
Expand Down

0 comments on commit ace3d3c

Please sign in to comment.