Skip to content

Commit

Permalink
Recorder_Enumerator: Distinct, Union, Intersect, and Except support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Rog-Wilhelm committed Oct 3, 2023
1 parent b1a5160 commit 722ebc6
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 2 deletions.
47 changes: 45 additions & 2 deletions extra/recorder_enumerator/src/Config.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
namespace Dec.RecorderEnumerator
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
Expand All @@ -12,6 +13,15 @@ public static class Config
private static readonly Regex RecordableClosureRegex = new Regex(@"^<>c__DisplayClass([0-9]+)_([0-9]+)$", RegexOptions.Compiled);


private static HashSet<(string, string)> InternalRegexSupportOverride = new HashSet<(string, string)>
{
("System.Linq", "DistinctByIterator"),
("System.Linq", "ExceptIterator"),
("System.Linq", "ExceptByIterator"),
("System.Linq", "IntersectIterator"),
("System.Linq", "IntersectByIterator"),
};

public static Converter ConverterFactory(Type type)
{
if (type == SystemLinqEnumerable_RangeIterator_Converter.RelevantType)
Expand Down Expand Up @@ -87,13 +97,38 @@ public static Converter ConverterFactory(Type type)
return (Converter)Activator.CreateInstance(typeof(SystemLinqEnumerable_SelectMany_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[1]));
}

// Set-related

if (genericTypeDefinition == SystemLinqEnumerable_DistinctIterator_Converter.RelevantType)
{
return (Converter)Activator.CreateInstance(typeof(SystemLinqEnumerable_DistinctIterator_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0]));
}

if (genericTypeDefinition == SystemLinqEnumerable_UnionIterator2_Converter.RelevantType)
{
return (Converter)Activator.CreateInstance(typeof(SystemLinqEnumerable_UnionIterator2_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0]));
}

if (genericTypeDefinition == SystemLinqEnumerable_UnionIteratorN_Converter.RelevantType)
{
return (Converter)Activator.CreateInstance(typeof(SystemLinqEnumerable_UnionIteratorN_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0]));
}

// List enumerator

if (genericTypeDefinition == SystemCollections_List_Enumerator_Converter.RelevantType)
{
return (Converter)Activator.CreateInstance(typeof(SystemCollections_List_Enumerator_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0]));
}

// SystemLinq

if (genericTypeDefinition == SystemLinq_SingleLinkedNode_Converter.RelevantType)
{
return (Converter)Activator.CreateInstance(typeof(SystemLinq_SingleLinkedNode_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0]));
}


// Delegate

if (System_Delegate_Converter.IsGenericDelegate(genericTypeDefinition))
Expand Down Expand Up @@ -138,12 +173,20 @@ public static Converter ConverterFactory(Type type)

int tags = functions.Count(f => f.GetCustomAttribute<RecordableEnumerableAttribute>() != null);

if (tags == 0)
if (tags == functions.Length)
{

}
else if (InternalRegexSupportOverride.Contains((type.Namespace, functionName)))
{

}
else if (tags == 0)
{
Dbg.Err($"Attempting to serialize an enumerable {type} without a Dec.RecorderEnumerator.RecordableEnumerable applied to its function");
return null;
}
else if (tags != functions.Length)
else // tags != functionLength
{
Dbg.Err($"Attempting to serialize an enumerable {type} without a Dec.RecorderEnumerator.RecordableEnumerable applied to all functions with that name; sorry, it's gotta be all of them right now");
return null;
Expand Down
35 changes: 35 additions & 0 deletions extra/recorder_enumerator/src/SystemLinq.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
namespace Dec.RecorderEnumerator
{
using System;
using System.IO;
using System.Linq;
using System.Reflection;

public static class SystemLinq_SingleLinkedNode_Converter
{
internal static Type RelevantType = typeof(System.Linq.Enumerable).Assembly.GetType("System.Linq.SingleLinkedNode`1");
}

public class SystemLinq_SingleLinkedNode_Converter<Node, T> : ConverterFactoryDynamic
{
internal FieldInfo field_Item = typeof(Node).GetPrivateFieldInHierarchy("<Item>k__BackingField");
internal FieldInfo field_Linked = typeof(Node).GetPrivateFieldInHierarchy("<Linked>k__BackingField");

public override void Write(object input, Recorder recorder)
{
recorder.Shared().RecordPrivate(input, field_Item, "item");
recorder.Shared().RecordPrivate(input, field_Linked, "linked");
}

public override object Create(Recorder recorder)
{
return Activator.CreateInstance(typeof(Node), new object[] { default(T) });
}

public override void Read(ref object input, Recorder recorder)
{
// it's the same code, we only need this for the funky Create
Write(input, recorder);
}
}
}
112 changes: 112 additions & 0 deletions extra/recorder_enumerator/src/SystemLinqEnumerable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,116 @@ public override void Read(ref object input, Recorder recorder)
recorder.RecordPrivate(input, Field_State, "state");
}
}

