diff --git a/src/FSharpTypes/FSharpTypes.fsproj b/src/FSharpTypes/FSharpTypes.fsproj index 186eabb173..2edfa524b2 100644 --- a/src/FSharpTypes/FSharpTypes.fsproj +++ b/src/FSharpTypes/FSharpTypes.fsproj @@ -1,13 +1,17 @@  - net8.0 true 8.0 + net8.0;net7.0;net9.0 + + + + diff --git a/src/FSharpTypes/Library.fs b/src/FSharpTypes/Library.fs index 498c8ab017..24822290ab 100644 --- a/src/FSharpTypes/Library.fs +++ b/src/FSharpTypes/Library.fs @@ -1,6 +1,9 @@ module FSharpTypes open System +open System.Linq.Expressions +open Marten.Testing.Documents +open Microsoft.FSharp.Linq.RuntimeHelpers type OrderId = Id of Guid @@ -14,3 +17,24 @@ type RecordTypeOrderId = { Part1: string; Part2: string } type ArbitraryClass() = member this.Value = "ok" + +let rec stripFSharpFunc (expression: Expression) = + match expression with + | :? MethodCallExpression as callExpression when callExpression.Method.Name = "ToFSharpFunc" -> + stripFSharpFunc callExpression.Arguments.[0] + | _ -> expression + +let toLinqExpression expr = + expr + |> LeafExpressionConverter.QuotationToExpression + |> stripFSharpFunc + |> unbox>> +let greaterThanWithFsharpDateOption = + <@ fun (o1: Target) -> o1.FSharpDateTimeOffsetOption >= Some DateTimeOffset.UtcNow @> |> toLinqExpression +let lesserThanWithFsharpDateOption = <@ (fun (o1: Target) -> o1.FSharpDateTimeOffsetOption <= Some DateTimeOffset.UtcNow ) @> |> toLinqExpression +let greaterThanWithFsharpDecimalOption = <@ (fun (o1: Target) -> o1.FSharpDecimalOption >= Some 5m ) @> |> toLinqExpression +let lesserThanWithFsharpDecimalOption = <@ (fun (o1: Target) -> o1.FSharpDecimalOption <= Some 5m ) @> |> toLinqExpression +let greaterThanWithFsharpStringOption = <@ (fun (o1: Target) -> o1.FSharpStringOption >= Some "MyString" ) @> |> toLinqExpression +let lesserThanWithFsharpStringOption = <@ (fun (o1: Target) -> o1.FSharpStringOption <= Some "MyString" ) @> |> toLinqExpression + + diff --git a/src/LinqTests/Acceptance/Support/DefaultQueryFixture.cs b/src/LinqTests/Acceptance/Support/DefaultQueryFixture.cs index c508c97db6..873e926025 100644 --- a/src/LinqTests/Acceptance/Support/DefaultQueryFixture.cs +++ b/src/LinqTests/Acceptance/Support/DefaultQueryFixture.cs @@ -1,3 +1,4 @@ +using System.Text.Json.Serialization; using Marten; using Marten.Services; using Marten.Testing.Documents; @@ -11,7 +12,6 @@ public DefaultQueryFixture() { Store = ProvisionStore("linq_querying"); - DuplicatedFieldStore = ProvisionStore("duplicate_fields", o => { o.Schema.For() @@ -25,6 +25,30 @@ public DefaultQueryFixture() .Duplicate(x => x.NumberArray); }); + FSharpFriendlyStore = ProvisionStore("fsharp_linq_querying", options => + { + options.RegisterFSharpOptionValueTypes(); + var serializerOptions = JsonFSharpOptions.Default().WithUnwrapOption().ToJsonSerializerOptions(); + options.UseSystemTextJsonForSerialization(serializerOptions); + }, isFsharpTest: true); + + FSharpFriendlyStoreWithDuplicatedField = ProvisionStore("fsharp_duplicated_fields", options => + { + options.Schema.For() + .Duplicate(x => x.Number) + .Duplicate(x => x.Long) + .Duplicate(x => x.String) + .Duplicate(x => x.Date) + .Duplicate(x => x.Double) + .Duplicate(x => x.Flag) + .Duplicate(x => x.Color) + .Duplicate(x => x.NumberArray); + + options.RegisterFSharpOptionValueTypes(); + var serializerOptions = JsonFSharpOptions.Default().WithUnwrapOption().ToJsonSerializerOptions(); + options.UseSystemTextJsonForSerialization(serializerOptions); + }, isFsharpTest: true); + SystemTextJsonStore = ProvisionStore("stj_linq", o => { o.Serializer(); @@ -35,5 +59,9 @@ public DefaultQueryFixture() public DocumentStore DuplicatedFieldStore { get; set; } + public DocumentStore FSharpFriendlyStore { get; set; } + public DocumentStore FSharpFriendlyStoreWithDuplicatedField { get; set; } + public DocumentStore Store { get; set; } } + diff --git a/src/LinqTests/Acceptance/Support/LinqTestContext.cs b/src/LinqTests/Acceptance/Support/LinqTestContext.cs index 1c5d80a8de..e4fcef9afd 100644 --- a/src/LinqTests/Acceptance/Support/LinqTestContext.cs +++ b/src/LinqTests/Acceptance/Support/LinqTestContext.cs @@ -103,17 +103,27 @@ public static IEnumerable GetDescriptions() return _descriptions.Select(x => new object[] { x }); } - protected async Task assertTestCase(string description, IDocumentStore store) + protected async Task assertTestCaseWithDocuments(string description, IDocumentStore store, Target[] documents) { var index = _descriptions.IndexOf(description); var testCase = testCases[index]; await using var session = store.QuerySession(); - var logger = new TestOutputMartenLogger(TestOutput); + session.Logger = logger; - await testCase.Compare(session, Fixture.Documents, logger); + await testCase.Compare(session, documents, logger); + } + + protected Task assertTestCase(string description, IDocumentStore store) + { + return assertTestCaseWithDocuments(description, store, Fixture.Documents); + } + + protected Task assertFSharpTestCase(string description, IDocumentStore store) + { + return assertTestCaseWithDocuments(description, store, Fixture.FSharpDocuments); } } diff --git a/src/LinqTests/Acceptance/Support/TargetSchemaFixture.cs b/src/LinqTests/Acceptance/Support/TargetSchemaFixture.cs index 8248c38575..b0c83b1086 100644 --- a/src/LinqTests/Acceptance/Support/TargetSchemaFixture.cs +++ b/src/LinqTests/Acceptance/Support/TargetSchemaFixture.cs @@ -9,7 +9,12 @@ namespace LinqTests.Acceptance.Support; public abstract class TargetSchemaFixture: IDisposable { + /* + * Newtonsoft.Json does not support saving Discriminated Unions unwrapped (included f# options) which causes serialization-related errors. + * We must therefore only include F# data in F#-related tests to avoid false negatives. + */ public readonly Target[] Documents = Target.GenerateRandomData(1000).ToArray(); + public readonly Target[] FSharpDocuments = Target.GenerateRandomData(1000, includeFSharpUnionTypes: true).ToArray(); private readonly IList _stores = new List(); @@ -21,7 +26,7 @@ public void Dispose() } } - internal DocumentStore ProvisionStore(string schema, Action configure = null) + internal DocumentStore ProvisionStore(string schema, Action configure = null, bool isFsharpTest = false) { var store = DocumentStore.For(x => { @@ -33,7 +38,15 @@ internal DocumentStore ProvisionStore(string schema, Action config store.Advanced.Clean.CompletelyRemoveAll(); - store.BulkInsert(Documents); + + if (isFsharpTest) + { + store.BulkInsert(FSharpDocuments); + } + else + { + store.BulkInsert(Documents); + } _stores.Add(store); diff --git a/src/LinqTests/Acceptance/where_clauses_fsharp.cs b/src/LinqTests/Acceptance/where_clauses_fsharp.cs new file mode 100644 index 0000000000..9498390f97 --- /dev/null +++ b/src/LinqTests/Acceptance/where_clauses_fsharp.cs @@ -0,0 +1,48 @@ +using System; +using System.Threading.Tasks; +using LinqTests.Acceptance.Support; +using Microsoft.FSharp.Core; +using Xunit.Abstractions; + +namespace LinqTests.Acceptance; + +public class where_clauses_fsharp: LinqTestContext +{ + public where_clauses_fsharp(DefaultQueryFixture fixture, ITestOutputHelper output) : base(fixture) + { + TestOutput = output; + } + + static where_clauses_fsharp() + { + + @where(x => x.FSharpBoolOption == FSharpOption.Some(true)); + @where(x => x.FSharpBoolOption == FSharpOption.Some(false)); + @where(x => x.FSharpDateOption == FSharpOption.Some(DateTime.Now)); + @where(x => x.FSharpIntOption == FSharpOption.Some(300)); + @where(x => x.FSharpStringOption == FSharpOption.Some("My String")); + @where(x => x.FSharpLongOption == FSharpOption.Some(5_000_000)); + + //Comparing options is not a valid syntax in C#, we therefore define these expressions in F# + @where(FSharpTypes.greaterThanWithFsharpDateOption); + @where(FSharpTypes.lesserThanWithFsharpDateOption); + @where(FSharpTypes.greaterThanWithFsharpStringOption); + @where(FSharpTypes.lesserThanWithFsharpStringOption); + @where(FSharpTypes.greaterThanWithFsharpDecimalOption); + @where(FSharpTypes.lesserThanWithFsharpDecimalOption); + } + + [Theory] + [MemberData(nameof(GetDescriptions))] + public Task run_query(string description) + { + return assertFSharpTestCase(description, Fixture.FSharpFriendlyStore); + } + + [Theory] + [MemberData(nameof(GetDescriptions))] + public Task with_duplicated_fields(string description) + { + return assertFSharpTestCase(description, Fixture.FSharpFriendlyStoreWithDuplicatedField); + } +} diff --git a/src/LinqTests/LinqTests.csproj b/src/LinqTests/LinqTests.csproj index 088ee5055a..a2855f07db 100644 --- a/src/LinqTests/LinqTests.csproj +++ b/src/LinqTests/LinqTests.csproj @@ -6,15 +6,17 @@ + + - + all runtime; build; native; contentfiles; analyzers @@ -55,9 +57,6 @@ Documents\StringDoc.cs - - Documents\Target.cs - Documents\TargetIntId.cs diff --git a/src/LinqTests/Operators/is_one_of_operator.cs b/src/LinqTests/Operators/is_one_of_operator.cs index cb4b89c7dc..7e113e865b 100644 --- a/src/LinqTests/Operators/is_one_of_operator.cs +++ b/src/LinqTests/Operators/is_one_of_operator.cs @@ -2,15 +2,18 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Text.Json.Serialization; using Marten; using Marten.Testing.Documents; using Marten.Testing.Harness; +using Microsoft.FSharp.Core; using Shouldly; namespace LinqTests.Operators; public class is_one_of_operator: IntegrationContext { + public static TheoryData>>> SupportedIsOneOfWithIntArray = new() { @@ -95,6 +98,36 @@ public class is_one_of_operator: IntegrationContext validStrings => x => !x.String.In(validStrings) }; + public static TheoryData[], Expression>>> SupportedIsOneOfWithFsharpGuidOptionArray = + new() + { + validGuids => x => x.FSharpGuidOption.IsOneOf(validGuids), + validGuids => x => x.FSharpGuidOption.In(validGuids) + }; + + public static TheoryData[], Expression>>> SupportedNotIsOneOfWithFsharpGuidOptionArray = + new() + { + validGuids => x => !x.FSharpGuidOption.IsOneOf(validGuids), + validGuids => x => !x.FSharpGuidOption.In(validGuids) + }; + + public static TheoryData>, Expression>>> SupportedIsOneOfWithFsharpGuidOptionList = + new() + { + validGuids => x => x.FSharpGuidOption.IsOneOf(validGuids), + validGuids => x => x.FSharpGuidOption.In(validGuids) + }; + + public static TheoryData>, Expression>>> SupportedNotIsOneOfWithFsharpGuidOptionList = + new() + { + validGuids => x => !x.FSharpGuidOption.IsOneOf(validGuids), + validGuids => x => !x.FSharpGuidOption.In(validGuids) + }; + + + [Theory] [MemberData(nameof(SupportedIsOneOfWithIntArray))] public void can_query_against_integers(Func>> isOneOf) => @@ -110,9 +143,20 @@ public void can_query_against_guids(Func>> public void can_query_against_strings(Func>> isOneOf) => can_query_against_array(isOneOf, x => x.String); + [Theory] + [MemberData(nameof(SupportedIsOneOfWithFsharpGuidOptionArray))] + public void can_query_against_fsharp_guid_option_array(Func[], Expression>> isOneOf) => + can_query_against_array(isOneOf, x => x.FSharpGuidOption); + + [Theory] + [MemberData(nameof(SupportedIsOneOfWithFsharpGuidOptionArray))] + public void can_query_against_fsharp_guid_option_array_with_unwrapped_guid(Func[], Expression>> isOneOf) => + can_query_against_array(isOneOf, x => x.FSharpGuidOption); + private void can_query_against_array(Func>> isOneOf, Func select) { - var targets = Target.GenerateRandomData(100).ToArray(); + + var targets = Target.GenerateRandomData(100, true).ToArray(); theStore.BulkInsert(targets); var validValues = targets.Select(select).Distinct().Take(3).ToArray(); @@ -147,12 +191,17 @@ public void can_query_against_guids_with_not_operator(Func>> notIsOneOf) => can_query_against_array_with_not_operator(notIsOneOf, x => x.String); + [Theory] + [MemberData(nameof(SupportedNotIsOneOfWithFsharpGuidOptionArray))] + public void can_query_against_fsharp_guid_option_with_not_operator(Func[], Expression>> notIsOneOf) + => can_query_against_array_with_not_operator(notIsOneOf, x => x.FSharpGuidOption); + private void can_query_against_array_with_not_operator( Func>> notIsOneOf, Func select ) { - var targets = Target.GenerateRandomData(100).ToArray(); + var targets = Target.GenerateRandomData(100, true).ToArray(); theStore.BulkInsert(targets); var validValues = targets.Select(select).Distinct().Take(3).ToArray(); @@ -184,9 +233,14 @@ public void can_query_against_guids_list(Func, Expression, Expression>> isOneOf) => can_query_against_list(isOneOf, x => x.String); + [Theory] + [MemberData(nameof(SupportedIsOneOfWithFsharpGuidOptionList))] + public void can_query_against_fsharp_guid_option_list(Func>, Expression>> isOneOf) + => can_query_against_list(isOneOf, x => x.FSharpGuidOption); + private void can_query_against_list(Func, Expression>> isOneOf, Func select) { - var targets = Target.GenerateRandomData(100).ToArray(); + var targets = Target.GenerateRandomData(100, true).ToArray(); theStore.BulkInsert(targets); var validValues = targets.Select(select).Distinct().Take(3).ToList(); @@ -223,12 +277,18 @@ public void can_query_against_strings_with_not_operator_list( Func, Expression>> notIsOneOf) => can_query_against_list_with_not_operator(notIsOneOf, x => x.String); + [Theory] + [MemberData(nameof(SupportedNotIsOneOfWithFsharpGuidOptionList))] + public void can_query_against_fsharp_guid_option_with_not_operator_list( + Func>, Expression>> notIsOneOf) => + can_query_against_list_with_not_operator(notIsOneOf, x => x.FSharpGuidOption); + private void can_query_against_list_with_not_operator( Func, Expression>> notIsOneOf, Func select ) { - var targets = Target.GenerateRandomData(100).ToArray(); + var targets = Target.GenerateRandomData(100, true).ToArray(); theStore.BulkInsert(targets); var validValues = targets.Select(select).Distinct().Take(3).ToList(); @@ -249,5 +309,13 @@ Func select public is_one_of_operator(DefaultStoreFixture fixture): base(fixture) { + StoreOptions(_ => + { + //_.Logger(new ConsoleMartenLogger()); + _.RegisterValueType(typeof(FSharpOption)); + _.DisableNpgsqlLogging = false; + var serializerOptions = JsonFSharpOptions.Default().WithUnwrapOption().ToJsonSerializerOptions(); + _.UseSystemTextJsonForSerialization(serializerOptions); + }); } } diff --git a/src/LinqTestsTypes/LinqTestsTypes.csproj b/src/LinqTestsTypes/LinqTestsTypes.csproj new file mode 100644 index 0000000000..19310ac724 --- /dev/null +++ b/src/LinqTestsTypes/LinqTestsTypes.csproj @@ -0,0 +1,20 @@ + + + + net7.0 + enable + enable + + + + + Target.cs + + + + + + + + + diff --git a/src/Marten.Testing/Documents/Target.cs b/src/Marten.Testing/Documents/Target.cs index cca6377e4a..1c9dd017ac 100644 --- a/src/Marten.Testing/Documents/Target.cs +++ b/src/Marten.Testing/Documents/Target.cs @@ -1,8 +1,9 @@ -using System; +using System; using System.Collections.Generic; using System.Linq; using System.Text.Json.Serialization; using JasperFx.Core; +using Microsoft.FSharp.Core; #nullable enable namespace Marten.Testing.Documents; @@ -31,21 +32,33 @@ public class Target "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten" }; - public static IEnumerable GenerateRandomData(int number) + public static IEnumerable GenerateRandomData(int number, bool includeFSharpUnionTypes = false) { var i = 0; while (i < number) { - yield return Random(true); + yield return Random(true, includeFSharpUnionTypes); i++; } } - public static Target Random(bool deep = false) + public static Target Random(bool deep = false, bool includeFSharpUnionTypes = false) { var target = new Target(); target.String = _strings[_random.Next(0, 10)]; + + if (includeFSharpUnionTypes) + { + target.FSharpGuidOption = new FSharpOption(Guid.NewGuid()); + target.FSharpIntOption = new FSharpOption(_random.Next(0, 10)); + target.FSharpDateOption = new FSharpOption(DateTime.Now); + target.FSharpDateTimeOffsetOption = new FSharpOption(new DateTimeOffset(DateTime.UtcNow)); + target.FSharpDecimalOption = new FSharpOption(_random.Next(0, 10)); + target.FSharpLongOption = new FSharpOption(_random.Next(0, 10)); + target.FSharpStringOption = new FSharpOption(_strings[_random.Next(0, 10)]); + } + target.PaddedString = " " + target.String + " "; target.AnotherString = _otherStrings[_random.Next(0, 10)]; target.Number = _random.Next(); @@ -98,6 +111,7 @@ public static Target Random(bool deep = false) target.HowLong = TimeSpan.FromSeconds(target.Long); target.Date = DateTime.Today.AddDays(_random.Next(-10000, 10000)); + target.DateOffset = new DateTimeOffset(DateTime.Today.AddDays(_random.Next(-10000, 10000))); if (value > 15) { @@ -141,6 +155,16 @@ public Target() public long Long { get; set; } public string String { get; set; } + + public FSharpOption FSharpGuidOption { get; set; } + public FSharpOption FSharpIntOption { get; set; } + public FSharpOption FSharpBoolOption { get; set; } + public FSharpOption FSharpLongOption { get; set; } + public FSharpOption FSharpDecimalOption { get; set; } + public FSharpOption FSharpStringOption { get; set; } + public FSharpOption FSharpDateOption { get; set; } + public FSharpOption FSharpDateTimeOffsetOption { get; set; } + public string AnotherString { get; set; } public string[] StringArray { get; set; } @@ -194,6 +218,11 @@ public Target() public TimeSpan HowLong { get; set; } } +public class FSharpTarget: Target +{ + +} + public class Address { public Address() diff --git a/src/Marten.Testing/Harness/IntegrationContext.cs b/src/Marten.Testing/Harness/IntegrationContext.cs index 32d4d68357..63c3771166 100644 --- a/src/Marten.Testing/Harness/IntegrationContext.cs +++ b/src/Marten.Testing/Harness/IntegrationContext.cs @@ -5,6 +5,7 @@ using JasperFx.CodeGeneration; using Marten.Events; using Marten.Internal.CodeGeneration; +using Microsoft.FSharp.Core; using Weasel.Core; using Weasel.Postgresql; using Xunit; diff --git a/src/Marten.sln b/src/Marten.sln index 0a0ec93f4c..6208a8914e 100644 --- a/src/Marten.sln +++ b/src/Marten.sln @@ -109,6 +109,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "EventAppenderPerfTester", " EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StressTests", "StressTests\StressTests.csproj", "{C9D33381-3AD3-4005-B854-F04F10EA837F}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LinqTestsTypes", "LinqTestsTypes\LinqTestsTypes.csproj", "{2B5A28C6-2369-4554-B131-C42907E8BA83}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -259,6 +261,10 @@ Global {C9D33381-3AD3-4005-B854-F04F10EA837F}.Debug|Any CPU.Build.0 = Debug|Any CPU {C9D33381-3AD3-4005-B854-F04F10EA837F}.Release|Any CPU.ActiveCfg = Release|Any CPU {C9D33381-3AD3-4005-B854-F04F10EA837F}.Release|Any CPU.Build.0 = Release|Any CPU + {2B5A28C6-2369-4554-B131-C42907E8BA83}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2B5A28C6-2369-4554-B131-C42907E8BA83}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2B5A28C6-2369-4554-B131-C42907E8BA83}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2B5A28C6-2369-4554-B131-C42907E8BA83}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -297,6 +303,8 @@ Global {7F2B0CE1-1365-4AE2-9823-B6DBB1A42C5F} = {91D87D73-EC07-4067-8A64-26A2E4F6BC83} {29E06861-11C7-4917-BA91-162D15538029} = {91D87D73-EC07-4067-8A64-26A2E4F6BC83} {C9D33381-3AD3-4005-B854-F04F10EA837F} = {91D87D73-EC07-4067-8A64-26A2E4F6BC83} + {B1F935FC-55DC-418B-A5DC-6049A5C06871} = {91D87D73-EC07-4067-8A64-26A2E4F6BC83} + {2B5A28C6-2369-4554-B131-C42907E8BA83} = {91D87D73-EC07-4067-8A64-26A2E4F6BC83} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {52B7158F-0A24-47D9-9CF7-3FA94170041A} diff --git a/src/Marten/Linq/Members/FSharpOptionValueTypeMember.cs b/src/Marten/Linq/Members/FSharpOptionValueTypeMember.cs new file mode 100644 index 0000000000..ae64da7ef8 --- /dev/null +++ b/src/Marten/Linq/Members/FSharpOptionValueTypeMember.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using Marten.Internal; +using Marten.Linq.SqlGeneration; +using Microsoft.FSharp.Core; + +namespace Marten.Linq.Members; + +public class FSharpOptionValueTypeMember : ValueTypeMember, TOption>, IComparableMember +{ + public FSharpOptionValueTypeMember(IQueryableMember parent, Casing casing, MemberInfo member, ValueTypeInfo valueTypeInfo) : base(parent, casing, member, valueTypeInfo) + { + + } + +} diff --git a/src/Marten/Linq/Members/StringValueTypeMember.cs b/src/Marten/Linq/Members/StringValueTypeMember.cs index 17f4a789da..1193004b4f 100644 --- a/src/Marten/Linq/Members/StringValueTypeMember.cs +++ b/src/Marten/Linq/Members/StringValueTypeMember.cs @@ -14,7 +14,7 @@ namespace Marten.Linq.Members; -public class StringValueTypeMember: StringMember, IValueTypeMember +public class StringValueTypeMember: StringMember, IValueTypeMember { private readonly Func _valueSource; private readonly IScalarSelectClause _selector; @@ -61,7 +61,7 @@ public override ISqlFragment CreateComparison(string op, ConstantExpression cons return new MemberComparisonFilter(this, def, op); } - public object ConvertFromWrapperArray(object values) + public IEnumerable ConvertFromWrapperArray(IEnumerable values) { if (values is IEnumerable e) { diff --git a/src/Marten/Linq/Members/StrongTypedIdMember.cs b/src/Marten/Linq/Members/StrongTypedIdMember.cs index 56dbb72e19..3e2d823fb2 100644 --- a/src/Marten/Linq/Members/StrongTypedIdMember.cs +++ b/src/Marten/Linq/Members/StrongTypedIdMember.cs @@ -19,7 +19,7 @@ internal interface IStrongTypedIdGeneration ISelectClause BuildSelectClause(string fromObject); } -internal class StrongTypedIdMember: IdMember, IValueTypeMember +internal class StrongTypedIdMember: IdMember, IValueTypeMember { private readonly IStrongTypedIdGeneration _idGeneration; private readonly Func _innerValue; @@ -31,7 +31,7 @@ public StrongTypedIdMember(MemberInfo member, IStrongTypedIdGeneration idGenerat } - public object ConvertFromWrapperArray(object values) + public IEnumerable ConvertFromWrapperArray(IEnumerable values) { if (values is IEnumerable e) { diff --git a/src/Marten/Linq/Members/ValueTypeMember.cs b/src/Marten/Linq/Members/ValueTypeMember.cs index 8c48403e89..c6928067d1 100644 --- a/src/Marten/Linq/Members/ValueTypeMember.cs +++ b/src/Marten/Linq/Members/ValueTypeMember.cs @@ -16,13 +16,13 @@ namespace Marten.Linq.Members; -public interface IValueTypeMember: IQueryableMember +public interface IValueTypeMember: IQueryableMember { - object ConvertFromWrapperArray(object values); + IEnumerable ConvertFromWrapperArray(IEnumerable values); ISelectClause BuildSelectClause(string fromObject); } -public class ValueTypeMember: SimpleCastMember, IValueTypeMember +public class ValueTypeMember: SimpleCastMember, IValueTypeMember { private readonly Func _valueSource; private readonly IScalarSelectClause _selector; @@ -49,7 +49,6 @@ public ValueTypeMember(IQueryableMember parent, Casing casing, MemberInfo member valueTypeInfo.SimpleType); } - } public override void PlaceValueInDictionaryForContainment(Dictionary dict, @@ -75,7 +74,7 @@ public override ISqlFragment CreateComparison(string op, ConstantExpression cons return new MemberComparisonFilter(this, def, op); } - public object ConvertFromWrapperArray(object values) + public IEnumerable ConvertFromWrapperArray(IEnumerable values) { if (values is IEnumerable e) { diff --git a/src/Marten/Linq/Parsing/Methods/IsOneOf.cs b/src/Marten/Linq/Parsing/Methods/IsOneOf.cs index 8ce338d84e..17bc00cd1e 100644 --- a/src/Marten/Linq/Parsing/Methods/IsOneOf.cs +++ b/src/Marten/Linq/Parsing/Methods/IsOneOf.cs @@ -5,6 +5,7 @@ using JasperFx.Core.Reflection; using Marten.Linq.Members; using Marten.Linq.SqlGeneration.Filters; +using Marten.Util; using Weasel.Postgresql; using Weasel.Postgresql.SqlGeneration; @@ -31,13 +32,19 @@ public ISqlFragment Parse(IQueryableMemberCollection memberCollection, IReadOnly { return new EnumIsOneOfWhereFragment(values, options.Serializer().EnumStorage, locator); } - else if (queryableMember is IValueTypeMember valueTypeMember) + else if (queryableMember.IsGenericInterfaceImplementation(typeof(IValueTypeMember<,>))) { - return new IsOneOfFilter(queryableMember, new CommandParameter(valueTypeMember.ConvertFromWrapperArray(values))); + /* Unwrapping is required for nullable value types of the form: System.Nullable`1[ValueTypeTests.StrongTypedId.Issue2Id][] + otherwise we get exceptions such as: Object of type 'System.Nullable`1[ValueTypeTests.StrongTypedId.Issue2Id][]' cannot be converted to type 'System.Collections.Generic.IEnumerable`1[ValueTypeTests.StrongTypedId.Issue2Id]' + */ + var unwrappedValues = values.UnwrapIEnumerableOfNullables(); + var commandParameter = queryableMember.CallGenericInterfaceMethod(typeof(IValueTypeMember<,>), "ConvertFromWrapperArray", unwrappedValues); + return new IsOneOfFilter(queryableMember, new CommandParameter(commandParameter)); } return new IsOneOfFilter(queryableMember, new CommandParameter(values)); } + } internal class IsOneOfFilter: ISqlFragment diff --git a/src/Marten/Linq/Parsing/SelectorVisitor.cs b/src/Marten/Linq/Parsing/SelectorVisitor.cs index 4b8935ea09..232189e728 100644 --- a/src/Marten/Linq/Parsing/SelectorVisitor.cs +++ b/src/Marten/Linq/Parsing/SelectorVisitor.cs @@ -6,6 +6,7 @@ using JasperFx.Core.Reflection; using Marten.Linq.Members; using Marten.Linq.SqlGeneration; +using Marten.Util; namespace Marten.Linq.Parsing; @@ -105,8 +106,8 @@ public void ToScalar(Expression selectClauseSelector) else { _statement.SelectClause = - member is IValueTypeMember valueTypeMember - ? valueTypeMember.BuildSelectClause(_statement.FromObject) + member.IsGenericInterfaceImplementation(typeof(IValueTypeMember<,>)) + ? (ISelectClause)member.CallGenericInterfaceMethod(typeof(IValueTypeMember<,>), "BuildSelectClause", _statement.FromObject) : typeof(DataSelectClause<>).CloseAndBuildAs(_statement.FromObject, member.RawLocator, member.MemberType); diff --git a/src/Marten/Linq/SqlGeneration/IScalarSelectClause.cs b/src/Marten/Linq/SqlGeneration/IScalarSelectClause.cs index 11ae5c03f4..93dd8ee99f 100644 --- a/src/Marten/Linq/SqlGeneration/IScalarSelectClause.cs +++ b/src/Marten/Linq/SqlGeneration/IScalarSelectClause.cs @@ -1,7 +1,7 @@ #nullable enable namespace Marten.Linq.SqlGeneration; -internal interface IScalarSelectClause +public interface IScalarSelectClause { string MemberName { get; } void ApplyOperator(string op); diff --git a/src/Marten/Marten.csproj b/src/Marten/Marten.csproj index a8f6ac0d36..9bcd7f98b5 100644 --- a/src/Marten/Marten.csproj +++ b/src/Marten/Marten.csproj @@ -48,6 +48,7 @@ + diff --git a/src/Marten/StoreOptions.Identity.cs b/src/Marten/StoreOptions.Identity.cs index 911caa8ff3..ac55123924 100644 --- a/src/Marten/StoreOptions.Identity.cs +++ b/src/Marten/StoreOptions.Identity.cs @@ -10,6 +10,7 @@ using Marten.Linq.Members; using Marten.Schema.Identity; using Marten.Schema.Identity.Sequences; +using Microsoft.FSharp.Core; using CombGuidIdGeneration = Marten.Schema.Identity.CombGuidIdGeneration; namespace Marten; @@ -108,6 +109,15 @@ public ValueTypeInfo RegisterValueType(Type type) { valueProperty = type.GetProperties().Where(x => x.Name != "Tag").SingleOrDefaultIfMany(); } + else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(FSharpOption<>)) + { + var innerType = type.GetGenericArguments().Single(); + valueProperty = type.GetProperty("Value"); + var optionBuilder = type.GetMethod("Some", BindingFlags.Static | BindingFlags.Public); + var valueType = new ValueTypeInfo(type, innerType, valueProperty, optionBuilder); + ValueTypes.Add(valueType); + return valueType; + } else { valueProperty = type.GetProperties().SingleOrDefaultIfMany(); @@ -139,5 +149,24 @@ public ValueTypeInfo RegisterValueType(Type type) "Unable to determine either a builder static method or a constructor to use"); } + public void RegisterFSharpOptionValueTypes() + { + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + RegisterValueType(typeof(FSharpOption)); + } + internal List ValueTypes { get; } = new(); } diff --git a/src/Marten/StoreOptions.MemberFactory.cs b/src/Marten/StoreOptions.MemberFactory.cs index 77ec30a92e..c39f218871 100644 --- a/src/Marten/StoreOptions.MemberFactory.cs +++ b/src/Marten/StoreOptions.MemberFactory.cs @@ -11,6 +11,7 @@ using Marten.Linq.Members.ValueCollections; using Marten.Linq.Parsing; using Marten.Schema.Identity; +using Microsoft.FSharp.Core; using Newtonsoft.Json.Linq; using Weasel.Postgresql; @@ -148,15 +149,26 @@ public bool TryResolve(IQueryableMember parent, StoreOptions options, MemberInfo out IQueryableMember? member) { var valueType = options.ValueTypes.FirstOrDefault(x => x.OuterType == memberType); + if (valueType == null) { member = default; return false; } - var baseType = valueType.SimpleType == typeof(string) - ? typeof(StringValueTypeMember<>).MakeGenericType(memberType) - : typeof(ValueTypeMember<,>).MakeGenericType(memberType, valueType.SimpleType); + Type baseType; + if (valueType.OuterType.IsGenericType && valueType.OuterType.GetGenericTypeDefinition() == typeof(FSharpOption<>)) + { + baseType = typeof(FSharpOptionValueTypeMember<>).MakeGenericType(valueType.SimpleType); + } + else if (valueType.SimpleType == typeof(string)) + { + baseType = typeof(StringValueTypeMember<>).MakeGenericType(memberType); + } + else + { + baseType = typeof(ValueTypeMember<,>).MakeGenericType(memberType, valueType.SimpleType); + } member = (IQueryableMember)Activator.CreateInstance(baseType, parent, options.Serializer().Casing, memberInfo, valueType); diff --git a/src/Marten/Util/GenericsExtensions.cs b/src/Marten/Util/GenericsExtensions.cs new file mode 100644 index 0000000000..23b72246a5 --- /dev/null +++ b/src/Marten/Util/GenericsExtensions.cs @@ -0,0 +1,77 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Marten.Util; + +internal static class GenericsExtensions +{ + public static bool IsGenericInterfaceImplementation(this object obj, Type openInterfaceType) + { + return obj.GetType().GetInterfaces().Any(x => x.IsGenericType && openInterfaceType.IsAssignableFrom(x.GetGenericTypeDefinition())); + } + + public static object UnwrapIEnumerableOfNullables(this object obj) + { + var type = obj.GetType(); + if (type.GetInterfaces().Any(i => + i.IsGenericType && + i.GetGenericTypeDefinition() == typeof(IEnumerable<>) && + i.GetGenericArguments()[0].IsGenericType && + i.GetGenericArguments()[0].GetGenericTypeDefinition() == typeof(Nullable<>))) + { + // Get the underlying type of the Nullable + var strongIdtype = type.GetInterfaces() + .First(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)) + .GetGenericArguments()[0] + .GetGenericArguments()[0]; + + // Cast to IEnumerable> and filter for non-null values + var unwrappedValues = ((IEnumerable)obj) + .Cast() + .Where(x => x != null) + .Select(x => x.GetType().GetProperty("Value").GetValue(x)); + + // Create an array of the underlying type + obj = Array.CreateInstance(strongIdtype, unwrappedValues.Count()); + var stronglyTypedValues = unwrappedValues.Select(x => + { + var constructor = strongIdtype.GetConstructor(new[] { x.GetType() }); + object strongTypedId; + if (constructor != null) + { + strongTypedId = constructor.Invoke(new[] { x }); + } + else + { + // Use static builder method if no constructor is found. User can name the builder method anyway they want. + var fromMethod = strongIdtype.GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.ReturnType == strongIdtype && + m.GetParameters().Length == 1 && + m.GetParameters()[0].ParameterType == x.GetType() + ); + if (fromMethod == null) + { + throw new InvalidOperationException($"Type {strongIdtype} does not have a constructor or a static builder method."); + } + strongTypedId = fromMethod.Invoke(null, new[] { x }); + } + return strongTypedId; + }).ToArray(); + Array.Copy(stronglyTypedValues.ToArray(), (Array)obj, unwrappedValues.Count()); + + return obj; + } + + return obj; + } + + public static object CallGenericInterfaceMethod(this object obj, Type openInterfaceType, string methodName, params object[] parameters) + { + return obj.GetType().GetMethod(methodName)?.Invoke(obj, parameters); + } + +}