From 34a6dc3cf00f6019fb85dff912ba9d083d722aed Mon Sep 17 00:00:00 2001 From: jvbsl Date: Sun, 10 Dec 2023 16:51:57 +0100 Subject: [PATCH] Foreach fixed, unnecessary code removed * Fixes foreach to correctly dispose enumerator * Removed unnecessary code --- .../BaseGenerator.cs | 24 ++++++++++++------- .../Helper.cs | 8 ++++++- .../MethodResolver.cs | 4 ---- .../Serializers/ListSerializer.cs | 2 +- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/NonSucking.Framework.Serialization.Advanced/BaseGenerator.cs b/NonSucking.Framework.Serialization.Advanced/BaseGenerator.cs index ff46630..c82d69b 100644 --- a/NonSucking.Framework.Serialization.Advanced/BaseGenerator.cs +++ b/NonSucking.Framework.Serialization.Advanced/BaseGenerator.cs @@ -184,7 +184,7 @@ public void EmitForEach(Type enumerableType, Action getEnumerable if (getCurrent is null) throw new ArgumentException( $"Not a valid enumerator(no valid Current get method found): '{enumeratorType.FullName}'"); - var dispose = Helper.GetMethodIncludingInterfaces(enumeratorType, "Dispose", BindingFlags.Public | BindingFlags.Instance | BindingFlags.FlattenHierarchy); + var dispose = Helper.GetMethodInInterfaces(enumeratorType, "Dispose", BindingFlags.Public | BindingFlags.Instance | BindingFlags.FlattenHierarchy); if (dispose is null) throw new ArgumentException( $"Not a valid enumerator(no valid Dispose method found): '{enumeratorType.FullName}'"); @@ -213,14 +213,20 @@ public void EmitForEach(Type enumerableType, Action getEnumerable }, (gen, exitLabel) => { - // if (!enumeratorVariable.LocalType.IsValueType) - // { - // gen.EmitLoadLocRef(enumeratorVariable); - // gen.IL.Emit(OpCodes.Brfalse, exitLabel); - // } - // - // gen.EmitLoadLocRef(enumeratorVariable); - // gen.IL.Emit(OpCodes.Callvirt, dispose); + bool isValueType = enumeratorVariable.LocalType.IsValueType; + if (!isValueType) + { + gen.EmitLoadLocRef(enumeratorVariable); + gen.Il.Emit(OpCodes.Brfalse, exitLabel); + } + + gen.EmitLoadLocRef(enumeratorVariable); + + if (isValueType) + { + gen.Il.Emit(OpCodes.Constrained, enumeratorVariable.LocalType); + } + gen.Il.Emit(OpCodes.Callvirt, dispose); }); } diff --git a/NonSucking.Framework.Serialization.Advanced/Helper.cs b/NonSucking.Framework.Serialization.Advanced/Helper.cs index 8a40fb6..126943a 100644 --- a/NonSucking.Framework.Serialization.Advanced/Helper.cs +++ b/NonSucking.Framework.Serialization.Advanced/Helper.cs @@ -258,9 +258,15 @@ internal static bool MatchIdentifierWithPropName(string identifier, string param var res = type.GetMethod(name, bindingFlags); if (res is not null) return res; + + return GetMethodInInterfaces(type, name, bindingFlags); + } + + internal static MethodInfo? GetMethodInInterfaces(Type type, string name, BindingFlags bindingFlags) + { foreach (var i in type.GetInterfaces()) { - res = i.GetMethod(name, bindingFlags); + var res = i.GetMethod(name, bindingFlags); if (res is not null) return res; } diff --git a/NonSucking.Framework.Serialization.Advanced/MethodResolver.cs b/NonSucking.Framework.Serialization.Advanced/MethodResolver.cs index 5c2912c..5e844a1 100644 --- a/NonSucking.Framework.Serialization.Advanced/MethodResolver.cs +++ b/NonSucking.Framework.Serialization.Advanced/MethodResolver.cs @@ -49,12 +49,8 @@ public static implicit operator AssemblyNameKey(AssemblyName name) static MethodResolver() { // Default resolve extension methods - Stopwatch st = new(); - st.Start(); AnalyzeExtensionMethodsRecurse(Assembly.GetCallingAssembly()); AnalyzeExtensionMethodsRecurse(Assembly.GetEntryAssembly()); - st.Stop(); - Console.WriteLine($"Took {st.ElapsedMilliseconds}ms to load"); } internal static IEnumerable GetRegisteredExtensionMethods(string name) { diff --git a/NonSucking.Framework.Serialization.Advanced/Serializers/ListSerializer.cs b/NonSucking.Framework.Serialization.Advanced/Serializers/ListSerializer.cs index d3e79fb..d083b8d 100644 --- a/NonSucking.Framework.Serialization.Advanced/Serializers/ListSerializer.cs +++ b/NonSucking.Framework.Serialization.Advanced/Serializers/ListSerializer.cs @@ -264,7 +264,7 @@ private static int GetInheritanceDepth(Type? type, int lastAmount = -1) if (type.GetInterfaces().Any(x => x == typeof(IEnumerable))) return lastAmount; type = type.BaseType; - lastAmount = ++lastAmount; + ++lastAmount; } return lastAmount;