Skip to content

Commit

Permalink
TopK optimization to not keep all values, instead have max of k value…
Browse files Browse the repository at this point in the history
…s per window
  • Loading branch information
arunkm committed Dec 6, 2019
1 parent ada8d3a commit 5f60996
Show file tree
Hide file tree
Showing 9 changed files with 549 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,28 @@ Expression<Func<Func<SortedDictionary<T, long>>, MinMaxState<T>>> template
public Expression<Func<MinMaxState<T>>> InitialState() => initialState;

private static readonly Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> acc
= (set, timestamp, input) => new MinMaxState<T> { savedValues = set.savedValues.Add(input) };
= (set, timestamp, input) => Apply(set, s => s.savedValues.Add(input));
public Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> Accumulate() => acc;

private static readonly Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> dec
= (set, timestamp, input) => new MinMaxState<T> { savedValues = set.savedValues.Remove(input) };
= (set, timestamp, input) => Apply(set, s => s.savedValues.Remove(input));
public Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> Deaccumulate() => dec;

private static readonly Expression<Func<MinMaxState<T>, MinMaxState<T>, MinMaxState<T>>> diff
= (leftSet, rightSet) => new MinMaxState<T> { savedValues = leftSet.savedValues.RemoveAll(rightSet.savedValues) };
= (leftSet, rightSet) => Apply(leftSet, s => s.savedValues.RemoveAll(rightSet.savedValues));
public Expression<Func<MinMaxState<T>, MinMaxState<T>, MinMaxState<T>>> Difference() => diff;

private static readonly Expression<Func<MinMaxState<T>, MinMaxState<T>, MinMaxState<T>>> sum
= (leftSet, rightSet) => new MinMaxState<T> { savedValues = leftSet.savedValues.AddAll(rightSet.savedValues) };
= (leftSet, rightSet) => Apply(leftSet, s => s.savedValues.AddAll(rightSet.savedValues));
public Expression<Func<MinMaxState<T>, MinMaxState<T>, MinMaxState<T>>> Sum() => sum;

public abstract Expression<Func<MinMaxState<T>, T>> ComputeResult();

private static MinMaxState<T> Apply(MinMaxState<T> state, Action<MinMaxState<T>> op)
{
op(state);
return state;
}
}

internal sealed class MaxAggregate<T> : MinMaxAggregateBase<T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,27 @@ Expression<Func<Func<SortedDictionary<T, long>>, SortedMultiSet<T>>> template
public Expression<Func<SortedMultiSet<T>>> InitialState() => initialState;

private static readonly Expression<Func<SortedMultiSet<T>, long, T, SortedMultiSet<T>>> acc
= (set, timestamp, input) => set.Add(input);
= (set, timestamp, input) => Apply(set, s => s.Add(input));
public Expression<Func<SortedMultiSet<T>, long, T, SortedMultiSet<T>>> Accumulate() => acc;

private static readonly Expression<Func<SortedMultiSet<T>, long, T, SortedMultiSet<T>>> dec
= (set, timestamp, input) => set.Remove(input);
= (set, timestamp, input) => Apply(set, s => s.Remove(input));
public Expression<Func<SortedMultiSet<T>, long, T, SortedMultiSet<T>>> Deaccumulate() => dec;

private static readonly Expression<Func<SortedMultiSet<T>, SortedMultiSet<T>, SortedMultiSet<T>>> diff
= (leftSet, rightSet) => leftSet.RemoveAll(rightSet);
= (leftSet, rightSet) => Apply(leftSet, s => s.RemoveAll(rightSet));
public Expression<Func<SortedMultiSet<T>, SortedMultiSet<T>, SortedMultiSet<T>>> Difference() => diff;

private static readonly Expression<Func<SortedMultiSet<T>, SortedMultiSet<T>, SortedMultiSet<T>>> sum
= (leftSet, rightSet) => leftSet.AddAll(rightSet);
= (leftSet, rightSet) => Apply(leftSet, s => s.AddAll(rightSet));
public Expression<Func<SortedMultiSet<T>, SortedMultiSet<T>, SortedMultiSet<T>>> Sum() => sum;

public abstract Expression<Func<SortedMultiSet<T>, R>> ComputeResult();

