Skip to content

Commit

Permalink
Improvements to value type LINQ querying. Closes GH-3450
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremydmiller committed Oct 9, 2024
1 parent c82ea9e commit aef686f
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 9 deletions.
19 changes: 15 additions & 4 deletions src/Marten/Linq/Members/StringValueTypeMember.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@ public StringValueTypeMember(IQueryableMember parent, Casing casing, MemberInfo
{
_valueSource = valueTypeInfo.ValueAccessor<T, string>();
var converter = valueTypeInfo.CreateConverter<T, string>();
_selector = typeof(ValueTypeSelectClause<,>).CloseAndBuildAs<IScalarSelectClause>(
TypedLocator, converter,
valueTypeInfo.OuterType,
typeof(string));

if (typeof(T).IsClass)
{
_selector = typeof(ClassValueTypeSelectClause<,>).CloseAndBuildAs<IScalarSelectClause>(
TypedLocator, converter,
valueTypeInfo.OuterType,
typeof(string));
}
else
{
_selector = typeof(ValueTypeSelectClause<,>).CloseAndBuildAs<IScalarSelectClause>(
TypedLocator, converter,
valueTypeInfo.OuterType,
typeof(string));
}
}

public override void PlaceValueInDictionaryForContainment(Dictionary<string, object> dict, ConstantExpression constant)
Expand Down
24 changes: 19 additions & 5 deletions src/Marten/Linq/Members/ValueTypeMember.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,23 @@ public ValueTypeMember(IQueryableMember parent, Casing casing, MemberInfo member
{
_valueSource = valueTypeInfo.ValueAccessor<TOuter, TInner>();
var converter = valueTypeInfo.CreateConverter<TOuter, TInner>();
_selector = typeof(ValueTypeSelectClause<,>).CloseAndBuildAs<IScalarSelectClause>(
TypedLocator, converter,
valueTypeInfo.OuterType,
valueTypeInfo.SimpleType);

if (typeof(TOuter).IsClass)
{
_selector = typeof(ClassValueTypeSelectClause<,>).CloseAndBuildAs<IScalarSelectClause>(
TypedLocator, converter,
valueTypeInfo.OuterType,
valueTypeInfo.SimpleType);
}
else
{
_selector = typeof(ValueTypeSelectClause<,>).CloseAndBuildAs<IScalarSelectClause>(
TypedLocator, converter,
valueTypeInfo.OuterType,
valueTypeInfo.SimpleType);
}


}

public override void PlaceValueInDictionaryForContainment(Dictionary<string, object> dict,
Expand All @@ -56,7 +69,8 @@ public override ISqlFragment CreateComparison(string op, ConstantExpression cons
return op == "=" ? new IsNullFilter(this) : new IsNotNullFilter(this);
}

var value = _valueSource(constant.Value.As<TOuter>());
var value = constant.Value is TInner ? (TInner)constant.Value : _valueSource(constant.Value.As<TOuter>());

var def = new CommandParameter(Expression.Constant(value));
return new MemberComparisonFilter(this, def, op);
}
Expand Down
102 changes: 102 additions & 0 deletions src/Marten/Linq/SqlGeneration/ValueTypeSelectClause.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,105 @@ public override string ToString()
return $"Data from {FromObject}";
}
}

public class ClassValueTypeSelectClause<TOuter, TInner>: ISelectClause, IScalarSelectClause, IModifyableFromObject,
ISelector<TOuter>
{
public ClassValueTypeSelectClause(string memberName, Func<TInner, TOuter> converter)
{
MemberName = memberName;
Converter = converter;
}

public Func<TInner, TOuter> Converter { get; }

public string MemberName { get; set; }

public ISelectClause CloneToOtherTable(string tableName)
{
return new ClassValueTypeSelectClause<TOuter, TInner>(MemberName, Converter)
{
FromObject = tableName, MemberName = MemberName
};
}

public void ApplyOperator(string op)
{
MemberName = $"{op}({MemberName})";
}

public ISelectClause CloneToDouble()
{
throw new NotSupportedException();
}

public Type SelectedType => typeof(TOuter);

public string FromObject { get; set; }

public void Apply(ICommandBuilder sql)
{
if (MemberName.IsNotEmpty())
{
sql.Append("select ");
sql.Append(MemberName);
sql.Append(" as data from ");
}

sql.Append(FromObject);
sql.Append(" as d");
}

public string[] SelectFields()
{
return new[] { MemberName };
}

public ISelector BuildSelector(IMartenSession session)
{
return this;
}

public IQueryHandler<TResult> BuildHandler<TResult>(IMartenSession session, ISqlFragment statement,
ISqlFragment currentStatement)
{
if (typeof(TResult).CanBeCastTo<IEnumerable<TOuter>>())
{
return (IQueryHandler<TResult>)new ListQueryHandler<TOuter>(statement, this);
}

return (IQueryHandler<TResult>)new ListQueryHandler<TOuter?>(statement, this);
}

public ISelectClause UseStatistics(QueryStatistics statistics)
{
return new StatsSelectClause<TOuter>(this, statistics);
}

public TOuter Resolve(DbDataReader reader)
{
if (reader.IsDBNull(0))
{
return default(TOuter);
}

var inner = reader.GetFieldValue<TInner>(0);
return Converter(inner);
}

public async Task<TOuter> ResolveAsync(DbDataReader reader, CancellationToken token)
{
if (await reader.IsDBNullAsync(0, token).ConfigureAwait(false))
{
return default(TOuter);
}

var inner = await reader.GetFieldValueAsync<TInner>(0, token).ConfigureAwait(false);
return Converter(inner);
}

public override string ToString()
{
return $"Data from {FromObject}";
}
}
62 changes: 62 additions & 0 deletions src/ValueTypeTests/Bugs/querying_by_value_types.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using System;
using System.Threading.Tasks;
using Marten;
using Marten.Testing.Harness;
using Vogen;
using Shouldly;

namespace ValueTypeTests.Bugs;

public class querying_by_value_types : BugIntegrationContext
{
[Fact]
public async Task run_queries()
{
StoreOptions(opts =>
{
opts.RegisterValueType(typeof(EmailAddress));
opts.RegisterValueType(typeof(Age));
});

var customer = new Customer
{
Email = EmailAddress.From("[email protected]"),
Age = Age.From(25)
};

theSession.Store(customer);

await theSession.SaveChangesAsync();

var loadedCustomer = await theSession.LoadAsync<Customer>(customer.Id);

loadedCustomer.Email.ShouldNotBeNull();
loadedCustomer.Email.Value.ShouldBe("[email protected]");


var queryByAge = await theSession.Query<Customer>()
.FirstOrDefaultAsync(x => x.Age == 25);

queryByAge.ShouldNotBeNull();

var queryByEmail = await theSession.Query<Customer>()
.FirstOrDefaultAsync(x => x.Email == customer.Email);

queryByEmail.ShouldNotBeNull();
}
}

[ValueObject<string>]
public partial record EmailAddress;

[ValueObject<int>]
public partial record Age;

public class Customer
{
public Guid Id { get; set; }

public required EmailAddress Email { get; init; }

public required Age Age { get; init; }
}

0 comments on commit aef686f

Please sign in to comment.