diff --git a/build.gradle.kts b/build.gradle.kts index 51bb9ca..9c71fdf 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -11,7 +11,7 @@ val authorName = "ccbluex" val projectUrl = "https://github.com/ccbluex/netty-httpserver" group = "net.ccbluex" -version = "2.0.0" +version = "2.1.0" repositories { mavenCentral() diff --git a/src/main/kotlin/net/ccbluex/netty/http/HttpConductor.kt b/src/main/kotlin/net/ccbluex/netty/http/HttpConductor.kt index e0c6113..f4a968f 100644 --- a/src/main/kotlin/net/ccbluex/netty/http/HttpConductor.kt +++ b/src/main/kotlin/net/ccbluex/netty/http/HttpConductor.kt @@ -58,9 +58,6 @@ internal class HttpConductor(private val server: HttpServer) { val httpHeaders = response.headers() httpHeaders[HttpHeaderNames.CONTENT_TYPE] = "text/plain" httpHeaders[HttpHeaderNames.CONTENT_LENGTH] = response.content().readableBytes() - httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN] = "*" - httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS] = "GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS" - httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS] = "Content-Type, Content-Length, Authorization, Accept, X-Requested-With" return@runCatching response } diff --git a/src/main/kotlin/net/ccbluex/netty/http/HttpServer.kt b/src/main/kotlin/net/ccbluex/netty/http/HttpServer.kt index 4ebd4bb..e352c61 100644 --- a/src/main/kotlin/net/ccbluex/netty/http/HttpServer.kt +++ b/src/main/kotlin/net/ccbluex/netty/http/HttpServer.kt @@ -26,12 +26,17 @@ import io.netty.channel.epoll.EpollEventLoopGroup import io.netty.channel.epoll.EpollServerSocketChannel import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.codec.http.FullHttpResponse import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler +import net.ccbluex.netty.http.middleware.Middleware +import net.ccbluex.netty.http.middleware.MiddlewareFunction +import net.ccbluex.netty.http.model.RequestContext import net.ccbluex.netty.http.rest.RouteController import net.ccbluex.netty.http.websocket.WebSocketController import org.apache.logging.log4j.LogManager + /** * NettyRest - A Web Rest-API server with support for WebSocket and File Serving using Netty. * @@ -42,10 +47,20 @@ class HttpServer { val routeController = RouteController() val webSocketController = WebSocketController() + val middlewares = mutableListOf() + companion object { internal val logger = LogManager.getLogger("HttpServer") } + fun middleware(middlewareFunction: MiddlewareFunction) { + middlewares += middlewareFunction + } + + fun middleware(middleware: Middleware) { + middlewares += middleware::middleware + } + /** * Starts the Netty server on the specified port. */ diff --git a/src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt b/src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt index a785001..297f16d 100644 --- a/src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt +++ b/src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt @@ -109,10 +109,11 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun // If this is the last content, process the request if (msg is LastHttpContent) { localRequestContext.remove() - + val httpConductor = HttpConductor(server) val response = httpConductor.processRequestContext(requestContext) - ctx.writeAndFlush(response) + val httpResponse = server.middlewares.fold(response) { acc, f -> f(requestContext, acc) } + ctx.writeAndFlush(httpResponse) } } diff --git a/src/main/kotlin/net/ccbluex/netty/http/middleware/CorsMiddleware.kt b/src/main/kotlin/net/ccbluex/netty/http/middleware/CorsMiddleware.kt new file mode 100644 index 0000000..28d2a0a --- /dev/null +++ b/src/main/kotlin/net/ccbluex/netty/http/middleware/CorsMiddleware.kt @@ -0,0 +1,64 @@ +package net.ccbluex.netty.http.middleware + +import io.netty.handler.codec.http.FullHttpResponse +import io.netty.handler.codec.http.HttpHeaderNames +import net.ccbluex.netty.http.HttpServer.Companion.logger +import net.ccbluex.netty.http.model.RequestContext +import java.net.URI +import java.net.URISyntaxException + +/** + * Middleware to handle Cross-Origin Resource Sharing (CORS) requests. + * + * @param allowedOrigins List of allowed (host) origins (default: localhost, 127.0.0.1) + * - If we want to specify a protocol and port, we should use the full origin (e.g., http://localhost:8080). + * @param allowedMethods List of allowed HTTP methods (default: GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS) + * @param allowedHeaders List of allowed HTTP headers (default: Content-Type, Content-Length, Authorization, Accept, X-Requested-With) + * + * @see RequestContext + */ +class CorsMiddleware( + private val allowedOrigins: List = + listOf("localhost", "127.0.0.1"), + private val allowedMethods: List = + listOf("GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"), + private val allowedHeaders: List = + listOf("Content-Type", "Content-Length", "Authorization", "Accept", "X-Requested-With") +): Middleware { + + /** + * Middleware to handle CORS requests. + * Pass to server.middleware() to apply the CORS policy to all requests. + */ + override fun middleware(context: RequestContext, response: FullHttpResponse): FullHttpResponse { + val httpHeaders = response.headers() + val requestOrigin = context.headers["origin"] ?: context.headers["Origin"] + + if (requestOrigin != null) { + try { + // Parse the origin to extract the hostname (ignoring the port) + val uri = URI(requestOrigin) + val host = uri.host + + // Allow requests from localhost or 127.0.0.1 regardless of the port + if (allowedOrigins.contains(host) || allowedOrigins.contains(requestOrigin)) { + httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN] = requestOrigin + } else { + // Block cross-origin requests by not allowing other origins + httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN] = "null" + } + } catch (e: URISyntaxException) { + // Handle bad URIs by setting a default CORS policy or logging the error + httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN] = "null" + logger.error("Invalid Origin header: $requestOrigin", e) + } + + // Allow specific methods and headers for cross-origin requests + httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS] = allowedMethods.joinToString(", ") + httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS] = allowedHeaders.joinToString(", ") + } + + return response + } + +} \ No newline at end of file diff --git a/src/main/kotlin/net/ccbluex/netty/http/middleware/Middleware.kt b/src/main/kotlin/net/ccbluex/netty/http/middleware/Middleware.kt new file mode 100644 index 0000000..00ea757 --- /dev/null +++ b/src/main/kotlin/net/ccbluex/netty/http/middleware/Middleware.kt @@ -0,0 +1,10 @@ +package net.ccbluex.netty.http.middleware + +import io.netty.handler.codec.http.FullHttpResponse +import net.ccbluex.netty.http.model.RequestContext + +typealias MiddlewareFunction = (RequestContext, FullHttpResponse) -> FullHttpResponse + +interface Middleware { + fun middleware(context: RequestContext, response: FullHttpResponse): FullHttpResponse +} \ No newline at end of file diff --git a/src/main/kotlin/net/ccbluex/netty/http/util/HttpResponse.kt b/src/main/kotlin/net/ccbluex/netty/http/util/HttpResponse.kt index e9ae85f..eec1244 100644 --- a/src/main/kotlin/net/ccbluex/netty/http/util/HttpResponse.kt +++ b/src/main/kotlin/net/ccbluex/netty/http/util/HttpResponse.kt @@ -47,9 +47,7 @@ fun httpResponse(status: HttpResponseStatus, contentType: String = "text/plain", val httpHeaders = response.headers() httpHeaders[HttpHeaderNames.CONTENT_TYPE] = contentType httpHeaders[HttpHeaderNames.CONTENT_LENGTH] = response.content().readableBytes() - httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN] = "*" - httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS] = "GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS" - httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS] = "Content-Type, Content-Length, Authorization, Accept, X-Requested-With" + return response } @@ -140,7 +138,6 @@ fun httpFile(file: File): FullHttpResponse { val httpHeaders = response.headers() httpHeaders[HttpHeaderNames.CONTENT_TYPE] = tika.detect(file) httpHeaders[HttpHeaderNames.CONTENT_LENGTH] = response.content().readableBytes() - httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN] = "*" return response } @@ -162,7 +159,70 @@ fun httpFileStream(stream: InputStream): FullHttpResponse { val httpHeaders = response.headers() httpHeaders[HttpHeaderNames.CONTENT_TYPE] = tika.detect(bytes) httpHeaders[HttpHeaderNames.CONTENT_LENGTH] = response.content().readableBytes() - httpHeaders[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN] = "*" return response -} \ No newline at end of file +} + +/** + * Creates an HTTP 204 No Content response. + * + * @return A FullHttpResponse object. + */ +fun httpNoContent(): FullHttpResponse { + val response = DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.NO_CONTENT + ) + + val httpHeaders = response.headers() + httpHeaders[HttpHeaderNames.CONTENT_LENGTH] = 0 + return response +} + +/** + * Creates an HTTP 405 Method Not Allowed response with the given method. + * + * @param method The method that is not allowed. + * @return A FullHttpResponse object. + */ +fun httpMethodNotAllowed(method: String): FullHttpResponse { + val jsonObject = JsonObject() + jsonObject.addProperty("method", method) + return httpResponse(HttpResponseStatus.METHOD_NOT_ALLOWED, jsonObject) +} + +/** + * Creates an HTTP 401 Unauthorized response with the given reason. + * + * @param reason The reason for the 401 error. + * @return A FullHttpResponse object. + */ +fun httpUnauthorized(reason: String): FullHttpResponse { + val jsonObject = JsonObject() + jsonObject.addProperty("reason", reason) + return httpResponse(HttpResponseStatus.UNAUTHORIZED, jsonObject) +} + +/** + * Creates an HTTP 429 Too Many Requests response with the given reason. + * + * @param reason The reason for the 429 error. + * @return A FullHttpResponse object. + */ +fun httpTooManyRequests(reason: String): FullHttpResponse { + val jsonObject = JsonObject() + jsonObject.addProperty("reason", reason) + return httpResponse(HttpResponseStatus.TOO_MANY_REQUESTS, jsonObject) +} + +/** + * Creates an HTTP 503 Service Unavailable response with the given reason. + * + * @param reason The reason for the 503 error. + * @return A FullHttpResponse object. + */ +fun httpServiceUnavailable(reason: String): FullHttpResponse { + val jsonObject = JsonObject() + jsonObject.addProperty("reason", reason) + return httpResponse(HttpResponseStatus.SERVICE_UNAVAILABLE, jsonObject) +} diff --git a/src/test/kotlin/HttpMiddlewareServerTest.kt b/src/test/kotlin/HttpMiddlewareServerTest.kt new file mode 100644 index 0000000..f21367b --- /dev/null +++ b/src/test/kotlin/HttpMiddlewareServerTest.kt @@ -0,0 +1,132 @@ +import com.google.gson.JsonObject +import io.netty.handler.codec.http.FullHttpResponse +import net.ccbluex.netty.http.HttpServer +import net.ccbluex.netty.http.model.RequestObject +import net.ccbluex.netty.http.util.httpOk +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.Response +import org.junit.jupiter.api.* +import java.io.File +import java.nio.file.Files +import kotlin.concurrent.thread +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +/** + * Test class for the HttpServer, focusing on verifying the routing capabilities + * and correctness of responses from different endpoints. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class HttpMiddlewareServerTest { + + private lateinit var serverThread: Thread + private val client = OkHttpClient() + + /** + * This method sets up the necessary environment before any tests are run. + * It creates a temporary directory with dummy files and starts the HTTP server + * in a separate thread. + */ + @BeforeAll + fun initialize() { + // Start the HTTP server in a separate thread + serverThread = thread { + startHttpServer() + } + + // Allow the server some time to start + Thread.sleep(1000) + } + + /** + * This method cleans up resources after all tests have been executed. + * It stops the server and deletes the temporary directory. + */ + @AfterAll + fun cleanup() { + serverThread.interrupt() + } + + /** + * This function starts the HTTP server with routing configured for + * different difficulty levels. + */ + private fun startHttpServer() { + val server = HttpServer() + + server.routeController.apply { + get("/", ::static) + } + + server.middleware { requestContext, fullHttpResponse -> + // Add custom headers to the response + fullHttpResponse.headers().add("X-Custom-Header", "Custom Value") + + // Add a custom header if there is a query parameter + if (requestContext.params.isNotEmpty()) { + fullHttpResponse.headers().add("X-Query-Param", + requestContext.params.entries.joinToString(",")) + } + + fullHttpResponse + } + + server.start(8080) // Start the server on port 8080 + } + + @Suppress("UNUSED_PARAMETER") + fun static(requestObject: RequestObject): FullHttpResponse { + return httpOk(JsonObject().apply { + addProperty("message", "Hello, World!") + }) + } + + /** + * Utility function to make HTTP GET requests to the specified path. + * + * @param path The path for the request. + * @return The HTTP response. + */ + private fun makeRequest(path: String): Response { + val request = Request.Builder() + .url("http://localhost:8080$path") + .build() + return client.newCall(request).execute() + } + + /** + * Test the root endpoint ("/") and verify that it returns the correct number + * of files in the directory. + */ + @Test + fun testRootEndpoint() { + val response = makeRequest("/") + assertEquals(200, response.code(), "Expected status code 200") + + val responseBody = response.body()?.string() + assertNotNull(responseBody, "Response body should not be null") + + assertTrue(responseBody.contains("Hello, World!"), "Response should contain 'Hello, World!'") + } + + /** + * Test the root endpoint ("/") with a query parameter and verify that the + * custom header is added to the response. + */ + @Test + fun testRootEndpointWithQueryParam() { + val response = makeRequest("/?param1=value1¶m2=value2") + assertEquals(200, response.code(), "Expected status code 200") + + val responseBody = response.body()?.string() + assertNotNull(responseBody, "Response body should not be null") + + assertTrue(responseBody.contains("Hello, World!"), "Response should contain 'Hello, World!'") + assertTrue(response.headers("X-Custom-Header").contains("Custom Value"), "Custom header should be present") + assertTrue(response.headers("X-Query-Param").contains("param1=value1,param2=value2"), + "Query parameter should be present in the response") + } + +}