Skip to content

Commit

Permalink
Track the amount of websocket connections
Browse files Browse the repository at this point in the history
  • Loading branch information
Exca-DK committed Dec 4, 2023
1 parent e1125b6 commit 9535a4b
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 20 deletions.
21 changes: 20 additions & 1 deletion jsonrpc/event_listener.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package jsonrpc

import "time"
import (
"net"
"time"
)

type NewRequestListener interface {
OnNewRequest(method string)
OnNewConnection(conn net.Conn)
OnDisconnect(conn net.Conn)
}

type EventListener interface {
Expand All @@ -14,6 +19,8 @@ type EventListener interface {

type SelectiveListener struct {
OnNewRequestCb func(method string)
OnNewConnectionCb func(conn net.Conn)
OnDisconnectCb func(conn net.Conn)
OnRequestHandledCb func(method string, took time.Duration)
OnRequestFailedCb func(method string, data any)
}
Expand All @@ -24,6 +31,18 @@ func (l *SelectiveListener) OnNewRequest(method string) {
}
}

func (l *SelectiveListener) OnNewConnection(conn net.Conn) {
if l.OnNewConnectionCb != nil {
l.OnNewConnectionCb(conn)
}
}

func (l *SelectiveListener) OnDisconnect(conn net.Conn) {
if l.OnDisconnectCb != nil {
l.OnDisconnectCb(conn)
}
}

func (l *SelectiveListener) OnRequestHandled(method string, took time.Duration) {
if l.OnRequestHandledCb != nil {
l.OnRequestHandledCb(method, took)
Expand Down
29 changes: 28 additions & 1 deletion jsonrpc/event_listener_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package jsonrpc_test

import "time"
import (
"net"
"time"
)

type CountingEventListener struct {
OnNewRequestLogs []string
Expand All @@ -12,6 +15,22 @@ type CountingEventListener struct {
method string
data any
}
OnConnectionCalls map[net.Conn]int
}

func NewCountingEventListener() *CountingEventListener {
return &CountingEventListener{
OnNewRequestLogs: []string{},
OnRequestHandledCalls: []struct {
method string
took time.Duration
}{},
OnRequestFailedCalls: []struct {
method string
data any
}{},
OnConnectionCalls: map[net.Conn]int{},
}
}

func (l *CountingEventListener) OnNewRequest(method string) {
Expand All @@ -37,3 +56,11 @@ func (l *CountingEventListener) OnRequestFailed(method string, data any) {
data: data,
})
}

func (l *CountingEventListener) OnNewConnection(conn net.Conn) {
l.OnConnectionCalls[conn]++
}

func (l *CountingEventListener) OnDisconnect(conn net.Conn) {
l.OnConnectionCalls[conn]--
}
105 changes: 93 additions & 12 deletions jsonrpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package jsonrpc
import (
"context"
"io"
"math"
"net"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -51,10 +53,11 @@ func (ws *Websocket) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ws.log.Errorw("Failed to upgrade connection", "err", err)
return
}

// TODO include connection information, such as the remote address, in the logs.

wsc := newWebsocketConn(r.Context(), conn, ws.connParams)
wsc := newWebsocketConn(r.Context(), conn, ws.connParams, r.RemoteAddr)
ws.listener.OnNewConnection(wsc)
defer ws.listener.OnDisconnect(wsc)

for {
_, wsc.r, err = wsc.conn.Reader(wsc.ctx)
Expand Down Expand Up @@ -106,36 +109,114 @@ func DefaultWebsocketConnParams() *WebsocketConnParams {
}
}

type websocketAddr string

func (addr websocketAddr) Network() string { return "websocket" }

func (addr websocketAddr) String() string { return addr.Network() + "/" + string(addr) }

type websocketConn struct {
r io.Reader
conn *websocket.Conn
ctx context.Context
params *WebsocketConnParams

// write specific fields
writeTimer *time.Timer
writeCancel func()
writeCtx context.Context

// read specific fields
readTimer *time.Timer
readCancel func()
readCtx context.Context

remoteAddr string
}

func newWebsocketConn(ctx context.Context, conn *websocket.Conn, params *WebsocketConnParams) *websocketConn {
func newWebsocketConn(ctx context.Context, conn *websocket.Conn, params *WebsocketConnParams, remote string) *websocketConn {
conn.SetReadLimit(params.ReadLimit)
return &websocketConn{
conn: conn,
ctx: ctx,
params: params,
wc := &websocketConn{
conn: conn,
ctx: ctx,
params: params,
remoteAddr: remote,
}

wc.readCtx, wc.readCancel = context.WithCancel(ctx)
wc.readTimer = time.AfterFunc(math.MaxInt64, wc.readCancel)
if !wc.readTimer.Stop() {
<-wc.readTimer.C
}

wc.writeCtx, wc.writeCancel = context.WithCancel(ctx)
wc.writeTimer = time.AfterFunc(math.MaxInt64, wc.writeCancel)
if !wc.writeTimer.Stop() {
<-wc.writeTimer.C
}

return wc
}

// Close stops both reading and writing, sends a StatusNormalClosure code and releases associated resources.
func (wsc *websocketConn) Close() error {
wsc.readTimer.Stop()
wsc.readCancel()
wsc.writeTimer.Stop()
wsc.writeCancel()
return wsc.conn.Close(websocket.StatusNormalClosure, "")
}

func (wsc *websocketConn) RemoteAddr() net.Addr {
return websocketAddr(wsc.remoteAddr)
}

func (wsc *websocketConn) LocalAddr() net.Addr {
return websocketAddr("unknown")
}

func (wsc *websocketConn) Read(p []byte) (int, error) {
return wsc.r.Read(p)
}

// Write returns the number of bytes of p sent, not including the header.
func (wsc *websocketConn) Write(p []byte) (int, error) {
// TODO write responses concurrently. Unlike gorilla/websocket, nhooyr.io/websocket
// permits concurrent writes.

writeCtx, writeCancel := context.WithTimeout(wsc.ctx, wsc.params.WriteDuration)
defer writeCancel()
if err := wsc.SetWriteDeadline(time.Now().Add(wsc.params.WriteDuration)); err != nil {
return 0, err
}
// Use MessageText since JSON is a text format.
if err := wsc.conn.Write(writeCtx, websocket.MessageText, p); err != nil {
if err := wsc.conn.Write(wsc.writeCtx, websocket.MessageText, p); err != nil {
return 0, err
}
return len(p), nil
}

func (wsc *websocketConn) SetDeadline(t time.Time) error {
wsc.SetReadDeadline(t) //nolint:errcheck
wsc.SetWriteDeadline(t) //nolint:errcheck
return nil
}

func (wsc *websocketConn) SetReadDeadline(t time.Time) error {
return wsc.setDeadline(t, wsc.readTimer)
}

func (wsc *websocketConn) SetWriteDeadline(t time.Time) error {
return wsc.setDeadline(t, wsc.writeTimer)
}

func (wsc *websocketConn) setDeadline(t time.Time, timer *time.Timer) error {
// Support net.Conn non-timeout: "A zero value for t means I/O operations will not time out."
if t.IsZero() {
timer.Stop()
} else {
// Don't panic on non-positive timer
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
timer.Reset(dur)
}
return nil
}
18 changes: 13 additions & 5 deletions jsonrpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ import (

// The caller is responsible for closing the connection.
func testConnection(t *testing.T, ctx context.Context, method jsonrpc.Method, listener jsonrpc.EventListener) *websocket.Conn {
rpc := jsonrpc.NewServer(1, utils.NewNopZapLogger()).WithListener(listener)
rpc := jsonrpc.NewServer(1, utils.NewNopZapLogger())
require.NoError(t, rpc.RegisterMethods(method))

// Server
srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, utils.NewNopZapLogger()))
srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, utils.NewNopZapLogger()).WithListener(listener))

// Client
conn, resp, err := websocket.Dial(ctx, srv.URL, nil) //nolint:bodyclose // websocket package closes resp.Body for us.
Expand All @@ -41,8 +41,8 @@ func TestHandler(t *testing.T) {
return msg, nil
},
}
listener := CountingEventListener{}
conn := testConnection(t, ctx, method, &listener)
listener := NewCountingEventListener()
conn := testConnection(t, ctx, method, listener)

msg := `{"jsonrpc" : "2.0", "method" : "test_echo", "params" : [ "abc123" ], "id" : 1}`
err := conn.Write(context.Background(), websocket.MessageText, []byte(msg))
Expand All @@ -53,8 +53,16 @@ func TestHandler(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, want, string(got))
assert.Len(t, listener.OnNewRequestLogs, 1)
assert.Len(t, listener.OnConnectionCalls, 1)
for _, v := range listener.OnConnectionCalls {
assert.Equal(t, 1, v)
}

require.NoError(t, conn.Close(websocket.StatusNormalClosure, ""))
assert.Len(t, listener.OnConnectionCalls, 1)
for _, v := range listener.OnConnectionCalls {
assert.Equal(t, 0, v)
}
}

func TestSendFromHandler(t *testing.T) {
Expand All @@ -76,7 +84,7 @@ func TestSendFromHandler(t *testing.T) {
return 0, nil
},
}
conn := testConnection(t, ctx, method, &CountingEventListener{})
conn := testConnection(t, ctx, method, NewCountingEventListener())

req := `{"jsonrpc" : "2.0", "method" : "test", "params":[], "id" : 1}`
err := conn.Write(context.Background(), websocket.MessageText, []byte(req))
Expand Down
17 changes: 16 additions & 1 deletion node/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package node

import (
"math"
"net"
"time"

"github.com/NethermindEth/juno/blockchain"
Expand Down Expand Up @@ -96,12 +97,26 @@ func makeWSMetrics() jsonrpc.NewRequestListener {
Subsystem: "ws",
Name: "requests",
})
prometheus.MustRegister(reqCounter)
connGauge := prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "rpc",
Subsystem: "ws",
Name: "connections",
})
prometheus.MustRegister(
reqCounter,
connGauge,
)

return &jsonrpc.SelectiveListener{
OnNewRequestCb: func(method string) {
reqCounter.Inc()
},
OnNewConnectionCb: func(conn net.Conn) {
connGauge.Inc()
},
OnDisconnectCb: func(conn net.Conn) {
connGauge.Dec()
},
}
}

Expand Down

0 comments on commit 9535a4b

Please sign in to comment.