diff --git a/MicroRuleEngine.Core.Tests/ExampleUsage.cs b/MicroRuleEngine.Core.Tests/ExampleUsage.cs index 42482a6..918eb92 100644 --- a/MicroRuleEngine.Core.Tests/ExampleUsage.cs +++ b/MicroRuleEngine.Core.Tests/ExampleUsage.cs @@ -4,6 +4,7 @@ using System.Linq.Expressions; using Microsoft.VisualStudio.TestTools.UnitTesting; using MicroRuleEngine.Core.Tests.Models; +using Newtonsoft.Json; namespace MicroRuleEngine.Tests { @@ -245,6 +246,11 @@ public void RegexIsMatch()//Had to add a Regex evaluator to make it feel 'Comple Assert.IsFalse(passes); } + public class OrderParent + { + public Order PlacedOrder { get; set; } + } + public static Order GetOrder() { Order order = new Order() @@ -278,5 +284,216 @@ public static Order GetOrder() }; return order; } + + [TestMethod] + public void EnumerableFilterAndAggregation() + { + + + + Order order = GetOrder(); + + if (order.Items.Where(x => x.ItemCode.StartsWith("M")) + .Sum(x => x.Cost) > 6) + { + + } + + + Rule rule = new Rule + { + MemberName = "Items", + EnumerableFilter = new Rule + { + MemberName = "ItemCode", + Operator = "StartsWith", + Inputs = new []{"M"} + }, + EnumerableValueExpression = new Selector + { + MemberName = "Cost", + Operator = "Sum" + }, + Operator = "GreaterThan", + TargetValue = 5 + }; + + MRE engine = new MRE(); + var compiledRule = engine.CompileRule(rule); + bool passes = compiledRule(order); + Assert.IsTrue(passes); + + order.Items[0].Cost = 4m; + passes = compiledRule(order); + Assert.IsFalse(passes); + } + + [TestMethod] + public void CountAggregation() + { + + Order order = GetOrder(); + + Rule rule = new Rule + { + MemberName = "Items", + EnumerableValueExpression = new Selector + { + Operator = "Count" + }, + Operator = "GreaterThan", + TargetValue = 3 + }; + + MRE engine = new MRE(); + var compiledRule = engine.CompileRule(rule); + bool passes = compiledRule(order); + Assert.IsFalse(passes); + + order.Items.Add(new Item()); + order.Items.Add(new Item()); + + passes = compiledRule(order); + + Assert.IsTrue(passes); + + } + + [TestMethod] + public void EnumerableAggregationOnChild() + { + + + + Order order = GetOrder(); + + var orderParent = new OrderParent() {PlacedOrder = order}; + + + Rule rule = new Rule + { + MemberName = "PlacedOrder.Items", + EnumerableValueExpression = new Selector + { + Operator = "Count" + }, + Operator = "GreaterThan", + TargetValue = 3 + }; + + MRE engine = new MRE(); + var compiledRule = engine.CompileRule(rule); + bool passes = compiledRule(orderParent); + Assert.IsFalse(passes); + + order.Items.Add(new Item()); + order.Items.Add(new Item()); + + passes = compiledRule(orderParent); + + Assert.IsTrue(passes); + } + + [TestMethod] + public void SerializeThenDeserialize() + { + + + + Order order = GetOrder(); + + var orderParent = new OrderParent() {PlacedOrder = order}; + + + Rule rule = new Rule + { + MemberName = "PlacedOrder.Items", + EnumerableValueExpression = new Selector + { + Operator = "Count" + }, + Operator = "GreaterThan", + TargetValue = 3 + }; + + var jsonString = JsonConvert.SerializeObject(rule); + + var deserializedRule = JsonConvert.DeserializeObject(jsonString); + + + + MRE engine = new MRE(); + var compiledRule = engine.CompileRule(deserializedRule); + bool passes = compiledRule(orderParent); + Assert.IsFalse(passes); + + order.Items.Add(new Item()); + order.Items.Add(new Item()); + + passes = compiledRule(orderParent); + + Assert.IsTrue(passes); + } + + [TestMethod] + public void SerializeThenDeserializeComplexRules() + { + + + + Order order = GetOrder(); + + var orderParent = new OrderParent() {PlacedOrder = order}; + + + Rule rule = new Rule + { + Operator = "AndAlso", + Rules = new List + { + + new Rule + { + MemberName = "PlacedOrder.Items", + EnumerableValueExpression = new Selector + { + Operator = "Count" + }, + Operator = "GreaterThan", + TargetValue = 3 + }, + + new Rule + { + MemberName = "PlacedOrder.Items", + EnumerableValueExpression = new Selector + { + MemberName = "Cost", + Operator = "Sum" + }, + Operator = "GreaterThan", + TargetValue = 5 + } + } + }; + + var jsonString = JsonConvert.SerializeObject(rule); + + var deserializedRule = JsonConvert.DeserializeObject(jsonString); + + + + MRE engine = new MRE(); + var compiledRule = engine.CompileRule(deserializedRule); + bool passes = compiledRule(orderParent); + Assert.IsFalse(passes); + + order.Items.Add(new Item()); + order.Items.Add(new Item()); + + passes = compiledRule(orderParent); + + Assert.IsTrue(passes); + } } } diff --git a/MicroRuleEngine/MRE.cs b/MicroRuleEngine/MRE.cs index 8117313..c042466 100644 --- a/MicroRuleEngine/MRE.cs +++ b/MicroRuleEngine/MRE.cs @@ -27,8 +27,8 @@ public class MRE private static readonly Tuple>[] _enumrMethodsByName = new Tuple>[] { - Tuple.Create("Any", new Lazy(() => GetLinqMethod("Any", 2))), - Tuple.Create("All", new Lazy(() => GetLinqMethod("All", 2))), + Tuple.Create("Any", new Lazy(() => GetLinqMethod("Any", 2))), + Tuple.Create("All", new Lazy(() => GetLinqMethod("All", 2))), }; private static readonly Lazy _miIntTryParse = new Lazy(() => typeof(Int32).GetMethod("TryParse", new Type[] { typeof(string), Type.GetType("System.Int32&") })); @@ -42,7 +42,7 @@ public class MRE private static readonly Lazy _miDecimalTryParse = new Lazy(() => typeof(Decimal).GetMethod("TryParse", new Type[] { typeof(string), Type.GetType("System.Decimal&") })); - public Func CompileRule(Rule r) + public FuncCompileRule(Rule r) { var paramUser = Expression.Parameter(typeof(T)); Expression expr = GetExpressionForRule(typeof(T), r, paramUser); @@ -140,12 +140,23 @@ protected static Expression BinaryExpression(IEnumerable expressions private static readonly Regex _regexIndexed = new Regex(@"(?'Collection'\w+)\[(?:(?'Index'\d+)|(?:['""](?'Key'\w+)[""']))\]", RegexOptions.Compiled); - private static Expression GetProperty(Expression param, string propname) + + private static Expression GetProperty(Expression param, + string propname) + { + return GetProperty(param, + propname, + out _); + } + + private static Expression GetProperty(Expression param, string propname, out PropertyInfo propertyInfo) { Expression propExpression = param; String[] childProperties = propname.Split('.'); var propertyType = param.Type; + propertyInfo = null; + foreach (var childprop in childProperties) { var isIndexed = _regexIndexed.Match(childprop); @@ -157,6 +168,7 @@ private static Expression GetProperty(Expression param, string propname) if (collectionProp == null) throw new RulesException( $"Cannot find collection property {collectionname} in class {propertyType.Name} (\"{propname}\")"); + propertyInfo = collectionProp; var collexpr = Expression.PropertyOrField(propExpression, collectionname); Expression expIndex; @@ -193,7 +205,8 @@ private static Expression GetProperty(Expression param, string propname) throw new RulesException( $"Cannot find property {childprop} in class {propertyType.Name} (\"{propname}\")"); propExpression = Expression.PropertyOrField(propExpression, childprop); - propertyType = property.PropertyType; + propertyType = property.PropertyType; + propertyInfo = property; } } @@ -267,6 +280,85 @@ private static Expression BuildExpr(Type type, Rule rule, Expression param, bool propExpression = GetProperty(param, rule.MemberName); propType = propExpression.Type; } + + if(typeof(IEnumerable).IsAssignableFrom(propType)) + { + if (rule.EnumerableFilter != null) + { + var elementType = ElementType(propType); + var lambdaParam = Expression.Parameter(elementType, + "lambdaParam"); + propExpression = Expression.Call(GetLinqMethod("Where", + 2) + .MakeGenericMethod(elementType), + propExpression, + Expression.Lambda(BuildExpr(elementType, + rule.EnumerableFilter, + lambdaParam), + lambdaParam)); + + + } + + if(rule.EnumerableValueExpression != null) + { + Type elementType = ElementType(propType); + ParameterExpression parameter; + + PropertyInfo property; + Expression selector = null; + + if(string.IsNullOrEmpty(rule.EnumerableValueExpression.MemberName)) + { + GetProperty(param, rule.MemberName, out property); + } + else + { + property = elementType.GetProperty(rule.EnumerableValueExpression.MemberName); + parameter = Expression.Parameter(elementType, "s"); + selector = Expression.Lambda(Expression.MakeMemberAccess(parameter, + property), + parameter); + } + + + + + MethodInfo generationMethod = GetLinqMethod(rule.EnumerableValueExpression.Operator, + 2, + property.PropertyType); + + + + if(generationMethod == null) + { + generationMethod = GetLinqMethod(rule.EnumerableValueExpression.Operator, rule.TargetValue.GetType()); + } + if(generationMethod == null) + { + generationMethod = GetLinqMethod(rule.EnumerableValueExpression.Operator); + } + + if (generationMethod == null) + { + throw new + RulesException($"Unable to find Linq method for {rule.EnumerableValueExpression.Operator}"); + } + + var m = generationMethod.MakeGenericMethod(elementType); + + propExpression = selector == null + ? Expression.Call(m, + propExpression) + : Expression.Call(m, + propExpression, + selector); + + + propType = propExpression.Type; + } + } + if (useTryCatch) { propExpression = Expression.TryCatch( @@ -275,6 +367,9 @@ private static Expression BuildExpr(Type type, Rule rule, Expression param, bool ); } + + + // is the operator a known .NET operator? ExpressionType tBinary; @@ -397,6 +492,25 @@ private static MethodInfo GetLinqMethod(string name, int numParameter) .FirstOrDefault(m => m.Name == name && m.GetParameters().Length == numParameter); } + private static MethodInfo GetLinqMethod(string name, int numParameter, Type returnType) + { + return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public) + .FirstOrDefault((m => m.Name == name && m.GetParameters().Length == numParameter && m.ReturnType == returnType)); + } + + private static MethodInfo GetLinqMethod(string name, Type returnType) + { + return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public) + .FirstOrDefault((m => m.Name == name && m.ReturnType == returnType)); + } + + private static MethodInfo GetLinqMethod(string name) + { + return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public) + .FirstOrDefault((m => m.Name == name)); + } + + private static Expression GetDataRowField(Expression prm, string member, string typeName) { @@ -696,19 +810,27 @@ public static List Operators(Type type, bool addLogicOperators = false } [DataContract] - public class Rule + public class Selector + { + [DataMember] public string MemberName { get; set; } + [DataMember] public string Operator { get; set; } + } + + + [DataContract] + public class Rule : Selector { public Rule() { Inputs = Enumerable.Empty(); } - - [DataMember] public string MemberName { get; set; } - [DataMember] public string Operator { get; set; } + [DataMember] public object TargetValue { get; set; } - [DataMember] public IList Rules { get; set; } - [DataMember] public IEnumerable Inputs { get; set; } + [DataMember] public Rule EnumerableFilter { get; set; } + [DataMember] public Selector EnumerableValueExpression { get; set; } + [DataMember] public IList Rules { get; set; } + [DataMember] public IEnumerable Inputs { get; set; } public static Rule operator |(Rule lhs, Rule rhs) {