Skip to content

Commit

Permalink
Merge pull request #2653 from area363/override-rate-limit-with-jwt-token
Browse files Browse the repository at this point in the history
Override rate limit with jwt token
  • Loading branch information
area363 authored Dec 16, 2024
2 parents 71d6136 + bdfda08 commit 72a435d
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 34 deletions.
34 changes: 19 additions & 15 deletions NineChronicles.Headless/GraphQLService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,14 @@ public void ConfigureServices(IServiceCollection services)
"Admin"));

// FIXME: Use ConfigurationException after bumping to .NET 8 or later.
options.AddPolicy(
JwtPolicyKey,
p =>
p.RequireClaim("iss",
jwtOptions["Issuer"] ?? throw new ArgumentException("jwtOptions[\"Issuer\"] is null.")));
if (Convert.ToBoolean(Configuration.GetSection("Jwt")["EnableJwtAuthentication"]))
{
options.AddPolicy(
JwtPolicyKey,
p =>
p.RequireClaim("iss",
jwtOptions["Issuer"] ?? throw new ArgumentException("jwtOptions[\"Issuer\"] is null.")));
}
});

services.AddGraphTypes();
Expand All @@ -220,6 +223,17 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
}

// Capture requests
app.UseMiddleware<HttpCaptureMiddleware>();

app.UseRouting();
app.UseAuthorization();
if (Convert.ToBoolean(Configuration.GetSection("IpRateLimiting")["EnableEndpointRateLimiting"]))
{
app.UseMiddleware<CustomRateLimitMiddleware>();
app.UseMiddleware<IpBanMiddleware>();
app.UseMvc();
}

if (Convert.ToBoolean(Configuration.GetSection("MultiAccountManaging")["EnableManaging"]))
{
ConcurrentDictionary<string, HashSet<Address>> ipSignerList = new();
Expand All @@ -229,7 +243,6 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
Publisher);
}

app.UseMiddleware<HttpCaptureMiddleware>();

app.UseMiddleware<LocalAuthenticationMiddleware>();
if (Convert.ToBoolean(Configuration.GetSection("Jwt")["EnableJwtAuthentication"]))
Expand All @@ -246,15 +259,6 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
app.UseCors("AllowAllOrigins");
}

app.UseRouting();
app.UseAuthorization();
if (Convert.ToBoolean(Configuration.GetSection("IpRateLimiting")["EnableEndpointRateLimiting"]))
{
app.UseMiddleware<CustomRateLimitMiddleware>();
app.UseMiddleware<IpBanMiddleware>();
app.UseMvc();
}

