Skip to content

Commit

Permalink
Support funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
pakrym committed Dec 29, 2023
1 parent c215b17 commit 694aa23
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using Xunit;
using Jab;

Expand Down
68 changes: 68 additions & 0 deletions src/Jab.FunctionalTests.Common/ContainerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,74 @@ public void CanGetMultipleOpenGenericScoped()
partial class CanGetMultipleOpenGenericScopedContainer
{
}

[Fact]
public void SupportsImplicitFunc()
{
SupportsImplicitFuncFactoryContainer c = new();
var transientFunc = c.GetService<Func<IService>>();
var transientFunc2 = c.GetService<Func<IService>>();
var transientService1 = transientFunc();
var transientService2 = transientFunc();

var scope1 = c.CreateScope();
var scopedFunc = scope1.GetService<Func<IService1>>();
var scopedFunc2 = scope1.GetService<Func<IService1>>();
var scopedService1 = scopedFunc();
var scopedService2 = scopedFunc();

var scope2 = c.CreateScope();
var scopedFunc3 = scope2.GetService<Func<IService1>>();
var scopedService3 = scopedFunc3();

var singletonFunc = c.GetService<Func<IService2>>();
var singletonFunc2 = c.GetService<Func<IService2>>();

var singletonService1 = singletonFunc();
var singletonService2 = singletonFunc2();

Assert.Equal(2, c.TransientCount);
Assert.Equal(2, c.ScopedCount);
Assert.Equal(1, c.SingletonCount);

Assert.Same(singletonFunc, singletonFunc2);
Assert.Same(transientFunc, transientFunc2);
Assert.Same(scopedFunc, scopedFunc2);
Assert.NotSame(scopedFunc2, scopedFunc3);

Assert.Same(singletonService1, singletonService2);
Assert.Same(scopedService1, scopedService2);
Assert.NotSame(scopedService1, scopedService3);

Assert.NotSame(transientService1, transientService2);
}

[ServiceProvider(RootServices = new [] { typeof(Func<IService1>) })]
[Transient(typeof(IService), Factory=nameof(TransientFactory))]
[Scoped(typeof(IService1), Factory=nameof(ScopedFactory))]
[Singleton(typeof(IService2), Factory=nameof(SingletonFactory))]
internal partial class SupportsImplicitFuncFactoryContainer
{
internal int TransientCount = 0;
internal int ScopedCount = 0;
internal int SingletonCount = 0;

internal ServiceImplementation TransientFactory()
{
TransientCount++;
return new();
}
internal ServiceImplementation ScopedFactory()
{
ScopedCount++;
return new();
}
internal ServiceImplementation SingletonFactory()
{
SingletonCount++;
return new();
}
}

#region Non-generic member factory with parameters
[Fact]
Expand Down
2 changes: 1 addition & 1 deletion src/Jab/ConstructorCallSite.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