private static SortedMultiSet<T> Apply(SortedMultiSet<T> state, Action<SortedMultiSet<T>> op)
{
op(state);
return state;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,37 @@
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq.Expressions;
using Microsoft.StreamProcessing.Internal;

namespace Microsoft.StreamProcessing.Aggregates
{
internal sealed class TopKAggregate<T> : SortedMultisetAggregateBase<T, List<RankedEvent<T>>>
internal sealed class TopKAggregate<T> : ISummableAggregate<T, ITopKState<T>, List<RankedEvent<T>>>
{
private readonly Comparison<T> compiledRankComparer;
private readonly int k;

public TopKAggregate(int k, QueryContainer container) : this(k, ComparerExpression<T>.Default, container) { }
public TopKAggregate(int k, IComparerExpression<T> rankComparer, QueryContainer container, bool isHoppingWindow)
: this(k, rankComparer, ComparerExpression<T>.Default, container, isHoppingWindow) { }

public TopKAggregate(int k, IComparerExpression<T> rankComparer, QueryContainer container)
: this(k, rankComparer, ComparerExpression<T>.Default, container) { }

public TopKAggregate(int k, IComparerExpression<T> rankComparer, IComparerExpression<T> overallComparer, QueryContainer container)
: base(ThenOrderBy(Reverse(rankComparer), overallComparer), container)
public TopKAggregate(int k, IComparerExpression<T> rankComparer, IComparerExpression<T> overallComparer,
QueryContainer container, bool isHoppingWindow)
{
Contract.Requires(rankComparer != null);
Contract.Requires(overallComparer != null);
Contract.Requires(k > 0);
this.compiledRankComparer = Reverse(rankComparer).GetCompareExpr().Compile();
this.k = k;

Expression<Func<Func<SortedDictionary<T, long>>, ITopKState<T>>> template;
if (isHoppingWindow)
template = (g) => new HoppingTopKState<T>(k, compiledRankComparer, g);
else
template = (g) => new SimpleTopKState<T>(g);

var combinedComparer = ThenOrderBy(Reverse(rankComparer), overallComparer);
var generator = combinedComparer.CreateSortedDictionaryGenerator<T, long>(container);
var replaced = template.ReplaceParametersInBody(generator);
initialState = Expression.Lambda<Func<ITopKState<T>>>(replaced);
}

private static IComparerExpression<T> Reverse(IComparerExpression<T> comparer)
Expand All @@ -53,10 +63,11 @@ private static IComparerExpression<T> ThenOrderBy(IComparerExpression<T> compare
return new ComparerExpression<T>(newExpression);
}

public override Expression<Func<SortedMultiSet<T>, List<RankedEvent<T>>>> ComputeResult() => set => GetTopK(set);
public Expression<Func<ITopKState<T>, List<RankedEvent<T>>>> ComputeResult() => set => GetTopK(set);

private List<RankedEvent<T>> GetTopK(SortedMultiSet<T> set)
private List<RankedEvent<T>> GetTopK(ITopKState<T> state)
{
var set = state.GetSortedValues();
int count = (int)Math.Min(this.k, set.TotalCount);
var result = new List<RankedEvent<T>>(count);
int nextRank = 1;
Expand All @@ -82,5 +93,30 @@ private List<RankedEvent<T>> GetTopK(SortedMultiSet<T> set)

return result;
}

internal static ITopKState<T> Apply(ITopKState<T> state, Action<ITopKState<T>> op)
{
op(state);
return state;
}

private readonly Expression<Func<ITopKState<T>>> initialState;
public Expression<Func<ITopKState<T>>> InitialState() => initialState;

private static readonly Expression<Func<ITopKState<T>, long, T, ITopKState<T>>> acc
= (state, timestamp, input) => Apply(state, s => s.Add(input, timestamp));
public Expression<Func<ITopKState<T>, long, T, ITopKState<T>>> Accumulate() => acc;

private static readonly Expression<Func<ITopKState<T>, long, T, ITopKState<T>>> dec
= (state, timestamp, input) => Apply(state, s => s.Remove(input, timestamp));
public Expression<Func<ITopKState<T>, long, T, ITopKState<T>>> Deaccumulate() => dec;

private static readonly Expression<Func<ITopKState<T>, ITopKState<T>, ITopKState<T>>> diff
= (leftState, rightState) => Apply(leftState, s => s.RemoveAll(rightState));
public Expression<Func<ITopKState<T>, ITopKState<T>, ITopKState<T>>> Difference() => diff;

private static readonly Expression<Func<ITopKState<T>, ITopKState<T>, ITopKState<T>>> sum
= (leftState, rightState) => Apply(leftState, s => s.AddAll(rightState));
public Expression<Func<ITopKState<T>, ITopKState<T>, ITopKState<T>>> Sum() => sum;
}
}
236 changes: 236 additions & 0 deletions Sources/Core/Microsoft.StreamProcessing/Aggregates/TopKState.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.StreamProcessing.Aggregates
{
/// <summary>
/// State used by TopK Aggregate
/// </summary>
/// <typeparam name="T"></typeparam>
public interface ITopKState<T>
{
/// <summary>
/// Add a single entry
/// </summary>
/// <param name="input"></param>
/// <param name="timestamp"></param>
void Add(T input, long timestamp);

/// <summary>
/// Adds all entries from other
/// </summary>
/// <param name="other"></param>
void AddAll(ITopKState<T> other);

/// <summary>
/// Removes the specified entry
/// </summary>
/// <param name="input"></param>
/// <param name="timestamp"></param>
void Remove(T input, long timestamp);

/// <summary>
/// Removes entries from other
/// </summary>
/// <param name="other"></param>
void RemoveAll(ITopKState<T> other);

/// <summary>
/// Gets the values as sorted set
/// </summary>
/// <returns></returns>
SortedMultiSet<T> GetSortedValues();

/// <summary>
/// Returns total number of values in the set
/// </summary>
long Count { get; }
}

internal class SimpleTopKState<T> : ITopKState<T>
{
private SortedMultiSet<T> values;

public SimpleTopKState(Func<SortedDictionary<T, long>> generator)
{
values = new SortedMultiSet<T>(generator);
}

public long Count => values.TotalCount;

public virtual void Add(T input, long timestamp)
{
values.Add(input);
}

public void AddAll(ITopKState<T> other)
{
values.AddAll(other.GetSortedValues());
}

public SortedMultiSet<T> GetSortedValues()
{
return values;
}

public void Remove(T input, long timestamp)
{
values.Remove(input);
}

public void RemoveAll(ITopKState<T> other)
{
values.RemoveAll(other.GetSortedValues());
}
}

internal class HoppingTopKState<T> : ITopKState<T>
{
public long currentTimestamp;

public SortedMultiSet<T> previousValues;
public SortedMultiSet<T> currentValues;

public int k;

public Comparison<T> rankComparer;

public HoppingTopKState(int k, Comparison<T> rankComparer, Func<SortedDictionary<T, long>> generator)
{
this.k = k;
this.rankComparer = rankComparer;
this.currentValues = new SortedMultiSet<T>(generator);
this.previousValues = new SortedMultiSet<T>(generator);
}

public void Add(T input, long timestamp)
{
if (timestamp > currentTimestamp)
{
MergeCurrentToPrevious();
currentTimestamp = timestamp;
}
else if (timestamp < currentTimestamp)
{
throw new ArgumentException("Invalid timestamp");
}

currentValues.Add(input);

var toRemove = currentValues.TotalCount - k;
if (toRemove > 0)
{
var min = currentValues.GetMinItem();
if (toRemove == min.Count)
{
currentValues.Remove(min.Item, min.Count);
}
else if (toRemove > min.Count)
{
throw new InvalidOperationException("CurrentValues has more items than required");
}
}
}

public void Remove(T input, long timestamp)
{
if (timestamp < currentTimestamp)
{
previousValues.Remove(input);
}
else if (timestamp == currentTimestamp)
{
currentValues.Remove(input);
}
else
{
throw new ArgumentException("Invalid timestamp");
}
}

public void RemoveAll(ITopKState<T> other)
{
if (other.Count != 0)
{
if (other is HoppingTopKState<T> otherTopK)
{
if (otherTopK.currentTimestamp >= currentTimestamp)
{
throw new ArgumentException("Cannot remove entries with current or future timestamp");
}
previousValues.RemoveAll(otherTopK.currentValues);
previousValues.RemoveAll(otherTopK.previousValues);
}
else
{
throw new InvalidOperationException("Cannot remove non-HoppingTopKState from HoppingTopKState");
}
}
}

public SortedMultiSet<T> GetSortedValues()
{
if (previousValues.IsEmpty)
return currentValues;
else
{
MergeCurrentToPrevious();
return previousValues;
}
}

private void MergeCurrentToPrevious()
{
if (!currentValues.IsEmpty)
{
// Swap so we merge small onto larger
if (previousValues.UniqueCount < currentValues.UniqueCount)
{
var temp = previousValues;
previousValues = currentValues;
currentValues = temp;
}

if (!currentValues.IsEmpty)
{
previousValues.AddAll(currentValues);
currentValues.Clear();
}
}
}

public void AddAll(ITopKState<T> other)
{
if (other is HoppingTopKState<T> otherTopK)
{
if (otherTopK.currentTimestamp == currentTimestamp)
{
currentValues.AddAll(otherTopK.currentValues);
while (currentValues.TotalCount > k)
currentValues.Remove(currentValues.First());
}
else if (otherTopK.currentTimestamp < currentTimestamp)
{
previousValues.RemoveAll(otherTopK.currentValues);
previousValues.RemoveAll(otherTopK.previousValues);
}
else
{
MergeCurrentToPrevious();
currentValues.AddAll(otherTopK.currentValues);
currentTimestamp = otherTopK.currentTimestamp;
}
}
else
{
throw new InvalidOperationException("Cannot add non-HoppingTopKState from HoppingTopKState");
}
}

public long Count
{
get => this.currentValues.TotalCount + previousValues.TotalCount;
}
}
}
Loading

0 comments on commit 5f60996

Please sign in to comment.