Skip to content

Commit

Permalink
Merge pull request #359 from lesismal/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
lesismal authored Oct 26, 2023
2 parents 21f3f04 + d00806b commit 6cf29cb
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 28 deletions.
3 changes: 1 addition & 2 deletions conn_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,7 @@ func (c *Conn) resetRead() {
if !c.closed && c.isWAdded {
c.isWAdded = false
p := c.p
p.deleteEvent(c.fd)
p.addRead(c.fd)
p.resetRead(c.fd)
}
}

Expand Down
27 changes: 27 additions & 0 deletions nbhttp/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ func (c *ClientConn) closeWithErrorWithoutLock(err error) {
}
c.handlers = nil
if c.conn != nil {
nbc, ok := c.conn.(*nbio.Conn)
if !ok {
if tlsConn, ok2 := c.conn.(*tls.Conn); ok2 {
nbc, ok = tlsConn.Conn().(*nbio.Conn)
}
}
if ok {
key, _ := conn2Array(nbc)
c.Engine.mux.Lock()
delete(c.Engine.dialerConns, key)
c.Engine.mux.Unlock()
}
c.conn.Close()
c.conn = nil
}
Expand Down Expand Up @@ -247,6 +259,11 @@ func (c *ClientConn) Do(req *http.Request, handler func(res *http.Response, conn
return
}

key, _ := conn2Array(nbc)
engine.mux.Lock()
engine.dialerConns[key] = struct{}{}
engine.mux.Unlock()

c.conn = nbc
processor := NewClientProcessor(c, c.onResponse)
parser := NewParser(processor, true, engine.ReadLimit, nbc.Execute)
Expand Down Expand Up @@ -288,6 +305,16 @@ func (c *ClientConn) Do(req *http.Request, handler func(res *http.Response, conn
return
}

key, err := conn2Array(nbc)
if err != nil {
logging.Error("add dialer conn failed: %v", err)
c.closeWithErrorWithoutLock(err)
return
}
engine.mux.Lock()
engine.dialerConns[key] = struct{}{}
engine.mux.Unlock()

isNonblock := true
tlsConn.ResetConn(nbc, isNonblock)

Expand Down
24 changes: 21 additions & 3 deletions nbhttp/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ type Engine struct {
_onClose func(c net.Conn, err error)
_onStop func()

mux sync.Mutex
conns map[connValue]struct{}
mux sync.Mutex
conns map[connValue]struct{}
dialerConns map[connValue]struct{}

// tlsBuffers [][]byte
// getTLSBuffer func(c *nbio.Conn) []byte
Expand Down Expand Up @@ -258,6 +259,11 @@ func (e *Engine) Online() int {
return len(e.conns)
}

// DialerOnline .
func (e *Engine) DialerOnline() int {
return len(e.dialerConns)
}

func (e *Engine) closeAllConns() {
e.mux.Lock()
defer e.mux.Unlock()
Expand All @@ -266,6 +272,11 @@ func (e *Engine) closeAllConns() {
c.Close()
}
}
for key := range e.dialerConns {
if c, err := array2Conn(key); err == nil {
c.Close()
}
}
}

func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(net.Conn, *tls.Config, func()), decrease func()) {
Expand Down Expand Up @@ -459,7 +470,7 @@ func (e *Engine) Shutdown(ctx context.Context) error {
logging.Info("NBIO[%v] shutdown timeout", e.Engine.Name)
return ctx.Err()
case <-ticker.C:
if len(e.conns) == 0 {
if len(e.conns)+len(e.dialerConns) == 0 {
goto Exit
}
}
Expand Down Expand Up @@ -770,6 +781,9 @@ func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease f
return
}
parser.Read(buf[:n])
if parser.hijacked {
return
}
}
}

