From e609d3c2f128a521229626830372381be8f9a6dc Mon Sep 17 00:00:00 2001
From: Lewis Zou <zlzforever@163.com>
Date: Wed, 14 Aug 2024 14:45:31 +0800
Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=95=B0=E6=8D=AE=E5=BA=93?=
 =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=AF=94=E8=BE=83?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../Identity/IdentitySeedData.cs              |  2 +-
 .../IdentityServerExtensions.cs               |  4 +-
 .../Stores/PhoneCodeStore.cs                  | 57 ++++++++-----------
 .../WebApplicationBuilderExtensions.cs        |  8 +--
 4 files changed, 32 insertions(+), 39 deletions(-)

diff --git a/src/SecurityTokenService/Identity/IdentitySeedData.cs b/src/SecurityTokenService/Identity/IdentitySeedData.cs
index 6284666..1291ee9 100644
--- a/src/SecurityTokenService/Identity/IdentitySeedData.cs
+++ b/src/SecurityTokenService/Identity/IdentitySeedData.cs
@@ -23,7 +23,7 @@ public static void Load(IApplicationBuilder app)
         var configuration = scope.ServiceProvider.GetRequiredService<IConfiguration>();
         DbContext securityTokenServiceDbContext;
 
-        if (configuration.GetDatabaseType() == "MySql")
+        if ("mysql".Equals(configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
             securityTokenServiceDbContext =
                 scope.ServiceProvider.GetRequiredService<MySqlSecurityTokenServiceDbContext>();
diff --git a/src/SecurityTokenService/IdentityServer/IdentityServerExtensions.cs b/src/SecurityTokenService/IdentityServer/IdentityServerExtensions.cs
index dfde757..fd90232 100644
--- a/src/SecurityTokenService/IdentityServer/IdentityServerExtensions.cs
+++ b/src/SecurityTokenService/IdentityServer/IdentityServerExtensions.cs
@@ -177,7 +177,7 @@ public static void MigrateIdentityServer(this IApplicationBuilder app)
     {
         var configuration = app.ApplicationServices.GetRequiredService<IConfiguration>();
         using var scope = app.ApplicationServices.CreateScope();
-        if (configuration.GetDatabaseType() == "MySql")
+        if ("mysql".Equals(configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
             using var persistedGrantDbContext =
                 scope.ServiceProvider.GetRequiredService<MySqlPersistedGrantDbContext>();
@@ -190,4 +190,4 @@ public static void MigrateIdentityServer(this IApplicationBuilder app)
             persistedGrantDbContext.Database.Migrate();
         }
     }
-}
\ No newline at end of file
+}
diff --git a/src/SecurityTokenService/Stores/PhoneCodeStore.cs b/src/SecurityTokenService/Stores/PhoneCodeStore.cs
index ba532b5..ae678c4 100644
--- a/src/SecurityTokenService/Stores/PhoneCodeStore.cs
+++ b/src/SecurityTokenService/Stores/PhoneCodeStore.cs
@@ -1,3 +1,4 @@
+using System;
 using System.Data.Common;
 using System.Threading.Tasks;
 using Dapper;
@@ -43,11 +44,7 @@ public async Task<string> GetAsync(string phoneNumber, int ttl = 300)
     {
         var sql = GetSelectSql();
         await using var conn = GetConnection();
-        var code = await conn.QueryFirstOrDefaultAsync<string>(sql, new
-        {
-            phoneNumber,
-            ttl
-        }, commandTimeout: 30);
+        var code = await conn.QueryFirstOrDefaultAsync<string>(sql, new { phoneNumber, ttl }, commandTimeout: 30);
         return code;
     }
 
@@ -55,51 +52,47 @@ public async Task UpdateAsync(string phoneNumber, string code)
     {
         await using var conn = GetConnection();
 
-        await conn.ExecuteAsync(GetUpdateSql(), new
-        {
-            phoneNumber, code
-        }, commandTimeout: 30);
+        await conn.ExecuteAsync(GetUpdateSql(), new { phoneNumber, code }, commandTimeout: 30);
     }
 
     private DbConnection GetConnection()
     {
         var connectionString = _configuration["ConnectionStrings:Identity"];
-        var database = _configuration.GetDatabaseType().ToLower();
-        DbConnection conn = database switch
+
+        if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
-            "mysql" => new MySqlConnection(connectionString),
-            _ => new NpgsqlConnection(connectionString)
-        };
-        return conn;
+            return new MySqlConnection(connectionString);
+        }
+
+        return new NpgsqlConnection(connectionString);
     }
 
     private string GetUpdateSql()
     {
-        var database = _configuration.GetDatabaseType().ToLower();
         var tablePrefix = _configuration["Identity:TablePrefix"];
-        var conn = database switch
+
+        if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
-            "mysql" =>
+            return
                 $@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, UNIX_TIMESTAMP()) 
