Skip to content

Commit

Permalink
feat(templates): Check access token exp claim before using it in Boil…
Browse files Browse the repository at this point in the history
…erplate #9186 (#9211)
  • Loading branch information
ysmoradi authored Nov 12, 2024
1 parent 42040e6 commit bb6639c
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,14 @@ private async void AuthenticationStateChanged(Task<AuthenticationState> task)
try
{
var user = (await task).User;
TelemetryContext.UserId = user.IsAuthenticated() ? user.GetUserId() : null;
TelemetryContext.UserSessionId = user.IsAuthenticated() ? user.GetSessionId() : null;
var isAuthenticated = user.IsAuthenticated();
TelemetryContext.UserId = isAuthenticated ? user.GetUserId() : null;
TelemetryContext.UserSessionId = isAuthenticated ? user.GetSessionId() : null;

var data = TelemetryContext.ToDictionary();

//#if (appInsights == true)
if (user.IsAuthenticated())
if (isAuthenticated)
{
_ = appInsights.SetAuthenticatedUserContext(user.GetUserId().ToString());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
using System.Text;
using Boilerplate.Shared.Dtos.Identity;
using Boilerplate.Shared.Dtos.Identity;
using Boilerplate.Shared.Controllers.Identity;

namespace Boilerplate.Client.Core.Services;

public partial class AuthenticationManager : AuthenticationStateProvider
{
/// <summary>
/// To prevent multiple simultaneous refresh token requests.
/// </summary>
private readonly SemaphoreSlim semaphore = new(1, maxCount: 1);

[AutoInject] private Cookie cookie = default!;
[AutoInject] private IJSRuntime jsRuntime = default!;
[AutoInject] private IStorageService storageService = default!;
Expand All @@ -14,7 +18,6 @@ public partial class AuthenticationManager : AuthenticationStateProvider
[AutoInject] private IPrerenderStateService prerenderStateService;
[AutoInject] private IExceptionHandler exceptionHandler = default!;
[AutoInject] private IIdentityController identityController = default!;
[AutoInject] private JsonSerializerOptions jsonSerializerOptions = default!;

/// <summary>
/// Sign in and return whether the user requires two-factor authentication.
Expand Down Expand Up @@ -81,40 +84,45 @@ public override async Task<AuthenticationState> GetAuthenticationStateAsync()

if (string.IsNullOrEmpty(access_token) && inPrerenderSession is false)
{
string? refresh_token = await storageService.GetItem("refresh_token");

if (string.IsNullOrEmpty(refresh_token) is false)
try
{
// We refresh the access_token to ensure a seamless user experience, preventing unnecessary 'NotAuthorized' page redirects and improving overall UX.
// This method is triggered after 401 and 403 server responses in AuthDelegationHandler,
// as well as when accessing pages without the required permissions in NotAuthorizedPage, ensuring that any recent claims granted to the user are promptly reflected.

try
await semaphore.WaitAsync();
access_token = await tokenProvider.GetAccessToken();
if (string.IsNullOrEmpty(access_token)) // Check again after acquiring the lock.
{
var refreshTokenResponse = await identityController.Refresh(new() { RefreshToken = refresh_token }, CancellationToken.None);
await StoreTokens(refreshTokenResponse!);
access_token = refreshTokenResponse!.AccessToken;
}
catch (UnauthorizedException) // refresh_token is either invalid or expired.
{
await storageService.RemoveItem("refresh_token");
string? refresh_token = await storageService.GetItem("refresh_token");

if (string.IsNullOrEmpty(refresh_token) is false)
{
// We refresh the access_token to ensure a seamless user experience, preventing unnecessary 'NotAuthorized' page redirects and improving overall UX.
// This method is triggered after 401 and 403 server responses in AuthDelegationHandler,
// as well as when accessing pages without the required permissions in NotAuthorizedPage, ensuring that any recent claims granted to the user are promptly reflected.

try
{
var refreshTokenResponse = await identityController.Refresh(new() { RefreshToken = refresh_token }, CancellationToken.None);
await StoreTokens(refreshTokenResponse!);
access_token = refreshTokenResponse!.AccessToken;
}
catch (UnauthorizedException) // refresh_token is either invalid or expired.
{
await storageService.RemoveItem("refresh_token");
}
}
}
}
finally
{
semaphore.Release();
}
}

if (string.IsNullOrEmpty(access_token))
{
return NotSignedIn();
}

var identity = new ClaimsIdentity(claims: ParseTokenClaims(access_token), authenticationType: "Bearer", nameType: "name", roleType: "role");

return new AuthenticationState(new ClaimsPrincipal(identity));
return new AuthenticationState(tokenProvider.ParseAccessToken(access_token, validateExpiry: false /* For better UX in order to minimize Routes.razor's Authorizing loading duration. */));
}
catch (Exception exp)
{
exceptionHandler.Handle(exp); // Do not throw exceptions in GetAuthenticationStateAsync. This will fault CascadingAuthenticationState's state unless NotifyAuthenticationStateChanged is called again.
return NotSignedIn();
return new AuthenticationState(tokenProvider.Anonymous());
}
}

Expand All @@ -141,67 +149,4 @@ await cookie.Set(new()
});
}
}

