diff --git a/wsutil/reader.go b/wsutil/reader.go index ff2e5b6..f2710af 100644 --- a/wsutil/reader.go +++ b/wsutil/reader.go @@ -1,6 +1,7 @@ package wsutil import ( + "encoding/binary" "errors" "io" "io/ioutil" @@ -56,10 +57,12 @@ type Reader struct { OnContinuation FrameHandlerFunc OnIntermediate FrameHandlerFunc - opCode ws.OpCode // Used to store message op code on fragmentation. - frame io.Reader // Used to as frame reader. - raw io.LimitedReader // Used to discard frames without cipher. - utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true. + opCode ws.OpCode // Used to store message op code on fragmentation. + frame io.Reader // Used to as frame reader. + raw io.LimitedReader // Used to discard frames without cipher. + utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true. + tmp [ws.MaxHeaderSize - 2]byte // Used for reading headers. + cr *CipherReader // Used by NextFrame() to unmask frame payload. } // NewReader creates new frame reader that reads from r keeping given state to @@ -165,7 +168,7 @@ func (r *Reader) Discard() (err error) { // Note that next NextFrame() call must be done after receiving or discarding // all current message bytes. func (r *Reader) NextFrame() (hdr ws.Header, err error) { - hdr, err = ws.ReadHeader(r.Source) + hdr, err = r.readHeader(r.Source) if err == io.EOF && r.fragmented() { // If we are in fragmented state EOF means that is was totally // unexpected. @@ -196,7 +199,12 @@ func (r *Reader) NextFrame() (hdr ws.Header, err error) { frame := io.Reader(&r.raw) if hdr.Masked { - frame = NewCipherReader(frame, hdr.Mask) + if r.cr == nil { + r.cr = NewCipherReader(frame, hdr.Mask) + } else { + r.cr.Reset(frame, hdr.Mask) + } + frame = r.cr } for _, x := range r.Extensions { @@ -261,6 +269,82 @@ func (r *Reader) reset() { r.opCode = 0 } +// readHeader reads a frame header from in. +func (r *Reader) readHeader(in io.Reader) (h ws.Header, err error) { + // Make slice of bytes with capacity 12 that could hold any header. + // + // The maximum header size is 14, but due to the 2 hop reads, + // after first hop that reads first 2 constant bytes, we could reuse 2 bytes. + // So 14 - 2 = 12. + bts := r.tmp[:2] + + // Prepare to hold first 2 bytes to choose size of next read. + _, err = io.ReadFull(in, bts) + if err != nil { + return h, err + } + const bit0 = 0x80 + + h.Fin = bts[0]&bit0 != 0 + h.Rsv = (bts[0] & 0x70) >> 4 + h.OpCode = ws.OpCode(bts[0] & 0x0f) + + var extra int + + if bts[1]&bit0 != 0 { + h.Masked = true + extra += 4 + } + + length := bts[1] & 0x7f + switch { + case length < 126: + h.Length = int64(length) + + case length == 126: + extra += 2 + + case length == 127: + extra += 8 + + default: + err = ws.ErrHeaderLengthUnexpected + return h, err + } + + if extra == 0 { + return h, err + } + + // Increase len of bts to extra bytes need to read. + // Overwrite first 2 bytes that was read before. + bts = bts[:extra] + _, err = io.ReadFull(in, bts) + if err != nil { + return h, err + } + + switch { + case length == 126: + h.Length = int64(binary.BigEndian.Uint16(bts[:2])) + bts = bts[2:] + + case length == 127: + if bts[0]&0x80 != 0 { + err = ws.ErrHeaderLengthMSB + return h, err + } + h.Length = int64(binary.BigEndian.Uint64(bts[:8])) + bts = bts[8:] + } + + if h.Masked { + copy(h.Mask[:], bts) + } + + return h, nil +} + // NextReader prepares next message read from r. It returns header that // describes the message and io.Reader to read message's payload. It returns // non-nil error when it is not possible to read message's initial frame.