public static class SystemLinqEnumerable_DistinctIterator_Converter
{
internal static Type RelevantType = typeof(System.Linq.Enumerable).GetNestedType("DistinctIterator`1", System.Reflection.BindingFlags.NonPublic);
}

public class SystemLinqEnumerable_DistinctIterator_Converter<Iterator, T> : ConverterFactoryDynamic
{
internal FieldInfo field_Source = typeof(Iterator).GetField("_source", BindingFlags.NonPublic | BindingFlags.Instance);
internal FieldInfo field_Comparer = typeof(Iterator).GetField("_comparer", BindingFlags.NonPublic | BindingFlags.Instance);
internal FieldInfo field_Set = typeof(Iterator).GetField("_set", BindingFlags.NonPublic | BindingFlags.Instance);
internal FieldInfo field_Enumerator = typeof(Iterator).GetField("_enumerator", BindingFlags.NonPublic | BindingFlags.Instance);
internal FieldInfo field_State = typeof(Iterator).GetField("_state", BindingFlags.NonPublic | BindingFlags.Instance);
internal FieldInfo field_Current = typeof(Iterator).GetField("_current", BindingFlags.NonPublic | BindingFlags.Instance);

public override void Write(object input, Recorder recorder)
{
recorder.Shared().RecordPrivate(input, field_Source, "source");
recorder.Shared().RecordPrivate(input, field_Comparer, "comparer");
recorder.Shared().RecordPrivate(input, field_Set, "selector");
recorder.Shared().RecordPrivate(input, field_Enumerator, "enumerator");
recorder.RecordPrivate(input, field_State, "state");
recorder.SharedIfPossible<T>().RecordPrivate(input, field_Current, "current");
}

public override object Create(Recorder recorder)
{
return Activator.CreateInstance(typeof(Iterator), new object[] { null, null });
}

public override void Read(ref object input, Recorder recorder)
{
// it's the same code, we only need this for the funky Create
Write(input, recorder);
}
}

public static class SystemLinqEnumerable_UnionIterator2_Converter
{
internal static Type RelevantType = typeof(System.Linq.Enumerable).GetNestedType("UnionIterator2`1", System.Reflection.BindingFlags.NonPublic);
}

public class SystemLinqEnumerable_UnionIterator2_Converter<Iterator, T> : ConverterFactoryDynamic
{
internal FieldInfo field_First = typeof(Iterator).GetPrivateFieldInHierarchy("_first");
internal FieldInfo field_Second = typeof(Iterator).GetPrivateFieldInHierarchy("_second");
internal FieldInfo field_Comparer = typeof(Iterator).GetPrivateFieldInHierarchy("_comparer");
internal FieldInfo field_Enumerator = typeof(Iterator).GetPrivateFieldInHierarchy("_enumerator");
internal FieldInfo field_Set = typeof(Iterator).GetPrivateFieldInHierarchy("_set");
internal FieldInfo field_State = typeof(Iterator).GetPrivateFieldInHierarchy("_state");
internal FieldInfo field_Current = typeof(Iterator).GetPrivateFieldInHierarchy("_current");

public override void Write(object input, Recorder recorder)
{
recorder.Shared().RecordPrivate(input, field_First, "first");
recorder.Shared().RecordPrivate(input, field_Second, "second");
recorder.Shared().RecordPrivate(input, field_Comparer, "comparer");
recorder.Shared().RecordPrivate(input, field_Enumerator, "enumerator");
recorder.Shared().RecordPrivate(input, field_Set, "selector");
recorder.RecordPrivate(input, field_State, "state");
recorder.SharedIfPossible<T>().RecordPrivate(input, field_Current, "current");
}

public override object Create(Recorder recorder)
{
return Activator.CreateInstance(typeof(Iterator), new object[] { null, null, null });
}

public override void Read(ref object input, Recorder recorder)
{
// it's the same code, we only need this for the funky Create
Write(input, recorder);
}
}