private static AuthenticationState NotSignedIn()
{
return new AuthenticationState(new ClaimsPrincipal(new ClaimsIdentity()));
}

private IEnumerable<Claim> ParseTokenClaims(string access_token)
{
var parsedClaims = ParseJwt(access_token);

var claims = new List<Claim>();
foreach (var keyValue in parsedClaims)
{
if (keyValue.Value.ValueKind == JsonValueKind.Array)
{
foreach (var element in keyValue.Value.EnumerateArray())
{
claims.Add(new Claim(keyValue.Key, element.ToString() ?? string.Empty));
}
}
else
{
claims.Add(new Claim(keyValue.Key, keyValue.Value.ToString() ?? string.Empty));
}
}

return claims;
}

private Dictionary<string, JsonElement> ParseJwt(string access_token)
{
// Split the token to get the payload
string base64UrlPayload = access_token.Split('.')[1];

// Convert the payload from Base64Url format to Base64
string base64Payload = ConvertBase64UrlToBase64(base64UrlPayload);

// Decode the Base64 string to get a JSON string
string jsonPayload = Encoding.UTF8.GetString(Convert.FromBase64String(base64Payload));

// Deserialize the JSON string to a dictionary
var claims = JsonSerializer.Deserialize(jsonPayload, jsonSerializerOptions.GetTypeInfo<Dictionary<string, JsonElement>>())!;

return claims;
}

