diff --git a/nbhttp/websocket/conn.go b/nbhttp/websocket/conn.go index ef8ebf5a..1f2a97e5 100644 --- a/nbhttp/websocket/conn.go +++ b/nbhttp/websocket/conn.go @@ -263,14 +263,16 @@ ErrExit: c.Close() } -func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool, bool, error) { +func (c *Conn) nextFrame() (int, MessageType, []byte, bool, bool, bool, error) { var ( opcode MessageType body []byte ok, fin, res1, res2, res3 bool err error + data = c.bytesCached l = int64(len(data)) headLen = int64(2) + total int64 ) if l >= 2 { opcode = MessageType(data[0] & 0xF) @@ -297,12 +299,12 @@ func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool, } if c.isMessageTooLarge(len(c.message) + int(bodyLen)) { - return data, 0, nil, false, false, false, ErrMessageTooLarge + return 0, 0, nil, false, false, false, ErrMessageTooLarge } if (bodyLen > maxControlFramePayloadSize) && ((opcode == PingMessage) || (opcode == PongMessage) || (opcode == CloseMessage)) { - return data, 0, nil, false, false, false, ErrControlMessageTooBig + return 0, 0, nil, false, false, false, ErrControlMessageTooBig } if bodyLen >= 0 { @@ -310,7 +312,7 @@ func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool, if masked { headLen += 4 } - total := headLen + bodyLen + total = headLen + bodyLen if l >= total { body = data[headLen:total] if masked { @@ -318,17 +320,20 @@ func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool, } ok = true - data = data[total:l] err = c.validFrame(opcode, fin, res1, res2, res3, c.expectingFragments) } } } - return data, opcode, body, ok, fin, res1, err + return int(total), opcode, body, ok, fin, res1, err } // Read . func (c *Conn) Parse(data []byte) error { + if len(data) == 0 { + return nil + } + c.mux.Lock() if c.closed { c.mux.Unlock() @@ -341,12 +346,12 @@ func (c *Conn) Parse(data []byte) error { return nbhttp.ErrTooLong } - var appended = false var allocator = c.Engine.BodyAllocator - if len(c.bytesCached) > 0 { + if len(c.bytesCached) == 0 { + c.bytesCached = allocator.Malloc(len(data)) + copy(c.bytesCached, data) + } else { c.bytesCached = allocator.Append(c.bytesCached, data...) - data = c.bytesCached - appended = true } c.mux.Unlock() @@ -357,6 +362,7 @@ func (c *Conn) Parse(data []byte) error { var protocolMessage []byte var opcode MessageType var ok, fin, compress bool + var totalFrameSize int releaseBuf := func() { if len(frame) > 0 { @@ -378,7 +384,7 @@ func (c *Conn) Parse(data []byte) error { err = net.ErrClosed return } - data, opcode, body, ok, fin, compress, err = c.nextFrame(data) + totalFrameSize, opcode, body, ok, fin, compress, err = c.nextFrame() if err != nil { return } @@ -386,23 +392,22 @@ func (c *Conn) Parse(data []byte) error { return } + bl := len(body) switch opcode { case FragmentMessage, TextMessage, BinaryMessage: if c.msgType == 0 { c.msgType = opcode c.compress = compress } - bl := len(body) - if c.dataFrameHandler != nil { - if bl > 0 { - frame = allocator.Malloc(bl) - copy(frame, body) - } - if c.msgType == TextMessage && len(frame) > 0 && !c.Engine.CheckUtf8(frame) { - c.Conn.Close() - err = ErrInvalidUtf8 - return - } + if bl > 0 && c.dataFrameHandler != nil { + frame = allocator.Malloc(bl) + copy(frame, body) + // if compressed, should check utf8 after decompressed the whole message. + // if c.msgType == TextMessage && len(frame) > 0 && !c.Engine.CheckUtf8(frame) { + // c.Conn.Close() + // err = ErrInvalidUtf8 + // return + // } } if c.messageHandler != nil { if bl > 0 { @@ -416,36 +421,6 @@ func (c *Conn) Parse(data []byte) error { if fin { message = c.message c.message = nil - } - } - case PingMessage, PongMessage, CloseMessage: - if len(body) > 0 { - protocolMessage = allocator.Malloc(len(body)) - copy(protocolMessage, body) - } - default: - err = ErrInvalidFragmentMessage - return - } - }() - - if err != nil { - releaseBuf() - if errors.Is(err, ErrMessageTooLarge) || errors.Is(err, ErrControlMessageTooBig) { - c.WriteClose(1009, err.Error()) - } - return err - } - - if ok { - switch opcode { - case FragmentMessage, TextMessage, BinaryMessage: - if c.dataFrameHandler != nil { - c.handleDataFrame(c.msgType, fin, frame) - frame = nil - } - if fin { - if c.messageHandler != nil { if c.compress { var b []byte var rc io.ReadCloser @@ -460,64 +435,63 @@ func (c *Conn) Parse(data []byte) error { rc.Close() if err != nil { releaseBuf() - return err + return } } - c.handleMessage(c.msgType, message) - message = nil + c.compress = false + c.expectingFragments = false + c.msgType = 0 + } else { + c.expectingFragments = true } - c.compress = false - c.expectingFragments = false - c.msgType = 0 - } else { - c.expectingFragments = true } case PingMessage, PongMessage, CloseMessage: - c.handleProtocolMessage(opcode, protocolMessage) - protocolMessage = nil + if bl > 0 { + protocolMessage = allocator.Malloc(len(body)) + copy(protocolMessage, body) + } default: - releaseBuf() - return ErrInvalidFragmentMessage + err = ErrInvalidFragmentMessage + return } - } else { - goto Exit - } - if len(data) == 0 { - goto Exit - } - } + l := len(c.bytesCached) + if l == totalFrameSize { + c.Engine.BodyAllocator.Free(c.bytesCached) + c.bytesCached = nil + } else { + copy(c.bytesCached, c.bytesCached[totalFrameSize:l]) + c.bytesCached = c.bytesCached[:l-totalFrameSize] + } + }() -Exit: - releaseBuf() - c.mux.Lock() - defer c.mux.Unlock() - if c.closed { - return net.ErrClosed - } - // The data bytes were not all consumed, need to recache the current bytes left: - if len(data) > 0 { - // The data bytes were appended to the tail of the previous chaced data: - if appended { - // If data bytes were consumed, move data to the head of the cached bytes, - // else the data is same as the cached bytes, nothing to do. - if len(data) < len(c.bytesCached) { - c.bytesCached = c.bytesCached[:len(data)] - copy(c.bytesCached, data) + if err != nil { + if errors.Is(err, ErrMessageTooLarge) || errors.Is(err, ErrControlMessageTooBig) { + c.WriteClose(1009, err.Error()) } - } else { // When using the origin data passed to this `Parse` func: - c.bytesCached = allocator.Malloc(len(data)) - copy(c.bytesCached, data) + return err } - } else { // The data bytes were all consumed: - // If the data bytes were cached, release the bytes and clear the cache. - if len(c.bytesCached) > 0 { - allocator.Free(c.bytesCached) - c.bytesCached = nil + + if message != nil { + c.handleMessage(c.msgType, message) + message = nil + } + if frame != nil { + c.handleDataFrame(c.msgType, fin, frame) + frame = nil + } + if protocolMessage != nil { + c.handleProtocolMessage(opcode, protocolMessage) + protocolMessage = nil + } + + // need more data + if !ok { + break } } - return err + return nil } // OnMessage .