From baffd9972023e7589f0a5f721282a51b08c1e4f6 Mon Sep 17 00:00:00 2001 From: Ben Rog-Wilhelm Date: Fri, 6 Oct 2023 10:33:13 -0500 Subject: [PATCH] Recorder_Enumerator: Support Reverse, OrderBy, and related. --- extra/recorder_enumerator/src/Config.cs | 12 ++++ extra/recorder_enumerator/src/SystemLinq.cs | 28 +++++++++ .../src/SystemLinqEnumerable.cs | 32 ++++++++++ extra/recorder_enumerator/test/Linq.cs | 60 +++++++++++++++++++ 4 files changed, 132 insertions(+) diff --git a/extra/recorder_enumerator/src/Config.cs b/extra/recorder_enumerator/src/Config.cs index 0110ff8b..12646448 100644 --- a/extra/recorder_enumerator/src/Config.cs +++ b/extra/recorder_enumerator/src/Config.cs @@ -20,6 +20,7 @@ public static class Config ("System.Linq.Enumerable", "ExceptByIterator"), ("System.Linq.Enumerable", "IntersectIterator"), ("System.Linq.Enumerable", "IntersectByIterator"), + ("System.Linq.OrderedEnumerable`1", "GetEnumerator"), }; public static Converter ConverterFactory(Type type) @@ -114,6 +115,13 @@ public static Converter ConverterFactory(Type type) return (Converter)Activator.CreateInstance(typeof(SystemLinqEnumerable_UnionIteratorN_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0])); } + // Misc + + if (genericTypeDefinition == SystemLinqEnumerable_ReverseIterator_Converter.RelevantType) + { + return (Converter)Activator.CreateInstance(typeof(SystemLinqEnumerable_ReverseIterator_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0])); + } + // List enumerator if (genericTypeDefinition == SystemCollections_List_Enumerator_Converter.RelevantType) @@ -128,6 +136,10 @@ public static Converter ConverterFactory(Type type) return (Converter)Activator.CreateInstance(typeof(SystemLinq_SingleLinkedNode_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0])); } + if (genericTypeDefinition == SystemLinq_Buffer_Converter.RelevantType) + { + return (Converter)Activator.CreateInstance(typeof(SystemLinq_Buffer_Converter<,>).MakeGenericType(type, type.GenericTypeArguments[0])); + } // Delegate diff --git a/extra/recorder_enumerator/src/SystemLinq.cs b/extra/recorder_enumerator/src/SystemLinq.cs index 913c62e6..ac3d45b4 100644 --- a/extra/recorder_enumerator/src/SystemLinq.cs +++ b/extra/recorder_enumerator/src/SystemLinq.cs @@ -32,4 +32,32 @@ public override void Read(ref object input, Recorder recorder) Write(input, recorder); } } + + public static class SystemLinq_Buffer_Converter + { + internal static Type RelevantType = typeof(System.Linq.Enumerable).Assembly.GetType("System.Linq.Buffer`1"); + } + + public class SystemLinq_Buffer_Converter : ConverterFactoryDynamic + { + internal FieldInfo field_Item = typeof(Node).GetPrivateFieldInHierarchy("_items"); + internal FieldInfo field_Count = typeof(Node).GetPrivateFieldInHierarchy("_count"); + + public override void Write(object input, Recorder recorder) + { + recorder.Shared().RecordPrivate(input, field_Item, "item"); + recorder.RecordPrivate(input, field_Count, "count"); + } + + public override object Create(Recorder recorder) + { + return Activator.CreateInstance(typeof(Node), new object[] { Enumerable.Empty() }); + } + + public override void Read(ref object input, Recorder recorder) + { + // it's the same code, we only need this for the funky Create + Write(input, recorder); + } + } } diff --git a/extra/recorder_enumerator/src/SystemLinqEnumerable.cs b/extra/recorder_enumerator/src/SystemLinqEnumerable.cs index efc32ae8..f811bd90 100644 --- a/extra/recorder_enumerator/src/SystemLinqEnumerable.cs +++ b/extra/recorder_enumerator/src/SystemLinqEnumerable.cs @@ -151,4 +151,36 @@ public override void Read(ref object input, Recorder recorder) Write(input, recorder); } } + + public static class SystemLinqEnumerable_ReverseIterator_Converter + { + internal static Type RelevantType = typeof(System.Linq.Enumerable).GetNestedType("ReverseIterator`1", System.Reflection.BindingFlags.NonPublic); + } + + public class SystemLinqEnumerable_ReverseIterator_Converter : ConverterFactoryDynamic + { + internal FieldInfo field_Source = typeof(Iterator).GetPrivateFieldInHierarchy("_source"); + internal FieldInfo field_Buffer = typeof(Iterator).GetPrivateFieldInHierarchy("_buffer"); + internal FieldInfo field_State = typeof(Iterator).GetPrivateFieldInHierarchy("_state"); + internal FieldInfo field_Current = typeof(Iterator).GetPrivateFieldInHierarchy("_current"); + + public override void Write(object input, Recorder recorder) + { + recorder.Shared().RecordPrivate(input, field_Source, "source"); + recorder.Shared().RecordPrivate(input, field_Buffer, "buffer"); + recorder.RecordPrivate(input, field_State, "state"); + recorder.SharedIfPossible().RecordPrivate(input, field_Current, "current"); + } + + public override object Create(Recorder recorder) + { + return Activator.CreateInstance(typeof(Iterator), new object[] { null }); + } + + public override void Read(ref object input, Recorder recorder) + { + // it's the same code, we only need this for the funky Create + Write(input, recorder); + } + } } diff --git a/extra/recorder_enumerator/test/Linq.cs b/extra/recorder_enumerator/test/Linq.cs index 0abf9eda..123e26a5 100644 --- a/extra/recorder_enumerator/test/Linq.cs +++ b/extra/recorder_enumerator/test/Linq.cs @@ -455,5 +455,65 @@ public void ExceptByEnumeratorTest([ValuesExcept(RecorderMode.Validation)] Recor Assert.IsTrue(Util.AreEquivalentEnumerators(exceptEnumerator, result)); } #endif + + [Test] + [Dec.RecorderEnumerator.RecordableClosures] + public void OrderByEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode) + { + var source = Enumerable.Range(0, 20).OrderBy(i => i % 3).GetEnumerator(); + source.MoveNext(); + source.MoveNext(); + source.MoveNext(); + var result = DoRecorderRoundTrip(source, recorderMode); + Assert.IsTrue(Util.AreEquivalentEnumerators(source, result)); + } + + [Test] + [Dec.RecorderEnumerator.RecordableClosures] + public void OrderByDescendingEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode) + { + var source = Enumerable.Range(0, 20).OrderByDescending(i => i % 3).GetEnumerator(); + source.MoveNext(); + source.MoveNext(); + source.MoveNext(); + var result = DoRecorderRoundTrip(source, recorderMode); + Assert.IsTrue(Util.AreEquivalentEnumerators(source, result)); + } + + [Test] + [Dec.RecorderEnumerator.RecordableClosures] + public void ThenByEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode) + { + var source = Enumerable.Range(0, 20).OrderBy(i => i % 3).ThenBy(i => i).GetEnumerator(); + source.MoveNext(); + source.MoveNext(); + source.MoveNext(); + var result = DoRecorderRoundTrip(source, recorderMode); + Assert.IsTrue(Util.AreEquivalentEnumerators(source, result)); + } + + [Test] + [Dec.RecorderEnumerator.RecordableClosures] + public void ThenByDescendingEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode) + { + var source = Enumerable.Range(0, 20).OrderBy(i => i % 3).ThenByDescending(i => i).GetEnumerator(); + source.MoveNext(); + source.MoveNext(); + source.MoveNext(); + var result = DoRecorderRoundTrip(source, recorderMode); + Assert.IsTrue(Util.AreEquivalentEnumerators(source, result)); + } + + [Test] + [Dec.RecorderEnumerator.RecordableClosures] + public void ReverseEnumeratorTest([ValuesExcept(RecorderMode.Validation)] RecorderMode recorderMode) + { + var source = Enumerable.Range(0, 20).Reverse().GetEnumerator(); + source.MoveNext(); + source.MoveNext(); + source.MoveNext(); + var result = DoRecorderRoundTrip(source, recorderMode); + Assert.IsTrue(Util.AreEquivalentEnumerators(source, result)); + } } }