private static string ConvertBase64UrlToBase64(string base64Url)
{
base64Url = base64Url.Replace('-', '+').Replace('_', '/');

// Adjust base64Url string length for padding
switch (base64Url.Length % 4)
{
case 2:
base64Url += "==";
break;
case 3:
base64Url += "=";
break;
}

return base64Url;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,92 @@
namespace Boilerplate.Client.Core.Services.Contracts;
using System.Text;

namespace Boilerplate.Client.Core.Services.Contracts;

public interface IAuthTokenProvider
{
Task<string?> GetAccessToken();

public ClaimsPrincipal Anonymous() => new(new ClaimsIdentity());

public ClaimsPrincipal ParseAccessToken(string? access_token, bool validateExpiry)
{
if (string.IsNullOrEmpty(access_token) is true)
return Anonymous();

var claims = ReadClaims(access_token, validateExpiry);

if (claims is null)
return Anonymous();

var identity = new ClaimsIdentity(claims: claims, authenticationType: "Bearer", nameType: "name", roleType: "role");

var claimPrinciple = new ClaimsPrincipal(identity);

return claimPrinciple;
}

private IEnumerable<Claim>? ReadClaims(string access_token, bool validateExpiry)
{
var parsedClaims = DeserializeAccessToken(access_token);

if (validateExpiry && long.TryParse(parsedClaims["exp"].ToString(), out var expSeconds))
{
var expirationDate = DateTimeOffset.FromUnixTimeSeconds(expSeconds);
if (expirationDate <= DateTimeOffset.UtcNow)
return null;
}

var claims = new List<Claim>();
foreach (var keyValue in parsedClaims)
{
if (keyValue.Value.ValueKind == JsonValueKind.Array)
{
foreach (var element in keyValue.Value.EnumerateArray())
{
claims.Add(new Claim(keyValue.Key, element.ToString() ?? string.Empty));
}
}
else
{
claims.Add(new Claim(keyValue.Key, keyValue.Value.ToString() ?? string.Empty));
}
}

return claims;
}

private Dictionary<string, JsonElement> DeserializeAccessToken(string access_token)
{
// Split the token to get the payload
string base64UrlPayload = access_token.Split('.')[1];

// Convert the payload from Base64Url format to Base64
string base64Payload = ConvertBase64UrlToBase64(base64UrlPayload);

// Decode the Base64 string to get a JSON string
string jsonPayload = Encoding.UTF8.GetString(Convert.FromBase64String(base64Payload));

// Deserialize the JSON string to a dictionary
var claims = JsonSerializer.Deserialize(jsonPayload, AppJsonContext.Default.Options.GetTypeInfo<Dictionary<string, JsonElement>>())!;

return claims;
}

private string ConvertBase64UrlToBase64(string base64Url)
{
base64Url = base64Url.Replace('-', '+').Replace('_', '/');

// Adjust base64Url string length for padding
switch (base64Url.Length % 4)
{
case 2:
base64Url += "==";
break;
case 3:
base64Url += "=";
break;
}

return base64Url;
}
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,44 @@
using System.Net.Http.Headers;
using Boilerplate.Shared.Controllers.Identity;

namespace Boilerplate.Client.Core.Services.HttpMessageHandlers;

public partial class AuthDelegatingHandler(IAuthTokenProvider tokenProvider,
IJSRuntime jsRuntime,
IServiceProvider serviceProvider,
IStorageService storageService,
IServiceProvider serviceProvider,
IStorageService storageService,
HttpMessageHandler handler)
: DelegatingHandler(handler)
{
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
if (request.Headers.Authorization is null)
var isRefreshTokenRequest = request.RequestUri?.LocalPath?.Contains(IIdentityController.RefreshUri, StringComparison.InvariantCultureIgnoreCase) is true;

try
{
var access_token = await tokenProvider.GetAccessToken();
if (access_token is not null)
if (request.Headers.Authorization is null && isRefreshTokenRequest is false)
{
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", access_token);
var access_token = await tokenProvider.GetAccessToken();
if (access_token is not null)
{
if (tokenProvider.ParseAccessToken(access_token, validateExpiry: true).IsAuthenticated() is false)
throw new UnauthorizedException();

request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", access_token);
}
}
}

try
{
return await base.SendAsync(request, cancellationToken);
}
catch (KnownException _) when (_ is ForbiddenException or UnauthorizedException)
{
// Let's update the access token by refreshing it when a refresh token is available.
// Following this procedure, the newly acquired access token may now include the necessary roles or claims.

if (AppPlatform.IsBlazorHybrid is false && jsRuntime.IsInitialized() is false)
if (AppPlatform.IsBlazorHybrid is false && jsRuntime.IsInitialized() is false)
throw; // We don't have access to refresh_token during pre-rendering.

if (request.RequestUri?.LocalPath?.Contains("api/Identity/Refresh", StringComparison.InvariantCultureIgnoreCase) is true)
if (isRefreshTokenRequest)
throw; // To prevent refresh token loop

var refresh_token = await storageService.GetItem("refresh_token");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public interface IIdentityController : IAppController
[HttpPost]
Task ResetPassword(ResetPasswordRequestDto request, CancellationToken cancellationToken);

public const string RefreshUri = "api/Identity/Refresh";
[HttpPost]
Task<TokenResponseDto> Refresh(RefreshRequestDto request, CancellationToken cancellationToken) => default!;

Expand Down

0 comments on commit bb6639c

Please sign in to comment.