app.UseEndpoints(endpoints =>
{
endpoints.MapControllers();
Expand Down
71 changes: 62 additions & 9 deletions NineChronicles.Headless/Middleware/CustomRateLimitMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,109 @@
using NineChronicles.Headless.Properties;
using Serilog;
using ILogger = Serilog.ILogger;
using System.Linq;
using Microsoft.Extensions.Configuration;

namespace NineChronicles.Headless.Middleware
{

public class CustomRateLimitMiddleware : RateLimitMiddleware<CustomIpRateLimitProcessor>
{
private readonly ILogger _logger;
private readonly IRateLimitConfiguration _config;
private readonly IOptions<CustomIpRateLimitOptions> _options;
private readonly string _whitelistedIp;
private readonly System.IdentityModel.Tokens.Jwt.JwtSecurityTokenHandler _tokenHandler = new();
private readonly Microsoft.IdentityModel.Tokens.TokenValidationParameters _validationParams;

public CustomRateLimitMiddleware(RequestDelegate next,
IProcessingStrategy processingStrategy,
IOptions<CustomIpRateLimitOptions> options,
IIpPolicyStore policyStore,
IRateLimitConfiguration config)
IRateLimitConfiguration config,
Microsoft.Extensions.Configuration.IConfiguration configuration)
: base(next, options?.Value, new CustomIpRateLimitProcessor(options?.Value!, policyStore, processingStrategy), config)
{
_config = config;
_options = options!;
_logger = Log.Logger.ForContext<CustomRateLimitMiddleware>();
var jwtConfig = configuration.GetSection("Jwt");
var issuer = jwtConfig["Issuer"] ?? "";
var key = jwtConfig["Key"] ?? "";
_whitelistedIp = configuration.GetSection("IpRateLimiting:IpWhitelist")?.Get<string[]>()?.FirstOrDefault() ?? "127.0.0.1";
_validationParams = new Microsoft.IdentityModel.Tokens.TokenValidationParameters
{
ValidateIssuer = true,
ValidateAudience = false,
ValidateLifetime = true,
ValidateIssuerSigningKey = true,
ValidIssuer = issuer,
IssuerSigningKey = new Microsoft.IdentityModel.Tokens.SymmetricSecurityKey(System.Text.Encoding.ASCII.GetBytes(key.PadRight(512 / 8, '\0')))
};
}

protected override void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule)
protected override void LogBlockedRequest(HttpContext context, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule)
{
_logger.Information($"[IP-RATE-LIMITER] Request {identity.HttpVerb}:{identity.Path} from IP {identity.ClientIp} has been blocked, " +
$"quota {rule.Limit}/{rule.Period} exceeded by {counter.Count - rule.Limit}. Blocked by rule {rule.Endpoint}, " +
$"TraceIdentifier {httpContext.TraceIdentifier}. MonitorMode: {rule.MonitorMode}");
$"TraceIdentifier {context.TraceIdentifier}. MonitorMode: {rule.MonitorMode}");
if (counter.Count - rule.Limit >= _options.Value.IpBanThresholdCount)
{
_logger.Information($"[IP-RATE-LIMITER] Banning IP {identity.ClientIp}.");
IpBanMiddleware.BanIp(identity.ClientIp);
}
}

public override async Task<ClientRequestIdentity> ResolveIdentityAsync(HttpContext httpContext)
public override async Task<ClientRequestIdentity> ResolveIdentityAsync(HttpContext context)
{
var identity = await base.ResolveIdentityAsync(httpContext);
var identity = await base.ResolveIdentityAsync(context);

if (httpContext.Request.Protocol == "HTTP/1.1")
if (context.Request.Protocol == "HTTP/1.1")
{
var body = await new StreamReader(httpContext.Request.Body).ReadToEndAsync();
httpContext.Request.Body.Seek(0, SeekOrigin.Begin);
var body = context.Items["RequestBody"]!.ToString()!;
context.Request.Body.Seek(0, SeekOrigin.Begin);
if (body.Contains("stageTransaction"))
{
identity.Path = "/graphql/stagetransaction";
}
}

return identity;
// Check for JWT secret key in headers
if (context.Request.Headers.TryGetValue("Authorization", out var authHeaderValue) &&
authHeaderValue.Count > 0)
{
try
{
var (scheme, token) = ExtractSchemeAndToken(authHeaderValue);
if (scheme.Equals("Bearer", System.StringComparison.OrdinalIgnoreCase))
{
_tokenHandler.ValidateToken(token, _validationParams, out _);
identity.ClientIp = _whitelistedIp;
}
}
catch (System.Exception ex)
{
_logger.Warning("[IP-RATE-LIMITER] JWT validation failed: {Message}", ex.Message);
}
}

return identity;
}

private (string scheme, string token) ExtractSchemeAndToken(Microsoft.Extensions.Primitives.StringValues authorizationHeader)
{
if (authorizationHeader.Count == 0 || string.IsNullOrWhiteSpace(authorizationHeader[0]))
{
throw new System.ArgumentException("Authorization header is missing or empty.");
}

var headerValues = authorizationHeader[0]!.Split(" ");
if (headerValues.Length != 2)
{
throw new System.ArgumentException("Invalid Authorization header format. Expected 'Scheme Token'.");
}

return (headerValues[0], headerValues[1]);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public async Task InvokeAsync(HttpContext context)
context.Request.EnableBuffering();
var remoteIp = context.Connection.RemoteIpAddress;
var body = await new StreamReader(context.Request.Body).ReadToEndAsync();
context.Items["RequestBody"] = body;
_logger.Information("[GRAPHQL-REQUEST-CAPTURE] IP: {IP} Method: {Method} Endpoint: {Path} {Body}",
remoteIp, context.Request.Method, context.Request.Path, body);
context.Request.Body.Seek(0, SeekOrigin.Begin);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,35 @@ public class HttpMultiAccountManagementMiddleware
private readonly ConcurrentDictionary<string, HashSet<Address>> _ipSignerList;
private readonly IOptions<MultiAccountManagerProperties> _options;
private ActionEvaluationPublisher _publisher;
private readonly System.IdentityModel.Tokens.Jwt.JwtSecurityTokenHandler _tokenHandler = new();
private readonly Microsoft.IdentityModel.Tokens.TokenValidationParameters _validationParams;

public HttpMultiAccountManagementMiddleware(
RequestDelegate next,
StandaloneContext standaloneContext,
ConcurrentDictionary<string, HashSet<Address>> ipSignerList,
IOptions<MultiAccountManagerProperties> options,
ActionEvaluationPublisher publisher)
ActionEvaluationPublisher publisher,
Microsoft.Extensions.Configuration.IConfiguration configuration)
{
_next = next;
_logger = Log.Logger.ForContext<HttpMultiAccountManagementMiddleware>();
_standaloneContext = standaloneContext;
_ipSignerList = ipSignerList;
_options = options;
_publisher = publisher;
var jwtConfig = configuration.GetSection("Jwt");
var issuer = jwtConfig["Issuer"] ?? "";
var key = jwtConfig["Key"] ?? "";
_validationParams = new Microsoft.IdentityModel.Tokens.TokenValidationParameters
{
ValidateIssuer = true,
ValidateAudience = false,
ValidateLifetime = true,
ValidateIssuerSigningKey = true,
ValidIssuer = issuer,
IssuerSigningKey = new Microsoft.IdentityModel.Tokens.SymmetricSecurityKey(System.Text.Encoding.ASCII.GetBytes(key.PadRight(512 / 8, '\0')))
};
}

private static void ManageMultiAccount(Address agent)
Expand All @@ -58,9 +73,29 @@ public async Task InvokeAsync(HttpContext context)
// Prevent to harm HTTP/2 communication.
if (context.Request.Protocol == "HTTP/1.1")
{
context.Request.EnableBuffering();
var remoteIp = context.Connection.RemoteIpAddress!.ToString();
var body = await new StreamReader(context.Request.Body).ReadToEndAsync();

// Check for JWT secret key in headers
if (context.Request.Headers.TryGetValue("Authorization", out var authHeaderValue) &&
authHeaderValue.Count > 0)
{
try
{
var (scheme, token) = ExtractSchemeAndToken(authHeaderValue);
if (scheme.Equals("Bearer", System.StringComparison.OrdinalIgnoreCase))
{
_tokenHandler.ValidateToken(token, _validationParams, out _);
await _next(context);
return;
}
}
catch (System.Exception ex)
{
_logger.Warning("[GRAPHQL-MULTI-ACCOUNT-MANAGER] JWT validation failed: {Message}", ex.Message);
}
}

var body = context.Items["RequestBody"]!.ToString()!;
context.Request.Body.Seek(0, SeekOrigin.Begin);
if (_options.Value.EnableManaging && body.Contains("stageTransaction"))
{
Expand Down Expand Up @@ -150,6 +185,22 @@ and not ClaimStakeReward
await _next(context);
}

private (string scheme, string token) ExtractSchemeAndToken(Microsoft.Extensions.Primitives.StringValues authorizationHeader)
{
if (authorizationHeader.Count == 0 || string.IsNullOrWhiteSpace(authorizationHeader[0]))
{
throw new System.ArgumentException("Authorization header is missing or empty.");
}

var headerValues = authorizationHeader[0]!.Split(" ");
if (headerValues.Length != 2)
{
throw new System.ArgumentException("Invalid Authorization header format. Expected 'Scheme Token'.");
}

return (headerValues[0], headerValues[1]);
}

private void UpdateIpSignerList(string ip, Address agent)
{
if (!_ipSignerList.ContainsKey(ip))
Expand All @@ -159,13 +210,6 @@ private void UpdateIpSignerList(string ip, Address agent)
ip);
_ipSignerList[ip] = new HashSet<Address>();
}
else
{
_logger.Information(
"[GRAPHQL-MULTI-ACCOUNT-MANAGER] List already created for IP: {IP} Count: {Count}",
ip,
_ipSignerList[ip].Count);
}

_ipSignerList[ip].Add(agent);
AddClientIpInfo(agent, ip);
Expand Down

0 comments on commit 72a435d

Please sign in to comment.