diff --git a/Directory.Build.targets b/Directory.Build.targets index a4ce6b9..ab2cab4 100644 --- a/Directory.Build.targets +++ b/Directory.Build.targets @@ -3,7 +3,7 @@ 8.0.1 8.0.0 7.1.2 - 7.0.6 + 7.0.8 diff --git a/src/Duende.AccessTokenManagement.OpenIdConnect/AuthenticateResultCache.cs b/src/Duende.AccessTokenManagement.OpenIdConnect/AuthenticateResultCache.cs new file mode 100644 index 0000000..d2149db --- /dev/null +++ b/src/Duende.AccessTokenManagement.OpenIdConnect/AuthenticateResultCache.cs @@ -0,0 +1,14 @@ +// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. +// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. + +using Microsoft.AspNetCore.Authentication; +using System.Collections.Generic; + +/// +/// Per-request cache so that if SignInAsync is used, we won't re-read the old/cached AuthenticateResult from the handler. +/// This requires this service to be added as scoped to the DI system. +/// Be VERY CAREFUL to not accidentally capture this service for longer than the appropriate DI scope - e.g., in an HttpClient. +/// +internal class AuthenticateResultCache: Dictionary +{ +} \ No newline at end of file diff --git a/src/Duende.AccessTokenManagement.OpenIdConnect/AuthenticationSessionUserTokenStore.cs b/src/Duende.AccessTokenManagement.OpenIdConnect/AuthenticationSessionUserTokenStore.cs index 382b289..dc4d62b 100755 --- a/src/Duende.AccessTokenManagement.OpenIdConnect/AuthenticationSessionUserTokenStore.cs +++ b/src/Duende.AccessTokenManagement.OpenIdConnect/AuthenticationSessionUserTokenStore.cs @@ -4,10 +4,11 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Http; using System; -using System.Collections.Generic; using System.Security.Claims; using System.Threading.Tasks; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.DependencyInjection; +using IdentityModel; namespace Duende.AccessTokenManagement.OpenIdConnect { @@ -20,10 +21,6 @@ public class AuthenticationSessionUserAccessTokenStore : IUserTokenStore private readonly IStoreTokensInAuthenticationProperties _tokensInProps; private readonly ILogger _logger; - // per-request cache so that if SignInAsync is used, we won't re-read the old/cached AuthenticateResult from the handler - // this requires this service to be added as scoped to the DI system - private readonly Dictionary _cache = new Dictionary(); - /// /// ctor /// @@ -46,10 +43,14 @@ public async Task GetTokenAsync( UserTokenRequestParameters? parameters = null) { parameters ??= new(); - + // Resolve the cache here because it needs to have a per-request + // lifetime. Sometimes the store itself is captured for longer than + // that inside an HttpClient. + var cache = _contextAccessor.HttpContext?.RequestServices.GetRequiredService(); + // check the cache in case the cookie was re-issued via StoreTokenAsync // we use String.Empty as the key for a null SignInScheme - if (!_cache.TryGetValue(parameters.SignInScheme ?? String.Empty, out var result)) + if (!cache!.TryGetValue(parameters.SignInScheme ?? String.Empty, out var result)) { result = await _contextAccessor!.HttpContext!.AuthenticateAsync(parameters.SignInScheme).ConfigureAwait(false); } @@ -80,9 +81,14 @@ public async Task StoreTokenAsync( { parameters ??= new(); + // Resolve the cache here because it needs to have a per-request + // lifetime. Sometimes the store itself is captured for longer than + // that inside an HttpClient. + var cache = _contextAccessor.HttpContext?.RequestServices.GetRequiredService(); + // check the cache in case the cookie was re-issued via StoreTokenAsync // we use String.Empty as the key for a null SignInScheme - if (!_cache.TryGetValue(parameters.SignInScheme ?? String.Empty, out var result)) + if (!cache!.TryGetValue(parameters.SignInScheme ?? String.Empty, out var result)) { result = await _contextAccessor.HttpContext!.AuthenticateAsync(parameters.SignInScheme)!.ConfigureAwait(false); } @@ -103,7 +109,7 @@ public async Task StoreTokenAsync( // add to the cache so if GetTokenAsync is called again, we will use the updated property values // we use String.Empty as the key for a null SignInScheme - _cache[parameters.SignInScheme ?? String.Empty] = AuthenticateResult.Success(new AuthenticationTicket(transformedPrincipal, result.Properties, scheme!)); + cache[parameters.SignInScheme ?? String.Empty] = AuthenticateResult.Success(new AuthenticationTicket(transformedPrincipal, result.Properties, scheme!)); } /// @@ -125,4 +131,4 @@ protected virtual Task FilterPrincipalAsync(ClaimsPrincipal pri return Task.FromResult(principal); } } -} \ No newline at end of file +} diff --git a/src/Duende.AccessTokenManagement.OpenIdConnect/OpenIdConnectTokenManagementServiceCollectionExtensions.cs b/src/Duende.AccessTokenManagement.OpenIdConnect/OpenIdConnectTokenManagementServiceCollectionExtensions.cs index b8226e0..e344f96 100644 --- a/src/Duende.AccessTokenManagement.OpenIdConnect/OpenIdConnectTokenManagementServiceCollectionExtensions.cs +++ b/src/Duende.AccessTokenManagement.OpenIdConnect/OpenIdConnectTokenManagementServiceCollectionExtensions.cs @@ -2,9 +2,11 @@ // Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. using System; +using System.Collections.Generic; using System.Net.Http; using Duende.AccessTokenManagement; using Duende.AccessTokenManagement.OpenIdConnect; +using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Logging; @@ -46,8 +48,10 @@ public static IServiceCollection AddOpenIdConnectAccessTokenManagement(this ISer // context, and we register different ones in blazor services.TryAddScoped(); - // scoped since it will be caching per-request authentication results services.TryAddScoped(); + + // scoped since it will be caching per-request authentication results + services.AddScoped(); return services; } diff --git a/test/Tests/Framework/ApiHost.cs b/test/Tests/Framework/ApiHost.cs index ea12cd8..322f76b 100644 --- a/test/Tests/Framework/ApiHost.cs +++ b/test/Tests/Framework/ApiHost.cs @@ -3,6 +3,7 @@ using Duende.IdentityServer.Models; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; namespace Duende.AccessTokenManagement.Tests; @@ -37,6 +38,7 @@ private void ConfigureServices(IServiceCollection services) options.Audience = _identityServerHost.Url("/resources"); options.MapInboundClaims = false; options.BackchannelHttpHandler = _identityServerHost.Server.CreateHandler(); + options.TokenValidationParameters.NameClaimType = "sub"; }); } @@ -62,43 +64,12 @@ private void Configure(IApplicationBuilder app) app.UseEndpoints(endpoints => { - // endpoints.Map("/{**catch-all}", async context => - // { - // // capture body if present - // var body = default(string); - // if (context.Request.HasJsonContentType()) - // { - // using (var sr = new StreamReader(context.Request.Body)) - // { - // body = await sr.ReadToEndAsync(); - // } - // } - // - // // capture request headers - // var requestHeaders = new Dictionary>(); - // foreach (var header in context.Request.Headers) - // { - // var values = new List(header.Value.Select(v => v)); - // requestHeaders.Add(header.Key, values); - // } - // - // var response = new ApiResponse( - // context.Request.Method, - // context.Request.Path.Value, - // context.User.FindFirst(("sub"))?.Value, - // context.User.FindFirst(("client_id"))?.Value, - // context.User.Claims.Select(x => new ClaimRecord(x.Type, x.Value)).ToArray()) - // { - // Body = body, - // RequestHeaders = requestHeaders - // }; - // - // context.Response.StatusCode = ApiStatusCodeToReturn ?? 200; - // ApiStatusCodeToReturn = null; - // - // context.Response.ContentType = "application/json"; - // await context.Response.WriteAsync(JsonSerializer.Serialize(response)); - // }); + endpoints.Map("/{**catch-all}", (HttpContext context) => + { + return new TokenEchoResponse( + context.User.Identity?.Name ?? "missing sub", + context.Request.Headers.Authorization.First() ?? "missing token"); + }); }); } } \ No newline at end of file diff --git a/test/Tests/Framework/AppHost.cs b/test/Tests/Framework/AppHost.cs index 2923294..cca5837 100644 --- a/test/Tests/Framework/AppHost.cs +++ b/test/Tests/Framework/AppHost.cs @@ -10,6 +10,7 @@ using IdentityModel; using Duende.AccessTokenManagement.OpenIdConnect; using RichardSzalay.MockHttp; +using System.Net.Http.Json; namespace Duende.AccessTokenManagement.Tests; @@ -17,7 +18,7 @@ public class AppHost : GenericHost { private readonly IdentityServerHost _identityServerHost; private readonly ApiHost _apiHost; - private readonly string _clientId; + public string ClientId; private readonly Action? _configureUserTokenManagementOptions; public AppHost( @@ -30,7 +31,7 @@ public AppHost( { _identityServerHost = identityServerHost; _apiHost = apiHost; - _clientId = clientId; + ClientId = clientId; _configureUserTokenManagementOptions = configureUserTokenManagementOptions; OnConfigureServices += ConfigureServices; OnConfigure += Configure; @@ -58,7 +59,7 @@ private void ConfigureServices(IServiceCollection services) { options.Authority = _identityServerHost.Url(); - options.ClientId = _clientId; + options.ClientId = ClientId; options.ClientSecret = "secret"; options.ResponseType = "code"; options.ResponseMode = "query"; @@ -68,7 +69,7 @@ private void ConfigureServices(IServiceCollection services) options.SaveTokens = true; options.Scope.Clear(); - var client = _identityServerHost.Clients.Single(x => x.ClientId == _clientId); + var client = _identityServerHost.Clients.Single(x => x.ClientId == ClientId); foreach (var scope in client.AllowedScopes) { options.Scope.Add(scope); @@ -107,6 +108,10 @@ private void ConfigureServices(IServiceCollection services) } }); + services.AddUserAccessTokenHttpClient("callApi", configureClient: client => { + client.BaseAddress = new Uri(_apiHost.Url()); + }) + .ConfigurePrimaryHttpMessageHandler(() => _apiHost.HttpMessageHandler); } private void Configure(IApplicationBuilder app) @@ -136,6 +141,13 @@ await context.ChallengeAsync(new AuthenticationProperties await context.Response.WriteAsJsonAsync(token); }); + endpoints.MapGet("/call_api", async (IHttpClientFactory factory, HttpContext context) => + { + var http = factory.CreateClient("callApi"); + var response = await http.GetAsync("test"); + return await response.Content.ReadFromJsonAsync(); + }); + endpoints.MapGet("/user_token_with_resource/{resource}", async (string resource, HttpContext context) => { var token = await context.GetUserAccessTokenAsync(new UserTokenRequestParameters @@ -204,4 +216,6 @@ public async Task LogoutAsync(string? sid = null) response = await BrowserClient.GetAsync(Url(response.Headers.Location.ToString())); return response; } -} \ No newline at end of file +} + +public record TokenEchoResponse(string sub, string token); \ No newline at end of file diff --git a/test/Tests/Framework/GenericHost.cs b/test/Tests/Framework/GenericHost.cs index 0a7264a..5e558ad 100644 --- a/test/Tests/Framework/GenericHost.cs +++ b/test/Tests/Framework/GenericHost.cs @@ -32,10 +32,9 @@ public GenericHost(string baseAddress = "https://server") public TestServer Server { get; private set; } = default!; public TestBrowserClient BrowserClient { get; set; } = default!; public HttpClient HttpClient { get; set; } = default!; - + public HttpMessageHandler HttpMessageHandler { get; set; } = default!; public TestLoggerProvider Logger { get; set; } = new TestLoggerProvider(); - public T Resolve() where T : notnull { @@ -84,6 +83,7 @@ public async Task InitializeAsync() Server = host.GetTestServer(); BrowserClient = new TestBrowserClient(Server.CreateHandler()); HttpClient = Server.CreateClient(); + HttpMessageHandler = Server.CreateHandler(); } public event Action OnConfigureServices = services => { }; diff --git a/test/Tests/UserTokenManagementTests.cs b/test/Tests/UserTokenManagementTests.cs index 42751e4..6d0ced5 100644 --- a/test/Tests/UserTokenManagementTests.cs +++ b/test/Tests/UserTokenManagementTests.cs @@ -183,10 +183,18 @@ public async Task Missing_initial_refresh_token_and_expired_access_token_should_ [Fact] public async Task Short_token_lifetime_should_trigger_refresh() { + // This test makes an initial token request using code flow and then + // refreshes the token a couple of times. + + // We mock the expiration of the first few token responses to be short + // enough that we will automatically refresh immediately when attempting + // to use the tokens, while the final response gets a long refresh time, + // allowing us to verify that the token is not refreshed. + var mockHttp = new MockHttpMessageHandler(); AppHost.IdentityServerHttpHandler = mockHttp; - // short token lifetime should trigger refresh on 1st use + // Respond to code flow with a short token lifetime so that we trigger refresh on 1st use var initialTokenResponse = new { id_token = IdentityServerHost.CreateIdToken("1", "web"), @@ -195,13 +203,11 @@ public async Task Short_token_lifetime_should_trigger_refresh() expires_in = 10, refresh_token = "initial_refresh_token", }; - - // response for re-deeming code mockHttp.When("/connect/token") .WithFormData("grant_type", "authorization_code") .Respond("application/json", JsonSerializer.Serialize(initialTokenResponse)); - // short token lifetime should trigger refresh on 1st use + // Respond to refresh with a short token lifetime so that we trigger another refresh on 2nd use var refreshTokenResponse = new { access_token = "refreshed1_access_token", @@ -209,14 +215,12 @@ public async Task Short_token_lifetime_should_trigger_refresh() expires_in = 10, refresh_token = "refreshed1_refresh_token", }; - - // response for refresh 1 mockHttp.When("/connect/token") .WithFormData("grant_type", "refresh_token") .WithFormData("refresh_token", "initial_refresh_token") .Respond("application/json", JsonSerializer.Serialize(refreshTokenResponse)); - // short token lifetime should trigger refresh on 2nd use + // Respond to second refresh with a long token lifetime so that we don't trigger another refresh on 3rd use var refreshTokenResponse2 = new { access_token = "refreshed2_access_token", @@ -224,8 +228,6 @@ public async Task Short_token_lifetime_should_trigger_refresh() expires_in = 3600, refresh_token = "refreshed2_refresh_token", }; - - // response for refresh 1 mockHttp.When("/connect/token") .WithFormData("grant_type", "refresh_token") .WithFormData("refresh_token", "refreshed1_refresh_token") @@ -397,4 +399,29 @@ public async Task Refresh_responses_without_refresh_token_use_old_refresh_token( token.IsError.ShouldBeFalse(); token.RefreshToken.ShouldBe("initial_refresh_token"); } + + [Fact] + public async Task Multiple_users_have_distinct_tokens_across_refreshes() + { + // setup host + AppHost.ClientId = "web.short"; + await AppHost.InitializeAsync(); + await AppHost.LoginAsync("alice"); + + var firstResponse = await AppHost.BrowserClient.GetAsync(AppHost.Url("/call_api")); + var firstToken = await firstResponse.Content.ReadFromJsonAsync(); + var secondResponse = await AppHost.BrowserClient.GetAsync(AppHost.Url("/call_api")); + var secondToken = await secondResponse.Content.ReadFromJsonAsync(); + firstToken.ShouldNotBeNull(); + secondToken.ShouldNotBeNull(); + secondToken.sub.ShouldBe(firstToken.sub); + secondToken.token.ShouldNotBe(firstToken.token); + + await AppHost.LoginAsync("bob"); + var thirdResponse = await AppHost.BrowserClient.GetAsync(AppHost.Url("/call_api")); + var thirdToken = await thirdResponse.Content.ReadFromJsonAsync(); + thirdToken.ShouldNotBeNull(); + thirdToken.sub.ShouldNotBe(secondToken.sub); + thirdToken.token.ShouldNotBe(firstToken.token); + } } \ No newline at end of file