From c3cc56019c88bc95eb750501ecaa764fbb933eb4 Mon Sep 17 00:00:00 2001 From: area363 Date: Tue, 3 Dec 2024 17:21:51 +0900 Subject: [PATCH] override rate-limiter and multiaccount manager when using jwt token --- NineChronicles.Headless.Executable/Program.cs | 1 + NineChronicles.Headless/GraphQLService.cs | 21 +++++++++--------- .../Middleware/CustomRateLimitMiddleware.cs | 20 ++++++++++++++--- .../Middleware/HttpCaptureMiddleware.cs | 1 + .../HttpMultiAccountManagementMiddleware.cs | 22 ++++++++++++++++--- 5 files changed, 49 insertions(+), 16 deletions(-) diff --git a/NineChronicles.Headless.Executable/Program.cs b/NineChronicles.Headless.Executable/Program.cs index fcf00694f..8caff2e51 100644 --- a/NineChronicles.Headless.Executable/Program.cs +++ b/NineChronicles.Headless.Executable/Program.cs @@ -331,6 +331,7 @@ public async Task Run( try { IHostBuilder hostBuilder = Host.CreateDefaultBuilder(); + hostBuilder.ConfigureAppConfiguration(builder => builder.AddConfiguration(configuration)); var standaloneContext = new StandaloneContext { diff --git a/NineChronicles.Headless/GraphQLService.cs b/NineChronicles.Headless/GraphQLService.cs index c68027eb6..4c8d448c5 100644 --- a/NineChronicles.Headless/GraphQLService.cs +++ b/NineChronicles.Headless/GraphQLService.cs @@ -220,6 +220,17 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) } // Capture requests + app.UseMiddleware(); + + app.UseRouting(); + app.UseAuthorization(); + if (Convert.ToBoolean(Configuration.GetSection("IpRateLimiting")["EnableEndpointRateLimiting"])) + { + app.UseMiddleware(); + app.UseMiddleware(); + app.UseMvc(); + } + if (Convert.ToBoolean(Configuration.GetSection("MultiAccountManaging")["EnableManaging"])) { ConcurrentDictionary> ipSignerList = new(); @@ -229,7 +240,6 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) Publisher); } - app.UseMiddleware(); app.UseMiddleware(); if (Convert.ToBoolean(Configuration.GetSection("Jwt")["EnableJwtAuthentication"])) @@ -246,15 +256,6 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) app.UseCors("AllowAllOrigins"); } - app.UseRouting(); - app.UseAuthorization(); - if (Convert.ToBoolean(Configuration.GetSection("IpRateLimiting")["EnableEndpointRateLimiting"])) - { - app.UseMiddleware(); - app.UseMiddleware(); - app.UseMvc(); - } - app.UseEndpoints(endpoints => { endpoints.MapControllers(); diff --git a/NineChronicles.Headless/Middleware/CustomRateLimitMiddleware.cs b/NineChronicles.Headless/Middleware/CustomRateLimitMiddleware.cs index a1b3df76b..e00348218 100644 --- a/NineChronicles.Headless/Middleware/CustomRateLimitMiddleware.cs +++ b/NineChronicles.Headless/Middleware/CustomRateLimitMiddleware.cs @@ -6,25 +6,33 @@ using NineChronicles.Headless.Properties; using Serilog; using ILogger = Serilog.ILogger; +using System.Linq; +using Microsoft.Extensions.Configuration; namespace NineChronicles.Headless.Middleware { + public class CustomRateLimitMiddleware : RateLimitMiddleware { private readonly ILogger _logger; private readonly IRateLimitConfiguration _config; private readonly IOptions _options; + private readonly string _whitelistedIp; + private readonly string _jwtKey; public CustomRateLimitMiddleware(RequestDelegate next, IProcessingStrategy processingStrategy, IOptions options, IIpPolicyStore policyStore, - IRateLimitConfiguration config) + IRateLimitConfiguration config, + Microsoft.Extensions.Configuration.IConfiguration configuration) : base(next, options?.Value, new CustomIpRateLimitProcessor(options?.Value!, policyStore, processingStrategy), config) { _config = config; _options = options!; _logger = Log.Logger.ForContext(); + _jwtKey = configuration["Jwt:Key"] ?? string.Empty; + _whitelistedIp = configuration.GetSection("IpRateLimiting:IpWhitelist")?.Get()?.FirstOrDefault() ?? "127.0.0.1"; } protected override void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) @@ -45,14 +53,20 @@ public override async Task ResolveIdentityAsync(HttpConte if (httpContext.Request.Protocol == "HTTP/1.1") { - var body = await new StreamReader(httpContext.Request.Body).ReadToEndAsync(); + var body = httpContext.Items["RequestBody"]!.ToString()!; httpContext.Request.Body.Seek(0, SeekOrigin.Begin); if (body.Contains("stageTransaction")) { identity.Path = "/graphql/stagetransaction"; } + } - return identity; + // Check for JWT secret key in headers + if (httpContext.Request.Headers.TryGetValue("Authorization", out var authHeaderValue) && + !string.IsNullOrEmpty(_jwtKey) && + authHeaderValue.ToString().Equals($"Bearer {_jwtKey}", System.StringComparison.OrdinalIgnoreCase)) + { + identity.ClientIp = _whitelistedIp; } return identity; diff --git a/NineChronicles.Headless/Middleware/HttpCaptureMiddleware.cs b/NineChronicles.Headless/Middleware/HttpCaptureMiddleware.cs index f6d9c0c95..54a9b89d4 100644 --- a/NineChronicles.Headless/Middleware/HttpCaptureMiddleware.cs +++ b/NineChronicles.Headless/Middleware/HttpCaptureMiddleware.cs @@ -25,6 +25,7 @@ public async Task InvokeAsync(HttpContext context) context.Request.EnableBuffering(); var remoteIp = context.Connection.RemoteIpAddress; var body = await new StreamReader(context.Request.Body).ReadToEndAsync(); + context.Items["RequestBody"] = body; _logger.Information("[GRAPHQL-REQUEST-CAPTURE] IP: {IP} Method: {Method} Endpoint: {Path} {Body}", remoteIp, context.Request.Method, context.Request.Path, body); context.Request.Body.Seek(0, SeekOrigin.Begin); diff --git a/NineChronicles.Headless/Middleware/HttpMultiAccountManagementMiddleware.cs b/NineChronicles.Headless/Middleware/HttpMultiAccountManagementMiddleware.cs index 4fee61167..1cc21634c 100644 --- a/NineChronicles.Headless/Middleware/HttpMultiAccountManagementMiddleware.cs +++ b/NineChronicles.Headless/Middleware/HttpMultiAccountManagementMiddleware.cs @@ -17,6 +17,8 @@ namespace NineChronicles.Headless.Middleware { + using Microsoft.Extensions.Configuration; + public class HttpMultiAccountManagementMiddleware { private static readonly ConcurrentDictionary MultiAccountTxIntervalTracker = new(); @@ -27,13 +29,16 @@ public class HttpMultiAccountManagementMiddleware private readonly ConcurrentDictionary> _ipSignerList; private readonly IOptions _options; private ActionEvaluationPublisher _publisher; + private readonly string _whitelistedIp; + private readonly string _jwtKey; public HttpMultiAccountManagementMiddleware( RequestDelegate next, StandaloneContext standaloneContext, ConcurrentDictionary> ipSignerList, IOptions options, - ActionEvaluationPublisher publisher) + ActionEvaluationPublisher publisher, + Microsoft.Extensions.Configuration.IConfiguration configuration) { _next = next; _logger = Log.Logger.ForContext(); @@ -41,6 +46,8 @@ public HttpMultiAccountManagementMiddleware( _ipSignerList = ipSignerList; _options = options; _publisher = publisher; + _jwtKey = configuration["Jwt:Key"] ?? string.Empty; + _whitelistedIp = configuration.GetSection("IpRateLimiting:IpWhitelist")?.Get()?.FirstOrDefault() ?? "127.0.0.1"; } private static void ManageMultiAccount(Address agent) @@ -58,9 +65,18 @@ public async Task InvokeAsync(HttpContext context) // Prevent to harm HTTP/2 communication. if (context.Request.Protocol == "HTTP/1.1") { - context.Request.EnableBuffering(); var remoteIp = context.Connection.RemoteIpAddress!.ToString(); - var body = await new StreamReader(context.Request.Body).ReadToEndAsync(); + + // Skip if the remote IP is whitelisted + if (context.Request.Headers.TryGetValue("Authorization", out var authHeaderValue) && + !string.IsNullOrEmpty(_jwtKey) && + authHeaderValue.ToString().Equals($"Bearer {_jwtKey}", System.StringComparison.OrdinalIgnoreCase)) + { + await _next(context); + return; + } + + var body = context.Items["RequestBody"]!.ToString()!; context.Request.Body.Seek(0, SeekOrigin.Begin); if (_options.Value.EnableManaging && body.Contains("stageTransaction")) {