public static class SystemLinqEnumerable_UnionIteratorN_Converter
{
internal static Type RelevantType = typeof(System.Linq.Enumerable).GetNestedType("UnionIteratorN`1", System.Reflection.BindingFlags.NonPublic);
}

public class SystemLinqEnumerable_UnionIteratorN_Converter<Iterator, T> : ConverterFactoryDynamic
{
internal FieldInfo field_Sources = typeof(Iterator).GetPrivateFieldInHierarchy("_sources");
internal FieldInfo field_HeadIndex = typeof(Iterator).GetPrivateFieldInHierarchy("_headIndex");
internal FieldInfo field_Comparer = typeof(Iterator).GetPrivateFieldInHierarchy("_comparer");
internal FieldInfo field_Enumerator = typeof(Iterator).GetPrivateFieldInHierarchy("_enumerator");
internal FieldInfo field_Set = typeof(Iterator).GetPrivateFieldInHierarchy("_set");
internal FieldInfo field_State = typeof(Iterator).GetPrivateFieldInHierarchy("_state");
internal FieldInfo field_Current = typeof(Iterator).GetPrivateFieldInHierarchy("_current");

public override void Write(object input, Recorder recorder)
{
recorder.Shared().RecordPrivate(input, field_Sources, "sources");
recorder.RecordPrivate(input, field_HeadIndex, "headIndex");
recorder.Shared().RecordPrivate(input, field_Comparer, "comparer");
recorder.Shared().RecordPrivate(input, field_Enumerator, "enumerator");
recorder.Shared().RecordPrivate(input, field_Set, "selector");
recorder.RecordPrivate(input, field_State, "state");
recorder.SharedIfPossible<T>().RecordPrivate(input, field_Current, "current");
}

public override object Create(Recorder recorder)
{
return Activator.CreateInstance(typeof(Iterator), new object[] { null, 0, null });
}

public override void Read(ref object input, Recorder recorder)
{
// it's the same code, we only need this for the funky Create
Write(input, recorder);
}
}
}
15 changes: 15 additions & 0 deletions extra/recorder_enumerator/src/Util.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,20 @@ internal static Recorder.Parameters SharedIfPossible<T>(this Recorder recorder)
return recorder.WithFactory(null);
}
}

internal static FieldInfo GetPrivateFieldInHierarchy(this Type type, string name)
{
while (type != null)
{
var field = type.GetField(name, BindingFlags.Instance | BindingFlags.NonPublic);
if (field != null)
{
return field;
}
type = type.BaseType;
}

return null;
}
}
}
115 changes: 115 additions & 0 deletions extra/recorder_enumerator/test/Linq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -340,5 +340,120 @@ public void GroupJoinEnumeratorTest([ValuesExcept(RecorderMode.Validation)] Reco

Assert.IsTrue(Util.AreEquivalentEnumerators(groupJoinEnumerator, result));
}

[Test]
[Dec.RecorderEnumerator.RecordableClosures]
public void DistinctEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode)
{
var source = Enumerable.Range(0, 40).Select(i => i % 5).Distinct().GetEnumerator();
source.MoveNext();
source.MoveNext();
source.MoveNext();
var result = DoRecorderRoundTrip(source, recorderMode);
Assert.IsTrue(Util.AreEquivalentEnumerators(source, result));
}

