diff --git a/src/Marten/Linq/Members/StringValueTypeMember.cs b/src/Marten/Linq/Members/StringValueTypeMember.cs index 2280c27d8f..17f4a789da 100644 --- a/src/Marten/Linq/Members/StringValueTypeMember.cs +++ b/src/Marten/Linq/Members/StringValueTypeMember.cs @@ -23,10 +23,21 @@ public StringValueTypeMember(IQueryableMember parent, Casing casing, MemberInfo { _valueSource = valueTypeInfo.ValueAccessor(); var converter = valueTypeInfo.CreateConverter(); - _selector = typeof(ValueTypeSelectClause<,>).CloseAndBuildAs( - TypedLocator, converter, - valueTypeInfo.OuterType, - typeof(string)); + + if (typeof(T).IsClass) + { + _selector = typeof(ClassValueTypeSelectClause<,>).CloseAndBuildAs( + TypedLocator, converter, + valueTypeInfo.OuterType, + typeof(string)); + } + else + { + _selector = typeof(ValueTypeSelectClause<,>).CloseAndBuildAs( + TypedLocator, converter, + valueTypeInfo.OuterType, + typeof(string)); + } } public override void PlaceValueInDictionaryForContainment(Dictionary dict, ConstantExpression constant) diff --git a/src/Marten/Linq/Members/ValueTypeMember.cs b/src/Marten/Linq/Members/ValueTypeMember.cs index 8ae902e2ec..8c48403e89 100644 --- a/src/Marten/Linq/Members/ValueTypeMember.cs +++ b/src/Marten/Linq/Members/ValueTypeMember.cs @@ -33,10 +33,23 @@ public ValueTypeMember(IQueryableMember parent, Casing casing, MemberInfo member { _valueSource = valueTypeInfo.ValueAccessor(); var converter = valueTypeInfo.CreateConverter(); - _selector = typeof(ValueTypeSelectClause<,>).CloseAndBuildAs( - TypedLocator, converter, - valueTypeInfo.OuterType, - valueTypeInfo.SimpleType); + + if (typeof(TOuter).IsClass) + { + _selector = typeof(ClassValueTypeSelectClause<,>).CloseAndBuildAs( + TypedLocator, converter, + valueTypeInfo.OuterType, + valueTypeInfo.SimpleType); + } + else + { + _selector = typeof(ValueTypeSelectClause<,>).CloseAndBuildAs( + TypedLocator, converter, + valueTypeInfo.OuterType, + valueTypeInfo.SimpleType); + } + + } public override void PlaceValueInDictionaryForContainment(Dictionary dict, @@ -56,7 +69,8 @@ public override ISqlFragment CreateComparison(string op, ConstantExpression cons return op == "=" ? new IsNullFilter(this) : new IsNotNullFilter(this); } - var value = _valueSource(constant.Value.As()); + var value = constant.Value is TInner ? (TInner)constant.Value : _valueSource(constant.Value.As()); + var def = new CommandParameter(Expression.Constant(value)); return new MemberComparisonFilter(this, def, op); } diff --git a/src/Marten/Linq/SqlGeneration/ValueTypeSelectClause.cs b/src/Marten/Linq/SqlGeneration/ValueTypeSelectClause.cs index a51c7f5b7c..0609792a6d 100644 --- a/src/Marten/Linq/SqlGeneration/ValueTypeSelectClause.cs +++ b/src/Marten/Linq/SqlGeneration/ValueTypeSelectClause.cs @@ -124,3 +124,105 @@ public override string ToString() return $"Data from {FromObject}"; } } + +public class ClassValueTypeSelectClause: ISelectClause, IScalarSelectClause, IModifyableFromObject, + ISelector +{ + public ClassValueTypeSelectClause(string memberName, Func converter) + { + MemberName = memberName; + Converter = converter; + } + + public Func Converter { get; } + + public string MemberName { get; set; } + + public ISelectClause CloneToOtherTable(string tableName) + { + return new ClassValueTypeSelectClause(MemberName, Converter) + { + FromObject = tableName, MemberName = MemberName + }; + } + + public void ApplyOperator(string op) + { + MemberName = $"{op}({MemberName})"; + } + + public ISelectClause CloneToDouble() + { + throw new NotSupportedException(); + } + + public Type SelectedType => typeof(TOuter); + + public string FromObject { get; set; } + + public void Apply(ICommandBuilder sql) + { + if (MemberName.IsNotEmpty()) + { + sql.Append("select "); + sql.Append(MemberName); + sql.Append(" as data from "); + } + + sql.Append(FromObject); + sql.Append(" as d"); + } + + public string[] SelectFields() + { + return new[] { MemberName }; + } + + public ISelector BuildSelector(IMartenSession session) + { + return this; + } + + public IQueryHandler BuildHandler(IMartenSession session, ISqlFragment statement, + ISqlFragment currentStatement) + { + if (typeof(TResult).CanBeCastTo>()) + { + return (IQueryHandler)new ListQueryHandler(statement, this); + } + + return (IQueryHandler)new ListQueryHandler(statement, this); + } + + public ISelectClause UseStatistics(QueryStatistics statistics) + { + return new StatsSelectClause(this, statistics); + } + + public TOuter Resolve(DbDataReader reader) + { + if (reader.IsDBNull(0)) + { + return default(TOuter); + } + + var inner = reader.GetFieldValue(0); + return Converter(inner); + } + + public async Task ResolveAsync(DbDataReader reader, CancellationToken token) + { + if (await reader.IsDBNullAsync(0, token).ConfigureAwait(false)) + { + return default(TOuter); + } + + var inner = await reader.GetFieldValueAsync(0, token).ConfigureAwait(false); + return Converter(inner); + } + + public override string ToString() + { + return $"Data from {FromObject}"; + } +} diff --git a/src/ValueTypeTests/Bugs/querying_by_value_types.cs b/src/ValueTypeTests/Bugs/querying_by_value_types.cs new file mode 100644 index 0000000000..ce3697fd8e --- /dev/null +++ b/src/ValueTypeTests/Bugs/querying_by_value_types.cs @@ -0,0 +1,62 @@ +using System; +using System.Threading.Tasks; +using Marten; +using Marten.Testing.Harness; +using Vogen; +using Shouldly; + +namespace ValueTypeTests.Bugs; + +public class querying_by_value_types : BugIntegrationContext +{ + [Fact] + public async Task run_queries() + { + StoreOptions(opts => + { + opts.RegisterValueType(typeof(EmailAddress)); + opts.RegisterValueType(typeof(Age)); + }); + + var customer = new Customer + { + Email = EmailAddress.From("example@me.com"), + Age = Age.From(25) + }; + + theSession.Store(customer); + + await theSession.SaveChangesAsync(); + + var loadedCustomer = await theSession.LoadAsync(customer.Id); + + loadedCustomer.Email.ShouldNotBeNull(); + loadedCustomer.Email.Value.ShouldBe("example@me.com"); + + + var queryByAge = await theSession.Query() + .FirstOrDefaultAsync(x => x.Age == 25); + + queryByAge.ShouldNotBeNull(); + + var queryByEmail = await theSession.Query() + .FirstOrDefaultAsync(x => x.Email == customer.Email); + + queryByEmail.ShouldNotBeNull(); + } +} + +[ValueObject] +public partial record EmailAddress; + +[ValueObject] +public partial record Age; + +public class Customer +{ + public Guid Id { get; set; } + + public required EmailAddress Email { get; init; } + + public required Age Age { get; init; } +}