diff --git a/samples/ASBTaskQueue/Greetings/Adaptors/Data/AzureAdAuthenticationDbConnectionInterceptor.cs b/samples/ASBTaskQueue/Greetings/Adaptors/Data/AzureAdAuthenticationDbConnectionInterceptor.cs new file mode 100644 index 0000000000..500ea85ece --- /dev/null +++ b/samples/ASBTaskQueue/Greetings/Adaptors/Data/AzureAdAuthenticationDbConnectionInterceptor.cs @@ -0,0 +1,106 @@ +using System; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Identity; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore.Diagnostics; + +namespace Greetings.Adaptors.Data +{ + public class AzureAdAuthenticationDbConnectionInterceptor : DbConnectionInterceptor + { + // See https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/services-support-managed-identities#azure-sql + private static readonly string[] _azureSqlScopes = new[] { "https://database.windows.net//.default" }; + private const int _cacheLifeTime = 5; + + private static readonly TokenCredential _credential = new ChainedTokenCredential( + new ManagedIdentityCredential(), + new VisualStudioCredential()); + + private static AccessToken _token; + private static SemaphoreSlim _semaphoreToken = new SemaphoreSlim(1, 1); + + public override InterceptionResult ConnectionOpening( + DbConnection connection, + ConnectionEventData eventData, + InterceptionResult result) + { + var sqlConnection = (SqlConnection)connection; + if (DoesConnectionNeedAccessToken(sqlConnection)) + sqlConnection.AccessToken = GetAccessToken(); + + return base.ConnectionOpening(connection, eventData, result); + } + + public override async ValueTask ConnectionOpeningAsync( + DbConnection connection, + ConnectionEventData eventData, + InterceptionResult result, + CancellationToken cancellationToken = default) + { + var sqlConnection = (SqlConnection)connection; + if (DoesConnectionNeedAccessToken(sqlConnection)) + sqlConnection.AccessToken = await GetAccessTokenAsync(); + + return await base.ConnectionOpeningAsync(connection, eventData, result, cancellationToken); + } + + private static bool DoesConnectionNeedAccessToken(SqlConnection connection) + { + // + // Only try to get a token from AAD if + // - We connect to an Azure SQL instance; and + // - The connection doesn't specify a username. + // + var connectionStringBuilder = new SqlConnectionStringBuilder(connection.ConnectionString); + + return connectionStringBuilder.DataSource.Contains("database.windows.net", StringComparison.OrdinalIgnoreCase) && string.IsNullOrEmpty(connectionStringBuilder.UserID); + } + + private string GetAccessToken() + { + _semaphoreToken.Wait(); + try + { + //If the Token has more than 5 minutes Validity + if (DateTime.UtcNow.AddMinutes(_cacheLifeTime) <= _token.ExpiresOn.UtcDateTime) + return _token.Token; + + var tokenRequestContext = new TokenRequestContext(_azureSqlScopes); + var token = _credential.GetToken(tokenRequestContext, CancellationToken.None); + + _token = token; + + return token.Token; + } + finally + { + _semaphoreToken.Release(); + } + } + + private async Task GetAccessTokenAsync() + { + await _semaphoreToken.WaitAsync(); + try + { + //If the Token has more than 5 minutes Validity + if (DateTime.UtcNow.AddMinutes(_cacheLifeTime) <= _token.ExpiresOn.UtcDateTime) + return _token.Token; + + var tokenRequestContext = new TokenRequestContext(_azureSqlScopes); + var token = await _credential.GetTokenAsync(tokenRequestContext, CancellationToken.None); + + _token = token; + + return token.Token; + } + finally + { + _semaphoreToken.Release(); + } + } + } +} diff --git a/samples/ASBTaskQueue/Greetings/Greetings.csproj b/samples/ASBTaskQueue/Greetings/Greetings.csproj index 5090f30c04..64a2e29418 100644 --- a/samples/ASBTaskQueue/Greetings/Greetings.csproj +++ b/samples/ASBTaskQueue/Greetings/Greetings.csproj @@ -13,6 +13,7 @@ + diff --git a/samples/ASBTaskQueue/GreetingsSender.Web/Controllers/HomeController.cs b/samples/ASBTaskQueue/GreetingsSender.Web/Controllers/HomeController.cs index 718b2144f2..57d81247e0 100644 --- a/samples/ASBTaskQueue/GreetingsSender.Web/Controllers/HomeController.cs +++ b/samples/ASBTaskQueue/GreetingsSender.Web/Controllers/HomeController.cs @@ -54,6 +54,8 @@ public async Task DepositMessage() var greetingAsync = new GreetingAsyncEvent("Deposit Hello from the web"); var greeting = new GreetingEvent("Deposit Hello from the web"); + await _commandProcessor.DepositPostAsync(greetingAsync); + _context.Greetings.Add(greeting); _context.GreetingsAsync.Add(greetingAsync); await _context.SaveChangesAsync(); diff --git a/samples/ASBTaskQueue/GreetingsSender.Web/Program.cs b/samples/ASBTaskQueue/GreetingsSender.Web/Program.cs index aa197cc307..13867aa2c2 100644 --- a/samples/ASBTaskQueue/GreetingsSender.Web/Program.cs +++ b/samples/ASBTaskQueue/GreetingsSender.Web/Program.cs @@ -24,6 +24,7 @@ builder.Services.AddDbContext(o => { o.UseSqlServer(dbConnString); + //o.AddInterceptors(new AzureAdAuthenticationDbConnectionInterceptor()); }); //Services @@ -67,7 +68,7 @@ var services = serviceScope.ServiceProvider; var dbContext = services.GetService(); - dbContext.Database.EnsureCreated(); + //dbContext.Database.EnsureCreated(); } else { diff --git a/src/Paramore.Brighter.MsSql.EntityFrameworkCore/MsSqlEntityFrameworkCoreConnectionProvider.cs b/src/Paramore.Brighter.MsSql.EntityFrameworkCore/MsSqlEntityFrameworkCoreConnectionProvider.cs index b2f286884a..e4bfba844f 100644 --- a/src/Paramore.Brighter.MsSql.EntityFrameworkCore/MsSqlEntityFrameworkCoreConnectionProvider.cs +++ b/src/Paramore.Brighter.MsSql.EntityFrameworkCore/MsSqlEntityFrameworkCoreConnectionProvider.cs @@ -1,4 +1,5 @@ -using System.Threading; +using System.Data; +using System.Threading; using System.Threading.Tasks; using Microsoft.Data.SqlClient; using Microsoft.EntityFrameworkCore; @@ -20,14 +21,16 @@ public MsSqlEntityFrameworkCoreConnectionProvider(T context) public SqlConnection GetConnection() { + //This line ensure that the connection has been initialised and that any required interceptors have been run before getting the connection + _context.Database.CanConnect(); return (SqlConnection)_context.Database.GetDbConnection(); } - public Task GetConnectionAsync(CancellationToken cancellationToken = default(CancellationToken)) + public async Task GetConnectionAsync(CancellationToken cancellationToken = default(CancellationToken)) { - var tcs = new TaskCompletionSource(); - tcs.SetResult(GetConnection()); - return tcs.Task; + //This line ensure that the connection has been initialised and that any required interceptors have been run before getting the connection + await _context.Database.CanConnectAsync(cancellationToken); + return (SqlConnection)_context.Database.GetDbConnection(); } public SqlTransaction GetTransaction() diff --git a/src/Paramore.Brighter.MySql.EntityFrameworkCore/MySqlEntityFrameworkConnectionProvider.cs b/src/Paramore.Brighter.MySql.EntityFrameworkCore/MySqlEntityFrameworkConnectionProvider.cs index 9bdb1d8435..bb02ef4f33 100644 --- a/src/Paramore.Brighter.MySql.EntityFrameworkCore/MySqlEntityFrameworkConnectionProvider.cs +++ b/src/Paramore.Brighter.MySql.EntityFrameworkCore/MySqlEntityFrameworkConnectionProvider.cs @@ -29,6 +29,8 @@ public MySqlEntityFrameworkConnectionProvider(T context) /// The Sqlite Connection that is in use public MySqlConnection GetConnection() { + //This line ensure that the connection has been initialised and that any required interceptors have been run before getting the connection + _context.Database.CanConnect(); return (MySqlConnection) _context.Database.GetDbConnection(); } @@ -37,11 +39,11 @@ public MySqlConnection GetConnection() /// /// A cancellation token /// - public Task GetConnectionAsync(CancellationToken cancellationToken = default(CancellationToken)) + public async Task GetConnectionAsync(CancellationToken cancellationToken = default(CancellationToken)) { - var tcs = new TaskCompletionSource(); - tcs.SetResult((MySqlConnection)_context.Database.GetDbConnection()); - return tcs.Task; + //This line ensure that the connection has been initialised and that any required interceptors have been run before getting the connection + await _context.Database.CanConnectAsync(cancellationToken); + return (MySqlConnection)_context.Database.GetDbConnection(); } /// diff --git a/src/Paramore.Brighter.Sqlite.EntityFrameworkCore/SqliteEntityFrameworkConnectionProvider.cs b/src/Paramore.Brighter.Sqlite.EntityFrameworkCore/SqliteEntityFrameworkConnectionProvider.cs index 96cb129215..72c9068ef5 100644 --- a/src/Paramore.Brighter.Sqlite.EntityFrameworkCore/SqliteEntityFrameworkConnectionProvider.cs +++ b/src/Paramore.Brighter.Sqlite.EntityFrameworkCore/SqliteEntityFrameworkConnectionProvider.cs @@ -29,6 +29,8 @@ public SqliteEntityFrameworkConnectionProvider(T context) /// The Sqlite Connection that is in use public SqliteConnection GetConnection() { + //This line ensure that the connection has been initialised and that any required interceptors have been run before getting the connection + _context.Database.CanConnect(); return (SqliteConnection) _context.Database.GetDbConnection(); } @@ -37,11 +39,11 @@ public SqliteConnection GetConnection() /// /// A cancellation token /// - public Task GetConnectionAsync(CancellationToken cancellationToken = default(CancellationToken)) + public async Task GetConnectionAsync(CancellationToken cancellationToken = default(CancellationToken)) { - var tcs = new TaskCompletionSource(); - tcs.SetResult((SqliteConnection)_context.Database.GetDbConnection()); - return tcs.Task; + //This line ensure that the connection has been initialised and that any required interceptors have been run before getting the connection + await _context.Database.CanConnectAsync(cancellationToken); + return (SqliteConnection)_context.Database.GetDbConnection(); } /// diff --git a/src/Paramore.Brighter/CommandProcessor.cs b/src/Paramore.Brighter/CommandProcessor.cs index 8b12cc0914..f694934bef 100644 --- a/src/Paramore.Brighter/CommandProcessor.cs +++ b/src/Paramore.Brighter/CommandProcessor.cs @@ -503,7 +503,7 @@ private async Task DepositPostAsync(T request, IAmABoxTransactionConnec var message = messageMapper.MapToMessage(request); - await _bus.AddToOutboxAsync(request, continueOnCapturedContext, cancellationToken, message, _boxTransactionConnectionProvider); + await _bus.AddToOutboxAsync(request, continueOnCapturedContext, cancellationToken, message, connectionProvider); return message.Id; }