Skip to content

Commit

Permalink
调整数据库类型比较
Browse files Browse the repository at this point in the history
  • Loading branch information
zlzforever committed Aug 14, 2024
1 parent 432ccd0 commit e609d3c
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/SecurityTokenService/Identity/IdentitySeedData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static void Load(IApplicationBuilder app)
var configuration = scope.ServiceProvider.GetRequiredService<IConfiguration>();
DbContext securityTokenServiceDbContext;

if (configuration.GetDatabaseType() == "MySql")
if ("mysql".Equals(configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
securityTokenServiceDbContext =
scope.ServiceProvider.GetRequiredService<MySqlSecurityTokenServiceDbContext>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public static void MigrateIdentityServer(this IApplicationBuilder app)
{
var configuration = app.ApplicationServices.GetRequiredService<IConfiguration>();
using var scope = app.ApplicationServices.CreateScope();
if (configuration.GetDatabaseType() == "MySql")
if ("mysql".Equals(configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
using var persistedGrantDbContext =
scope.ServiceProvider.GetRequiredService<MySqlPersistedGrantDbContext>();
Expand All @@ -190,4 +190,4 @@ public static void MigrateIdentityServer(this IApplicationBuilder app)
persistedGrantDbContext.Database.Migrate();
}
}
}
}
57 changes: 25 additions & 32 deletions src/SecurityTokenService/Stores/PhoneCodeStore.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Data.Common;
using System.Threading.Tasks;
using Dapper;
Expand Down Expand Up @@ -43,63 +44,55 @@ public async Task<string> GetAsync(string phoneNumber, int ttl = 300)
{
var sql = GetSelectSql();
await using var conn = GetConnection();
var code = await conn.QueryFirstOrDefaultAsync<string>(sql, new
{
phoneNumber,
ttl
}, commandTimeout: 30);
var code = await conn.QueryFirstOrDefaultAsync<string>(sql, new { phoneNumber, ttl }, commandTimeout: 30);
return code;
}

public async Task UpdateAsync(string phoneNumber, string code)
{
await using var conn = GetConnection();

await conn.ExecuteAsync(GetUpdateSql(), new
{
phoneNumber, code
}, commandTimeout: 30);
await conn.ExecuteAsync(GetUpdateSql(), new { phoneNumber, code }, commandTimeout: 30);
}

private DbConnection GetConnection()
{
var connectionString = _configuration["ConnectionStrings:Identity"];
var database = _configuration.GetDatabaseType().ToLower();
DbConnection conn = database switch

if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
"mysql" => new MySqlConnection(connectionString),
_ => new NpgsqlConnection(connectionString)
};
return conn;
return new MySqlConnection(connectionString);
}

return new NpgsqlConnection(connectionString);
}

private string GetUpdateSql()
{
var database = _configuration.GetDatabaseType().ToLower();
var tablePrefix = _configuration["Identity:TablePrefix"];
var conn = database switch

if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
"mysql" =>
return
$@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, UNIX_TIMESTAMP())
ON DUPLICATE KEY UPDATE code = @code, modification_time = UNIX_TIMESTAMP()",
_ =>
$@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, floor(extract(epoch from now())))
on conflict (phone_number) do update set code = @code, modification_time = floor(extract(epoch from now()))"
};
return conn;
ON DUPLICATE KEY UPDATE code = @code, modification_time = UNIX_TIMESTAMP()";
}

return
$@"INSERT INTO {tablePrefix}sms_code (phone_number, code, modification_time) VALUES (@phoneNumber, @code, floor(extract(epoch from now())))
on conflict (phone_number) do update set code = @code, modification_time = floor(extract(epoch from now()))";
}

private string GetSelectSql()
{
var database = _configuration.GetDatabaseType().ToLower();
var tablePrefix = _configuration["Identity:TablePrefix"];
var conn = database switch
if ("mysql".Equals(_configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
"mysql" =>
$"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= UNIX_TIMESTAMP() - @ttl;",
_ =>
$"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= floor(extract(epoch from now())) - @ttl;"
};
return conn;
return
$"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= UNIX_TIMESTAMP() - @ttl;";
}

return
$"SELECT code FROM {tablePrefix}sms_code WHERE phone_number = @phoneNumber AND modification_time >= floor(extract(epoch from now())) - @ttl;";
}
}
8 changes: 4 additions & 4 deletions src/SecurityTokenService/WebApplicationBuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public static WebApplicationBuilder AddDataProtection(this WebApplicationBuilder
dataProtectionBuilder.ProtectKeysWithCertificate(new X509Certificate2(protectKeysWithCertPath));
}

if (builder.Configuration.GetDatabaseType() == "MySql")
if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
dataProtectionBuilder.PersistKeysToDbContext<MySqlSecurityTokenServiceDbContext>();
}
Expand Down Expand Up @@ -178,7 +178,7 @@ public static WebApplicationBuilder AddIdentity(this WebApplicationBuilder build
identityBuilder.AddDefaultTokenProviders()
.AddErrorDescriber<SecurityTokenServiceIdentityErrorDescriber>();

if (builder.Configuration.GetDatabaseType() == "MySql")
if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
identityBuilder.AddEntityFrameworkStores<MySqlSecurityTokenServiceDbContext>();
}
Expand Down Expand Up @@ -211,11 +211,11 @@ public static WebApplicationBuilder AddIdentityServer(this WebApplicationBuilder

identityServerBuilder.Services.AddScoped<IPhoneCodeStore, PhoneCodeStore>();

if (builder.Configuration.GetDatabaseType() == "MySql")
if ("mysql".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
identityServerBuilder.AddOperationalStore<MySqlPersistedGrantDbContext>();
}
else if (builder.Configuration.GetDatabaseType() == "Postgre")
else if ("postgre".Equals(builder.Configuration.GetDatabaseType(), StringComparison.OrdinalIgnoreCase))
{
identityServerBuilder.AddOperationalStore<PostgreSqlPersistedGrantDbContext>();
}
Expand Down

0 comments on commit e609d3c

Please sign in to comment.