diff --git a/api/auth/secp256k1_test.go b/api/auth/secp256k1_test.go index 8ff48e50..880bb87a 100644 --- a/api/auth/secp256k1_test.go +++ b/api/auth/secp256k1_test.go @@ -151,7 +151,9 @@ func TestProtocolHandshake(t *testing.T) { } clientURI := testServer.URL - conn, err := protocol.NewConn(clientURI, clientAuth) + conn, err := protocol.NewConn(clientURI, &protocol.ConnOptions{ + Authenticator: clientAuth, + }) if err != nil { t.Fatal(err) } diff --git a/api/protocol/protocol.go b/api/protocol/protocol.go index 7bc4763b..660ca16f 100644 --- a/api/protocol/protocol.go +++ b/api/protocol/protocol.go @@ -362,7 +362,7 @@ type Conn struct { sync.RWMutex serverURL string - auth Authenticator + opts ConnOptions msgID uint64 wsc *websocket.Conn @@ -372,8 +372,21 @@ type Conn struct { calls map[string]chan *readResult } +// ConnOptions are options available for a Conn. +type ConnOptions struct { + // ReadLimit is the maximum number of bytes to read from the connection. + // Defaults to defaultConnReadLimit. + ReadLimit int64 + + // Authenticator is the connection authenticator. + Authenticator Authenticator +} + +// defaultConnReadLimit is the default connection read limit. +const defaultConnReadLimit = 512 * (1 << 10) // 512 KiB + // NewConn returns a client side connection object. -func NewConn(urlStr string, authenticator Authenticator) (*Conn, error) { +func NewConn(urlStr string, opts *ConnOptions) (*Conn, error) { log.Tracef("NewConn: %v", urlStr) defer log.Tracef("NewConn exit: %v", urlStr) @@ -382,9 +395,16 @@ func NewConn(urlStr string, authenticator Authenticator) (*Conn, error) { return nil, err } + if opts == nil { + opts = new(ConnOptions) + } + if opts.ReadLimit <= 0 { + opts.ReadLimit = defaultConnReadLimit + } + ac := &Conn{ serverURL: u.String(), - auth: authenticator, + opts: *opts, calls: make(map[string]chan *readResult), msgID: 1, } @@ -414,7 +434,7 @@ func (ac *Conn) Connect(ctx context.Context) error { if err != nil { return fmt.Errorf("dial server: %w", err) } - conn.SetReadLimit(512 * 1024) // XXX - default is 32KB + conn.SetReadLimit(ac.opts.ReadLimit) defer func() { if ac.wsc == nil { conn.Close(websocket.StatusNormalClosure, "") @@ -424,14 +444,14 @@ func (ac *Conn) Connect(ctx context.Context) error { handshakeCtx, cancel := context.WithTimeout(ctx, WSHandshakeTimeout) defer cancel() - if ac.auth != nil { + if auth := ac.opts.Authenticator; auth != nil { log.Tracef("Connect: handshaking with %v", ac.serverURL) - if err := ac.auth.HandshakeClient(handshakeCtx, NewWSConn(conn)); err != nil { + if err := auth.HandshakeClient(handshakeCtx, NewWSConn(conn)); err != nil { return HandshakeError(fmt.Sprintf("failed to handshake with server: %v", err)) } } - // done as an API message and it should be done at the protocol + // done as an API message, and it should be done at the protocol // level instead... var msg Message if err := NewWSConn(conn).ReadJSON(connectCtx, &msg); err != nil { @@ -456,7 +476,7 @@ func (ac *Conn) Connect(ctx context.Context) error { return nil } -// wsConn returns the underlying webscket connection. +// wsConn returns the underlying websocket connection. func (ac *Conn) wsConn() *websocket.Conn { ac.RLock() defer ac.RUnlock() diff --git a/service/bss/bss.go b/service/bss/bss.go index c7bf2e48..99652f30 100644 --- a/service/bss/bss.go +++ b/service/bss/bss.go @@ -629,7 +629,9 @@ func (s *Server) connectBFG(ctx context.Context) error { log.Tracef("connectBFG") defer log.Tracef("connectBFG exit") - conn, err := protocol.NewConn(s.cfg.BFGURL, nil) + conn, err := protocol.NewConn(s.cfg.BFGURL, &protocol.ConnOptions{ + ReadLimit: 2 * (1 << 20), // 2 MiB + }) if err != nil { return err } diff --git a/service/popm/popm.go b/service/popm/popm.go index 6c9db075..8043cd73 100644 --- a/service/popm/popm.go +++ b/service/popm/popm.go @@ -808,7 +808,9 @@ func (m *Miner) connectBFG(pctx context.Context) error { return err } - conn, err = protocol.NewConn(m.cfg.BFGWSURL, authenticator) + conn, err = protocol.NewConn(m.cfg.BFGWSURL, &protocol.ConnOptions{ + Authenticator: authenticator, + }) if err != nil { return err }