-ON DUPLICATE KEY UPDATE code = @code, modification_time = UNIX_TIMESTAMP()",
-            _ =>
-                $@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, floor(extract(epoch from now()))) 
-on conflict (phone_number) do update set code = @code, modification_time = floor(extract(epoch from now()))"
-        };
-        return conn;
+ON DUPLICATE KEY UPDATE code = @code, modification_time = UNIX_TIMESTAMP()";
+        }
+
+        return
+            $@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, floor(extract(epoch from now()))) 
+on conflict (phone_number) do update set code = @code, modification_time = floor(extract(epoch from now()))";
     }
 
     private string GetSelectSql()
     {
-        var database = _configuration.GetDatabaseType().ToLower();
         var tablePrefix = _configuration["Identity:TablePrefix"];
-        var conn = database switch
+        if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
-            "mysql" =>
-                $"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= UNIX_TIMESTAMP() - @ttl;",
-            _ =>
-                $"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= floor(extract(epoch from now())) - @ttl;"
-        };
-        return conn;
+            return
+                $"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= UNIX_TIMESTAMP() - @ttl;";
+        }
+
+        return
+            $"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= floor(extract(epoch from now())) - @ttl;";
     }
 }
diff --git a/src/SecurityTokenService/WebApplicationBuilderExtensions.cs b/src/SecurityTokenService/WebApplicationBuilderExtensions.cs
index 10291f0..e6b2177 100644
--- a/src/SecurityTokenService/WebApplicationBuilderExtensions.cs
+++ b/src/SecurityTokenService/WebApplicationBuilderExtensions.cs
@@ -144,7 +144,7 @@ public static WebApplicationBuilder AddDataProtection(this WebApplicationBuilder
             dataProtectionBuilder.ProtectKeysWithCertificate(new X509Certificate2(protectKeysWithCertPath));
         }
 
-        if (builder.Configuration.GetDatabaseType() == "MySql")
+        if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
             dataProtectionBuilder.PersistKeysToDbContext<MySqlSecurityTokenServiceDbContext>();
         }
@@ -178,7 +178,7 @@ public static WebApplicationBuilder AddIdentity(this WebApplicationBuilder build
         identityBuilder.AddDefaultTokenProviders()
             .AddErrorDescriber<SecurityTokenServiceIdentityErrorDescriber>();
 
-        if (builder.Configuration.GetDatabaseType() == "MySql")
+        if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
             identityBuilder.AddEntityFrameworkStores<MySqlSecurityTokenServiceDbContext>();
         }
@@ -211,11 +211,11 @@ public static WebApplicationBuilder AddIdentityServer(this WebApplicationBuilder
 
         identityServerBuilder.Services.AddScoped<IPhoneCodeStore, PhoneCodeStore>();
 
-        if (builder.Configuration.GetDatabaseType() == "MySql")
+        if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
             identityServerBuilder.AddOperationalStore<MySqlPersistedGrantDbContext>();
         }
-        else if (builder.Configuration.GetDatabaseType() == "Postgre")
+        else if ("postgre".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
         {
             identityServerBuilder.AddOperationalStore<PostgreSqlPersistedGrantDbContext>();
         }