#if NET6_0_OR_GREATER
[Test]
[Dec.RecorderEnumerator.RecordableClosures]
public void DistinctByEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode)
{
var source = Enumerable.Range(0, 40).Select(i => i % 5).DistinctBy(x => x % 4).GetEnumerator();
source.MoveNext();
source.MoveNext();
source.MoveNext();
var result = DoRecorderRoundTrip(source, recorderMode);
Assert.IsTrue(Util.AreEquivalentEnumerators(source, result));
}
#endif

[Test]
[Dec.RecorderEnumerator.RecordableClosures]
public void UnionEnumerator2Test([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode)
{
var first = Enumerable.Range(0, 20).Select(x => x * 2);
var second = Enumerable.Range(0, 20).Select(x => x * 3);
var unionEnumerator = first.Union(second).GetEnumerator();
unionEnumerator.MoveNext();
unionEnumerator.MoveNext();
unionEnumerator.MoveNext();
var result = DoRecorderRoundTrip(unionEnumerator, recorderMode);
Assert.IsTrue(Util.AreEquivalentEnumerators(unionEnumerator, result));
}

[Test]
[Dec.RecorderEnumerator.RecordableClosures]
public void UnionEnumerator3Test([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode)
{
var first = Enumerable.Range(0, 20).Select(x => x * 2);
var second = Enumerable.Range(0, 20).Select(x => x * 3);
var third = Enumerable.Range(0, 20).Select(x => x * 5);
var unionEnumerator = first.Union(second).Union(third).GetEnumerator();
unionEnumerator.MoveNext();
unionEnumerator.MoveNext();
unionEnumerator.MoveNext();
var result = DoRecorderRoundTrip(unionEnumerator, recorderMode);
Assert.IsTrue(Util.AreEquivalentEnumerators(unionEnumerator, result));
}

[Test]
[Dec.RecorderEnumerator.RecordableClosures]
public void IntersectEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode)
{
var first = Enumerable.Range(0, 20);
var second = Enumerable.Range(15, 20);
var intersectEnumerator = first.Intersect(second).GetEnumerator();
intersectEnumerator.MoveNext();
intersectEnumerator.MoveNext();
intersectEnumerator.MoveNext();
var result = DoRecorderRoundTrip(intersectEnumerator, recorderMode);
Assert.IsTrue(Util.AreEquivalentEnumerators(intersectEnumerator, result));
}

#if NET6_0_OR_GREATER
[Test]
[Dec.RecorderEnumerator.RecordableClosures]
public void IntersectByEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode)
{
var first = Enumerable.Range(0, 20);
var second = Enumerable.Range(15, 20);
var intersectEnumerator = first.IntersectBy(second, x => x % 17).GetEnumerator();
intersectEnumerator.MoveNext();
intersectEnumerator.MoveNext();
intersectEnumerator.MoveNext();
var result = DoRecorderRoundTrip(intersectEnumerator, recorderMode);
Assert.IsTrue(Util.AreEquivalentEnumerators(intersectEnumerator, result));
}
#endif

[Test]
[Dec.RecorderEnumerator.RecordableClosures]
public void ExceptEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode)
{
var first = Enumerable.Range(0, 20);
var second = Enumerable.Range(15, 20);
var exceptEnumerator = first.Except(second).GetEnumerator();
exceptEnumerator.MoveNext();
exceptEnumerator.MoveNext();
exceptEnumerator.MoveNext();
var result = DoRecorderRoundTrip(exceptEnumerator, recorderMode);
Assert.IsTrue(Util.AreEquivalentEnumerators(exceptEnumerator, result));
}

#if NET6_0_OR_GREATER
[Test]
[Dec.RecorderEnumerator.RecordableClosures]
public void ExceptByEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode)
{
var first = Enumerable.Range(0, 20);
var second = Enumerable.Range(15, 20);
var exceptEnumerator = first.ExceptBy(second, x => x % 3).GetEnumerator();
exceptEnumerator.MoveNext();
exceptEnumerator.MoveNext();
exceptEnumerator.MoveNext();
var result = DoRecorderRoundTrip(exceptEnumerator, recorderMode);
Assert.IsTrue(Util.AreEquivalentEnumerators(exceptEnumerator, result));
}
#endif
}
}

0 comments on commit 722ebc6

Please sign in to comment.