Expand Down Expand Up @@ -819,6 +833,9 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars
logging.Debug("parser.Read failed: %v", err)
return
}
// if parser.hijacked {
// return
// }
}
if nread == 0 {
break
Expand Down Expand Up @@ -965,6 +982,7 @@ func NewEngine(conf Config) *Engine {
_onStop: func() {},
CheckUtf8: utf8.Valid,
conns: map[connValue]struct{}{},
dialerConns: map[connValue]struct{}{},
ExecuteClient: clientExecutor,

emptyRequest: (&http.Request{}).WithContext(baseCtx),
Expand Down
1 change: 1 addition & 0 deletions nbhttp/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Parser struct {

state int8
isClient bool
hijacked bool

readLimit int

Expand Down
4 changes: 3 additions & 1 deletion nbhttp/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,11 @@ func (p *ServerProcessor) OnComplete(parser *Parser) {
}

func (p *ServerProcessor) flushResponse(res *Response) {
hijacked := res.hijacked
p.parser.hijacked = hijacked
if p.conn != nil {
req := res.request
if !res.hijacked {
if !hijacked {
res.eoncodeHead()
if err := res.flushTrailer(p.conn); err != nil {
p.conn.Close()
Expand Down
74 changes: 65 additions & 9 deletions nbhttp/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package websocket

import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -54,7 +55,8 @@ type Conn struct {

mux sync.Mutex

session interface{}
chSessionInited chan struct{}
session interface{}

sendQueue [][]byte
sendQueueSize int
Expand All @@ -67,6 +69,8 @@ type Conn struct {
remoteCompressionEnabled bool
enableWriteCompression bool
isBlockingMod bool
isReadingByParser bool
isInReadingLoop bool
expectingFragments bool
compress bool
opcode MessageType
Expand Down Expand Up @@ -99,13 +103,17 @@ func (c *Conn) Close() error {
if c.Conn == nil {
return nil
}
if c.IsAsyncWrite() {
c.Engine.AfterFunc(c.BlockingModAsyncCloseDelay, func() { c.Conn.Close() })
return nil
}
return c.Conn.Close()
}

// CloseWithError .
func (c *Conn) CloseWithError(err error) error {
c.SetCloseError(err)
return c.Conn.Close()
return c.Close()
}

// SetCloseError .
Expand Down Expand Up @@ -237,11 +245,7 @@ func (c *Conn) handleWsMessage(opcode MessageType, data []byte) {
}

ErrExit:
if c.IsAsyncWrite() {
c.Engine.AfterFunc(time.Second, func() { c.Conn.Close() })
} else {
c.Conn.Close()
}
c.Close()
}

func (c *Conn) nextFrame() (opcode MessageType, body []byte, ok, fin, res1, res2, res3 bool, err error) {
Expand Down Expand Up @@ -566,8 +570,41 @@ func (c *Conn) Session() interface{} {
return c.session
}

// SessionWithLock returns user session with lock, returns as soon as the session has been seted.
func (c *Conn) SessionWithLock() interface{} {
c.mux.Lock()
ch := c.chSessionInited
c.mux.Unlock()
if ch != nil {
<-ch
}
return c.session
}

// SessionWithContext returns user session, returns as soon as the session has been seted or
// waits until the context is done.
func (c *Conn) SessionWithContext(ctx context.Context) interface{} {
c.mux.Lock()
ch := c.chSessionInited
c.mux.Unlock()
if ch != nil {
select {
case <-ch:
case <-ctx.Done():
}

}
return c.session
}

// SetSession sets user session.
func (c *Conn) SetSession(session interface{}) {
c.mux.Lock()
if c.chSessionInited != nil {
close(c.chSessionInited)
c.chSessionInited = nil
}
c.mux.Unlock()
c.session = session
}

Expand All @@ -584,6 +621,11 @@ func (w *writeBuffer) Close() error {
// CloseAndClean .
func (c *Conn) CloseAndClean(err error) {
c.mux.Lock()
if c.chSessionInited != nil {
close(c.chSessionInited)
c.chSessionInited = nil
}

closed := c.closed
c.closed = true
if closed {
Expand Down Expand Up @@ -780,12 +822,26 @@ func NewConn(u *Upgrader, c net.Conn, subprotocol string, remoteCompressionEnabl
if asyncWrite {
wsc.sendQueue = make([][]byte, u.BlockingModSendQueueInitSize)[:0]
wsc.sendQueueSize = u.BlockingModSendQueueMaxSize
if wsc.BlockingModAsyncCloseDelay <= 0 {
wsc.BlockingModAsyncCloseDelay = DefaultBlockingModAsyncCloseDelay
}
}
return wsc
}

// BlockingModReadLoop .
func (c *Conn) BlockingModReadLoop(bufSize int) {
// HandleRead .
func (c *Conn) HandleRead(bufSize int) {
if !c.isReadingByParser {
return
}
c.mux.Lock()
reading := c.isInReadingLoop
c.isInReadingLoop = true
c.mux.Unlock()
if reading {
return
}

var (
n int
err error
Expand Down
36 changes: 28 additions & 8 deletions nbhttp/websocket/upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ var (

DefaultBlockingModSendQueueMaxSize = 0

DefaultBlockingModAsyncCloseDelay = time.Second / 10

// DefaultEngine will be set to a Upgrader.Engine to handle details such as buffers.
DefaultEngine = nbhttp.NewEngine(nbhttp.Config{
ReleaseWebsocketPayload: true,
})
)

type commonFields struct {
Engine *nbhttp.Engine
KeepaliveTime time.Duration
MessageLengthLimit int
Engine *nbhttp.Engine
KeepaliveTime time.Duration
MessageLengthLimit int
BlockingModAsyncCloseDelay time.Duration

enableCompression bool
compressionLevel int
Expand Down Expand Up @@ -92,8 +95,9 @@ type Upgrader struct {
func NewUpgrader() *Upgrader {
u := &Upgrader{
commonFields: commonFields{
Engine: DefaultEngine,
compressionLevel: defaultCompressionLevel,
Engine: DefaultEngine,
compressionLevel: defaultCompressionLevel,
BlockingModAsyncCloseDelay: DefaultBlockingModAsyncCloseDelay,
},
BlockingModReadBufferSize: DefaultBlockingReadBufferSize,
BlockingModAsyncWrite: DefaultBlockingModAsyncWrite,
Expand Down Expand Up @@ -370,13 +374,22 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return nil, err
}

wsc.isReadingByParser = (parser == nil)

if wsc.openHandler != nil {
wsc.openHandler(wsc)
}

if wsc.isBlockingMod {
if parser == nil {
go wsc.BlockingModReadLoop(u.BlockingModReadBufferSize)
if wsc.isBlockingMod && wsc.isReadingByParser {
var handleRead = true
if len(args) > 1 {
var b bool
b, ok = args[1].(bool)
handleRead = ok && b
}
if handleRead {
wsc.chSessionInited = make(chan struct{})
go wsc.HandleRead(u.BlockingModReadBufferSize)
}
}

Expand All @@ -388,6 +401,13 @@ func (u *Upgrader) UpgradeAndTransferConnToPoller(w http.ResponseWriter, r *http
return u.Upgrade(w, r, responseHeader, trasferConn)
}

func (u *Upgrader) UpgradeWithoutHandlingReadForConnFromSTDServer(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
// handle std server's conn, no need transfer conn to nbio Engine
const trasferConn = false
const handleRead = false
return u.Upgrade(w, r, responseHeader, trasferConn, handleRead)
}

func (u *Upgrader) commCheck(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (string, string, bool, error) {
if !headerContains(r.Header, "Connection", "upgrade") {
return "", "", false, u.returnError(w, r, http.StatusBadRequest, ErrUpgradeTokenNotFound)
Expand Down
10 changes: 5 additions & 5 deletions poller_epoll.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,6 @@ func (p *poller) readWriteLoop() {
default:
c := p.getConn(fd)
if c != nil {
if ev.Events&epollEventsError != 0 {
c.closeWithError(io.EOF)
continue
}

if ev.Events&epollEventsWrite != 0 {
c.flush()
}
Expand Down Expand Up @@ -227,6 +222,11 @@ func (p *poller) readWriteLoop() {
p.g.onRead(c)
}
}

if ev.Events&epollEventsError != 0 {
c.closeWithError(io.EOF)
continue
}
} else {
syscall.Close(fd)
// p.deleteEvent(fd)
Expand Down
Loading

0 comments on commit 6cf29cb

Please sign in to comment.