From e609d3c2f128a521229626830372381be8f9a6dc Mon Sep 17 00:00:00 2001 From: Lewis Zou <zlzforever@163.com> Date: Wed, 14 Aug 2024 14:45:31 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=AF=94=E8=BE=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Identity/IdentitySeedData.cs | 2 +- .../IdentityServerExtensions.cs | 4 +- .../Stores/PhoneCodeStore.cs | 57 ++++++++----------- .../WebApplicationBuilderExtensions.cs | 8 +-- 4 files changed, 32 insertions(+), 39 deletions(-) diff --git a/src/SecurityTokenService/Identity/IdentitySeedData.cs b/src/SecurityTokenService/Identity/IdentitySeedData.cs index 6284666..1291ee9 100644 --- a/src/SecurityTokenService/Identity/IdentitySeedData.cs +++ b/src/SecurityTokenService/Identity/IdentitySeedData.cs @@ -23,7 +23,7 @@ public static void Load(IApplicationBuilder app) var configuration = scope.ServiceProvider.GetRequiredService<IConfiguration>(); DbContext securityTokenServiceDbContext; - if (configuration.GetDatabaseType() == "MySql") + if ("mysql".Equals(configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { securityTokenServiceDbContext = scope.ServiceProvider.GetRequiredService<MySqlSecurityTokenServiceDbContext>(); diff --git a/src/SecurityTokenService/IdentityServer/IdentityServerExtensions.cs b/src/SecurityTokenService/IdentityServer/IdentityServerExtensions.cs index dfde757..fd90232 100644 --- a/src/SecurityTokenService/IdentityServer/IdentityServerExtensions.cs +++ b/src/SecurityTokenService/IdentityServer/IdentityServerExtensions.cs @@ -177,7 +177,7 @@ public static void MigrateIdentityServer(this IApplicationBuilder app) { var configuration = app.ApplicationServices.GetRequiredService<IConfiguration>(); using var scope = app.ApplicationServices.CreateScope(); - if (configuration.GetDatabaseType() == "MySql") + if ("mysql".Equals(configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { using var persistedGrantDbContext = scope.ServiceProvider.GetRequiredService<MySqlPersistedGrantDbContext>(); @@ -190,4 +190,4 @@ public static void MigrateIdentityServer(this IApplicationBuilder app) persistedGrantDbContext.Database.Migrate(); } } -} \ No newline at end of file +} diff --git a/src/SecurityTokenService/Stores/PhoneCodeStore.cs b/src/SecurityTokenService/Stores/PhoneCodeStore.cs index ba532b5..ae678c4 100644 --- a/src/SecurityTokenService/Stores/PhoneCodeStore.cs +++ b/src/SecurityTokenService/Stores/PhoneCodeStore.cs @@ -1,3 +1,4 @@ +using System; using System.Data.Common; using System.Threading.Tasks; using Dapper; @@ -43,11 +44,7 @@ public async Task<string> GetAsync(string phoneNumber, int ttl = 300) { var sql = GetSelectSql(); await using var conn = GetConnection(); - var code = await conn.QueryFirstOrDefaultAsync<string>(sql, new - { - phoneNumber, - ttl - }, commandTimeout: 30); + var code = await conn.QueryFirstOrDefaultAsync<string>(sql, new { phoneNumber, ttl }, commandTimeout: 30); return code; } @@ -55,51 +52,47 @@ public async Task UpdateAsync(string phoneNumber, string code) { await using var conn = GetConnection(); - await conn.ExecuteAsync(GetUpdateSql(), new - { - phoneNumber, code - }, commandTimeout: 30); + await conn.ExecuteAsync(GetUpdateSql(), new { phoneNumber, code }, commandTimeout: 30); } private DbConnection GetConnection() { var connectionString = _configuration["ConnectionStrings:Identity"]; - var database = _configuration.GetDatabaseType().ToLower(); - DbConnection conn = database switch + + if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { - "mysql" => new MySqlConnection(connectionString), - _ => new NpgsqlConnection(connectionString) - }; - return conn; + return new MySqlConnection(connectionString); + } + + return new NpgsqlConnection(connectionString); } private string GetUpdateSql() { - var database = _configuration.GetDatabaseType().ToLower(); var tablePrefix = _configuration["Identity:TablePrefix"]; - var conn = database switch + + if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { - "mysql" => + return $@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, UNIX_TIMESTAMP()) -ON DUPLICATE KEY UPDATE code = @code, modification_time = UNIX_TIMESTAMP()", - _ => - $@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, floor(extract(epoch from now()))) -on conflict (phone_number) do update set code = @code, modification_time = floor(extract(epoch from now()))" - }; - return conn; +ON DUPLICATE KEY UPDATE code = @code, modification_time = UNIX_TIMESTAMP()"; + } + + return + $@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, floor(extract(epoch from now()))) +on conflict (phone_number) do update set code = @code, modification_time = floor(extract(epoch from now()))"; } private string GetSelectSql() { - var database = _configuration.GetDatabaseType().ToLower(); var tablePrefix = _configuration["Identity:TablePrefix"]; - var conn = database switch + if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { - "mysql" => - $"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= UNIX_TIMESTAMP() - @ttl;", - _ => - $"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= floor(extract(epoch from now())) - @ttl;" - }; - return conn; + return + $"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= UNIX_TIMESTAMP() - @ttl;"; + } + + return + $"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= floor(extract(epoch from now())) - @ttl;"; } } diff --git a/src/SecurityTokenService/WebApplicationBuilderExtensions.cs b/src/SecurityTokenService/WebApplicationBuilderExtensions.cs index 10291f0..e6b2177 100644 --- a/src/SecurityTokenService/WebApplicationBuilderExtensions.cs +++ b/src/SecurityTokenService/WebApplicationBuilderExtensions.cs @@ -144,7 +144,7 @@ public static WebApplicationBuilder AddDataProtection(this WebApplicationBuilder dataProtectionBuilder.ProtectKeysWithCertificate(new X509Certificate2(protectKeysWithCertPath)); } - if (builder.Configuration.GetDatabaseType() == "MySql") + if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { dataProtectionBuilder.PersistKeysToDbContext<MySqlSecurityTokenServiceDbContext>(); } @@ -178,7 +178,7 @@ public static WebApplicationBuilder AddIdentity(this WebApplicationBuilder build identityBuilder.AddDefaultTokenProviders() .AddErrorDescriber<SecurityTokenServiceIdentityErrorDescriber>(); - if (builder.Configuration.GetDatabaseType() == "MySql") + if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { identityBuilder.AddEntityFrameworkStores<MySqlSecurityTokenServiceDbContext>(); } @@ -211,11 +211,11 @@ public static WebApplicationBuilder AddIdentityServer(this WebApplicationBuilder identityServerBuilder.Services.AddScoped<IPhoneCodeStore, PhoneCodeStore>(); - if (builder.Configuration.GetDatabaseType() == "MySql") + if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { identityServerBuilder.AddOperationalStore<MySqlPersistedGrantDbContext>(); } - else if (builder.Configuration.GetDatabaseType() == "Postgre") + else if ("postgre".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase)) { identityServerBuilder.AddOperationalStore<PostgreSqlPersistedGrantDbContext>(); }