internal record ConstructorCallSite : ServiceCallSite
{
public ConstructorCallSite(ServiceIdentity identity, INamedTypeSymbol implementationType, ServiceCallSite[] parameters, KeyValuePair<IParameterSymbol, ServiceCallSite>[] optionalParameters, ServiceLifetime lifetime, int? reverseIndex, bool? isDisposable)
public ConstructorCallSite(ServiceIdentity identity, INamedTypeSymbol implementationType, ServiceCallSite[] parameters, KeyValuePair<IParameterSymbol, ServiceCallSite>[] optionalParameters, ServiceLifetime lifetime, bool? isDisposable)
: base(identity, implementationType, lifetime, isDisposable)
{
Parameters = parameters;
Expand Down
31 changes: 20 additions & 11 deletions src/Jab/ContainerGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@ private void GenerateCallSiteWithCache(CodeWriter codeWriter, string rootReferen
if (serviceCallSite.Lifetime != ServiceLifetime.Transient)
{
var cacheLocation = GetCacheLocation(serviceCallSite.Identity);
codeWriter.Line($"if ({cacheLocation} == null)");
codeWriter.Line($"lock (this)");
using (codeWriter.Scope($"if ({cacheLocation} == null)"))
var locking = serviceCallSite is not FuncCallSite;
if (locking)
{
GenerateCallSite(
codeWriter,
rootReference,
serviceCallSite,
(w, v) =>
{
w.Line($"{cacheLocation} = {v};");
});
codeWriter.Line($"if ({cacheLocation} == null)");
codeWriter.Line($"lock (this)");
}
GenerateCallSite(
codeWriter,
rootReference,
serviceCallSite,
(w, v) =>
{
w.Line($"{cacheLocation} ??= {v};");
});

if (serviceCallSite.ImplementationType.IsValueType)
{
Expand Down Expand Up @@ -146,6 +147,14 @@ private void GenerateCallSite(CodeWriter codeWriter, string rootReference, Servi
w.Append($")");
});
break;

case FuncCallSite funcCallSite:
valueCallback(codeWriter, w =>
{
w.Append($"() => ");
WriteResolutionCall(codeWriter, funcCallSite.Inner.Identity, "this");
});
break;
case MemberCallSite memberCallSite:
valueCallback(codeWriter, w =>
{
Expand Down
18 changes: 18 additions & 0 deletions src/Jab/FuncCallSite.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace Jab;

internal record FuncCallSite : ServiceCallSite
{
public FuncCallSite(ServiceIdentity identity, ServiceCallSite inner)
: base(identity, identity.Type, GetFuncLifetime(inner.Lifetime), false)
{
Inner = inner;
}

public ServiceCallSite Inner { get; }

private static ServiceLifetime GetFuncLifetime(ServiceLifetime innerLifetime) => innerLifetime switch
{
ServiceLifetime.Scoped => ServiceLifetime.Scoped,
_ => ServiceLifetime.Singleton
};
}
3 changes: 3 additions & 0 deletions src/Jab/KnownTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ internal class KnownTypes

private const string IAsyncDisposableMetadataName = "System.IAsyncDisposable";
private const string IEnumerableMetadataName = "System.Collections.Generic.IEnumerable`1";
private const string FuncMetadataName = "System.Func`1";
private const string IServiceProviderMetadataName = "System.IServiceProvider";
private const string IServiceScopeMetadataName = "Microsoft.Extensions.DependencyInjection.IServiceScope";
private const string IKeyedServiceProviderMetadataName = "Microsoft.Extensions.DependencyInjection.IKeyedServiceProvider";
Expand All @@ -59,6 +60,7 @@ internal class KnownTypes
"Microsoft.Extensions.DependencyInjection.IServiceProviderIsService";

public INamedTypeSymbol IEnumerableType { get; }
public INamedTypeSymbol FuncType { get; }
public INamedTypeSymbol IServiceProviderType { get; }
public INamedTypeSymbol CompositionRootAttributeType { get; }
public INamedTypeSymbol TransientAttributeType { get; }
Expand Down Expand Up @@ -102,6 +104,7 @@ static INamedTypeSymbol GetTypeFromCompilationByMetadataNameOrThrow(Compilation
?? throw new InvalidOperationException($"Type with metadata '{fullyQualifiedMetadataName}' not found");

IEnumerableType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IEnumerableMetadataName);
FuncType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, FuncMetadataName);
IServiceProviderType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IServiceProviderMetadataName);
IServiceScopeType = compilation.GetTypeByMetadataName(IServiceScopeMetadataName);
IAsyncDisposableType = compilation.GetTypeByMetadataName(IAsyncDisposableMetadataName);
Expand Down
33 changes: 32 additions & 1 deletion src/Jab/ServiceProviderBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ private void EmitTypeDiagnostics(ITypeSymbol typeSymbol)
return TryCreateSpecial(serviceType, name, context) ??
TryCreateExact(serviceType, name, null, context) ??
TryCreateEnumerable(serviceType, name, context) ??
TryCreateFunc(serviceType, name, context) ??
TryCreateGeneric(serviceType, name, context);
}
finally
Expand Down Expand Up @@ -432,6 +433,37 @@ static ServiceLifetime GetCommonLifetime(IEnumerable<ServiceCallSite> callSites)
return null;
}


private ServiceCallSite? TryCreateFunc(ITypeSymbol serviceType, string? name, ServiceResolutionContext context)
{
if (serviceType is INamedTypeSymbol { IsGenericType: true } genericType &&
SymbolEqualityComparer.Default.Equals(genericType.ConstructedFrom, _knownTypes.FuncType))
{
var identity = new ServiceIdentity(genericType, name, null);

if (context.CallSiteCache.TryGet(identity, out var callSite))
{
return callSite;
}

var innerType = genericType.TypeArguments[0];
var inner = GetCallSite(innerType, name, context);

if (inner == null)
{
return null;
}

callSite = new FuncCallSite(identity, inner);

context.CallSiteCache.Add(callSite);

return callSite;
}

return null;
}

private ServiceCallSite? TryCreateExact(
ITypeSymbol serviceType,
string? name,
Expand Down Expand Up @@ -612,7 +644,6 @@ private ServiceCallSite CreateConstructorCallSite(
parameters.ToArray(),
namedParameters.ToArray(),
registration.Lifetime,
identity.ReverseIndex,
// TODO: this can be optimized to avoid check for all the types
isDisposable: null
);
Expand Down

0 comments on commit 694aa23

Please sign in to comment.