From a1fc200e5b85a7737a9834ec28fb768fb7bde7bd Mon Sep 17 00:00:00 2001 From: Aman Sanghi Date: Mon, 12 Aug 2024 22:36:30 +0530 Subject: [PATCH] Changes based on PR comments --- node/config.go | 3 +++ node/node.go | 9 +++++++++ node/rpcstack.go | 3 ++- rpc/client_test.go | 6 +++--- rpc/websocket.go | 7 +++++-- rpc/websocket_test.go | 10 +++++----- 6 files changed, 27 insertions(+), 11 deletions(-) diff --git a/node/config.go b/node/config.go index e60db34c51..8ac4c955af 100644 --- a/node/config.go +++ b/node/config.go @@ -224,6 +224,9 @@ type Config struct { // HTTPBodyLimit is the maximum number of bytes allowed in the HTTP request body. HTTPBodyLimit int `toml:",omitempty"` + + // WSReadLimit is the maximum number of bytes allowed in the websocket request body. + WSReadLimit int64 `toml:",omitempty"` } // IPCEndpoint resolves an IPC endpoint based on a configured value, taking into diff --git a/node/node.go b/node/node.go index 1d392af6df..8c72f968a3 100644 --- a/node/node.go +++ b/node/node.go @@ -423,6 +423,12 @@ func (n *Node) startRPC() error { batchResponseSizeLimit: n.config.BatchResponseMaxSize, apiFilter: n.apiFilter, } + if n.config.HTTPBodyLimit != 0 { + rpcConfig.httpBodyLimit = n.config.HTTPBodyLimit + } + if n.config.WSReadLimit != 0 { + rpcConfig.wsReadLimit = n.config.WSReadLimit + } initHttp := func(server *httpServer, port int) error { if err := server.setListenAddr(n.config.HTTPHost, port); err != nil { @@ -473,6 +479,9 @@ func (n *Node) startRPC() error { if n.config.HTTPBodyLimit != 0 { sharedConfig.httpBodyLimit = n.config.HTTPBodyLimit } + if n.config.WSReadLimit != 0 { + sharedConfig.wsReadLimit = n.config.WSReadLimit + } err := server.enableRPC(allAPIs, httpConfig{ CorsAllowedOrigins: DefaultAuthCors, Vhosts: n.config.AuthVirtualHosts, diff --git a/node/rpcstack.go b/node/rpcstack.go index d4e0831b46..329037ea1b 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -58,6 +58,7 @@ type rpcEndpointConfig struct { batchResponseSizeLimit int apiFilter map[string]bool httpBodyLimit int + wsReadLimit int64 } type rpcHandler struct { @@ -362,7 +363,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } h.wsConfig = config h.wsHandler.Store(&rpcHandler{ - Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret), + Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins, config.wsReadLimit), config.jwtSecret), server: srv, }) return nil diff --git a/rpc/client_test.go b/rpc/client_test.go index ac02ad33cf..7b1466dc67 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -584,7 +584,7 @@ func TestClientSubscriptionChannelClose(t *testing.T) { var ( srv = NewServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) + httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() @@ -745,7 +745,7 @@ func TestClientReconnect(t *testing.T) { if err != nil { t.Fatal("can't listen:", err) } - go http.Serve(l, srv.WebsocketHandler([]string{"*"})) + go http.Serve(l, srv.WebsocketHandler([]string{"*"}, 0)) return srv, l } @@ -811,7 +811,7 @@ func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, var hs *httptest.Server switch transport { case "ws": - hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"})) + hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}, 0)) case "http": hs = httptest.NewUnstartedServer(srv) default: diff --git a/rpc/websocket.go b/rpc/websocket.go index 538e53a31b..5b79b9948a 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -47,7 +47,10 @@ var wsBufferPool = new(sync.Pool) // // allowedOrigins should be a comma-separated list of allowed origin URLs. // To allow connections with any origin, pass "*". -func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { +func (s *Server) WebsocketHandler(allowedOrigins []string, wsReadLimit int64) http.Handler { + if wsReadLimit == 0 { + wsReadLimit = wsDefaultReadLimit + } var upgrader = websocket.Upgrader{ ReadBufferSize: wsReadBuffer, WriteBufferSize: wsWriteBuffer, @@ -60,7 +63,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { log.Debug("WebSocket upgrade failed", "err", err) return } - codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit) + codec := newWebsocketCodec(conn, r.Host, r.Header, wsReadLimit) s.ServeCodec(codec, 0) }) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 8d2bd9d802..95e1c61cde 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -53,7 +53,7 @@ func TestWebsocketOriginCheck(t *testing.T) { var ( srv = newTestServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() @@ -83,7 +83,7 @@ func TestWebsocketLargeCall(t *testing.T) { var ( srv = newTestServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() @@ -119,7 +119,7 @@ func TestWebsocketLargeRead(t *testing.T) { var ( srv = newTestServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() @@ -176,7 +176,7 @@ func TestWebsocketLargeRead(t *testing.T) { func TestWebsocketPeerInfo(t *testing.T) { var ( s = newTestServer() - ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) + ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"}, 0)) tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") ) defer s.Stop() @@ -260,7 +260,7 @@ func TestClientWebsocketPing(t *testing.T) { func TestClientWebsocketLargeMessage(t *testing.T) { var ( srv = NewServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) + httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop()