diff --git a/config.go b/config.go new file mode 100644 index 0000000..3e837cd --- /dev/null +++ b/config.go @@ -0,0 +1,53 @@ +package compress + +import ( + "compress/gzip" + "regexp" + "sort" +) + +var DefaultAllowedTypes = []*regexp.Regexp{ + regexp.MustCompile(`^text/`), + regexp.MustCompile(`^application/json`), + regexp.MustCompile(`^application/javascript`), + regexp.MustCompile(`\+(xml|json)$`), + regexp.MustCompile(`^image/svg`), +} + +type config struct { + supportedEncoding []string + + encoders map[string]*encoder + allowedType []*regexp.Regexp + minSize uint64 + silent bool +} + +func newConfig(options ...Option) *config { + c := &config{minSize: 4 * 1024, allowedType: DefaultAllowedTypes, encoders: map[string]*encoder{}} + WithGzip(100, gzip.DefaultCompression)(c) + for _, o := range options { + o(c) + } + c.populateSupportedEncoding() + return c +} + +func (c *config) populateSupportedEncoding() { + type encoderWithName struct { + *encoder + name string + } + + encList := make([]encoderWithName, 0, len(c.encoders)) + for name, enc := range c.encoders { + encList = append(encList, encoderWithName{enc, name}) + } + sort.Slice(encList, func(i, j int) bool { + return encList[i].priority < encList[j].priority + }) + c.supportedEncoding = make([]string, 0, len(encList)) + for _, enc := range encList { + c.supportedEncoding = append(c.supportedEncoding, enc.name) + } +} diff --git a/middleware_test.go b/config_test.go similarity index 63% rename from middleware_test.go rename to config_test.go index 77fdbe6..c6da12b 100644 --- a/middleware_test.go +++ b/config_test.go @@ -3,7 +3,6 @@ package compress import ( "context" "io" - "net/http" "testing" "gotest.tools/assert" @@ -11,7 +10,6 @@ import ( func Test_populateSupportedEncoding(t *testing.T) { var dummyFactory EncoderFactory = func(ctx context.Context, w io.Writer) (io.WriteCloser, error) { return nil, nil } - var h http.Handler - m := newMiddleware(h, WithEncoder("a", 1, dummyFactory), WithEncoder("b", 2, dummyFactory), WithEncoder("c", 3, dummyFactory), WihtoutEncoder("gzip")) + m := newConfig(WithEncoder("a", 1, dummyFactory), WithEncoder("b", 2, dummyFactory), WithEncoder("c", 3, dummyFactory), WihtoutEncoder("gzip")) assert.DeepEqual(t, m.supportedEncoding, []string{"a", "b", "c"}) } diff --git a/go.mod b/go.mod index 0e6266a..3509db0 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,11 @@ go 1.17 require ( github.com/kevinpollet/nego v0.0.0-20211010160919-a65cd48cee43 + +) + +// test +require ( gotest.tools v2.2.0+incompatible ) diff --git a/handler.go b/handler.go deleted file mode 100644 index a827554..0000000 --- a/handler.go +++ /dev/null @@ -1,9 +0,0 @@ -package compress - -import ( - "net/http" -) - -func Handler(h http.Handler, options ...Option) http.Handler { - return newMiddleware(h, options...) -} diff --git a/middleware.go b/middleware.go index d942c42..9fb14bc 100644 --- a/middleware.go +++ b/middleware.go @@ -1,12 +1,9 @@ package compress import ( - "compress/gzip" "context" "io" "net/http" - "regexp" - "sort" "github.com/kevinpollet/nego" ) @@ -18,72 +15,33 @@ type encoder struct { factory EncoderFactory } -type middleware struct { - http.Handler - - supportedEncoding []string - - encoders map[string]*encoder - allowedType []*regexp.Regexp - minSize uint64 - silent bool -} - -var DefaultAllowedTypes = []*regexp.Regexp{ - regexp.MustCompile(`^text/`), - regexp.MustCompile(`^application/json`), - regexp.MustCompile(`^application/javascript`), - regexp.MustCompile(`\+(xml|json)$`), - regexp.MustCompile(`^image/svg`), -} - -func newMiddleware(h http.Handler, options ...Option) *middleware { - m := &middleware{Handler: h, minSize: 4 * 1024, allowedType: DefaultAllowedTypes, encoders: map[string]*encoder{}} - WithGzip(100, gzip.DefaultCompression)(m) - for _, o := range options { - o(m) +type Middleware func(http.Handler) http.Handler + +func New(options ...Option) Middleware { + c := newConfig(options...) + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Vary", "Accept-Encoding") + + encoding := nego.NegotiateContentEncoding(r, c.supportedEncoding...) + enc, ok := c.encoders[encoding] + if ok { + mw := &responseWriter{ + ResponseWriter: w, + ctx: r.Context(), + factory: enc.factory, + encoding: encoding, + c: c, + status: http.StatusOK, + } + defer mw.flush() + w = mw + } + h.ServeHTTP(w, r) + }) } - m.populateSupportedEncoding() - return m } -func (m *middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // for cache validation - w.Header().Add("Vary", "Accept-Encoding") - - encoding := nego.NegotiateContentEncoding(r, m.supportedEncoding...) - enc, ok := m.encoders[encoding] - if ok { - mw := &middlewareWriter{ - ResponseWriter: w, - ctx: r.Context(), - factory: enc.factory, - encoding: encoding, - m: m, - status: http.StatusOK, - } - defer mw.flush() - w = mw - } - - m.Handler.ServeHTTP(w, r) -} - -func (m *middleware) populateSupportedEncoding() { - type encoderWithName struct { - *encoder - name string - } - - encList := make([]encoderWithName, 0, len(m.encoders)) - for name, enc := range m.encoders { - encList = append(encList, encoderWithName{enc, name}) - } - sort.Slice(encList, func(i, j int) bool { - return encList[i].priority < encList[j].priority - }) - m.supportedEncoding = make([]string, 0, len(encList)) - for _, enc := range encList { - m.supportedEncoding = append(m.supportedEncoding, enc.name) - } +func Handler(h http.Handler, options ...Option) http.Handler { + return New(options...)(h) } diff --git a/option.go b/option.go index bda4865..2d8f5ef 100644 --- a/option.go +++ b/option.go @@ -7,17 +7,17 @@ import ( "regexp" ) -type Option func(m *middleware) +type Option func(*config) func WithEncoder(encoding string, priotity int, factory EncoderFactory) Option { - return func(m *middleware) { + return func(m *config) { m.encoders[encoding] = &encoder{priority: priotity, factory: factory} } } func WihtoutEncoder(encoding string) Option { - return func(m *middleware) { - delete(m.encoders, encoding) + return func(c *config) { + delete(c.encoders, encoding) } } @@ -28,19 +28,19 @@ func WithGzip(priority, level int) Option { } func WithAllowedTypes(list []*regexp.Regexp) Option { - return func(m *middleware) { - m.allowedType = list + return func(c *config) { + c.allowedType = list } } func WithMinSize(minSize uint64) Option { - return func(m *middleware) { - m.minSize = minSize + return func(c *config) { + c.minSize = minSize } } func WithSilent() Option { - return func(m *middleware) { - m.silent = true + return func(c *config) { + c.silent = true } } diff --git a/writer.go b/writer.go index d2f1254..b622931 100644 --- a/writer.go +++ b/writer.go @@ -10,9 +10,9 @@ import ( "sync" ) -type middlewareWriter struct { +type responseWriter struct { http.ResponseWriter - m *middleware + c *config ctx context.Context factory EncoderFactory enc io.WriteCloser @@ -39,100 +39,100 @@ func matchRegexes(str string, res []*regexp.Regexp) bool { return false } -func (mw *middlewareWriter) WriteHeader(status int) { - mw.mutex.Lock() - defer mw.mutex.Unlock() - mw.status = status - mw.headerSent = true +func (rw *responseWriter) WriteHeader(status int) { + rw.mutex.Lock() + defer rw.mutex.Unlock() + rw.status = status + rw.headerSent = true - cenc := mw.Header().Get("content-encoding") - ctype := mw.Header().Get("content-type") + cenc := rw.Header().Get("content-encoding") + ctype := rw.Header().Get("content-type") // if content encoding already defined // or content type is not defined // or content type is not in allowed list // => just forward the body - if cenc != "" || ctype == "" || !matchRegexes(ctype, mw.m.allowedType) { - mw.dontEncode = true - mw.ResponseWriter.WriteHeader(status) + if cenc != "" || ctype == "" || !matchRegexes(ctype, rw.c.allowedType) { + rw.dontEncode = true + rw.ResponseWriter.WriteHeader(status) return } // if content length is big enough, start encoding now - if clen := mw.Header().Get("content-length"); clen != "" { + if clen := rw.Header().Get("content-length"); clen != "" { len, err := strconv.ParseUint(clen, 10, 64) if err == nil { - if len >= mw.m.minSize { - mw.startEncoding() + if len >= rw.c.minSize { + rw.startEncoding() } else { - mw.dontEncode = true - mw.ResponseWriter.WriteHeader(status) + rw.dontEncode = true + rw.ResponseWriter.WriteHeader(status) } return } } // the content length is unknown, buffer the response until it exceeds minSize - mw.buff = make([]byte, mw.m.minSize) + rw.buff = make([]byte, rw.c.minSize) } -func (mw *middlewareWriter) startEncoding() bool { +func (rw *responseWriter) startEncoding() bool { var err error - mw.enc, err = mw.factory(mw.ctx, mw.ResponseWriter) + rw.enc, err = rw.factory(rw.ctx, rw.ResponseWriter) if err != nil { - if !mw.m.silent { - fmt.Printf("Can not create encoder %s: %v\n", mw.encoding, err) + if !rw.c.silent { + fmt.Printf("Can not create encoder %s: %v\n", rw.encoding, err) } - mw.flush() - mw.dontEncode = true + rw.flush() + rw.dontEncode = true return false } - mw.shouldEncode = true - mw.Header().Del("content-length") - mw.Header().Add("content-encoding", mw.encoding) - mw.ResponseWriter.WriteHeader(mw.status) + rw.shouldEncode = true + rw.Header().Del("content-length") + rw.Header().Add("content-encoding", rw.encoding) + rw.ResponseWriter.WriteHeader(rw.status) return true } -func (mw *middlewareWriter) Write(chunk []byte) (int, error) { - if !mw.headerSent { - mw.WriteHeader(mw.status) +func (rw *responseWriter) Write(chunk []byte) (int, error) { + if !rw.headerSent { + rw.WriteHeader(rw.status) } - mw.mutex.Lock() - defer mw.mutex.Unlock() + rw.mutex.Lock() + defer rw.mutex.Unlock() - if mw.shouldEncode { - return mw.enc.Write(chunk) + if rw.shouldEncode { + return rw.enc.Write(chunk) } - if mw.dontEncode { - return mw.ResponseWriter.Write(chunk) + if rw.dontEncode { + return rw.ResponseWriter.Write(chunk) } - newBufLen := mw.buffLen + uint64(len(chunk)) - if newBufLen > mw.m.minSize { - if mw.startEncoding() { - if mw.buffLen > 0 { - _, err := mw.enc.Write(mw.buff[0:mw.buffLen]) + newBufLen := rw.buffLen + uint64(len(chunk)) + if newBufLen > rw.c.minSize { + if rw.startEncoding() { + if rw.buffLen > 0 { + _, err := rw.enc.Write(rw.buff[0:rw.buffLen]) if err != nil { return 0, err } } - mw.buff = nil - mw.buffLen = 0 - return mw.enc.Write(chunk) + rw.buff = nil + rw.buffLen = 0 + return rw.enc.Write(chunk) } - return mw.ResponseWriter.Write(chunk) + return rw.ResponseWriter.Write(chunk) } - n := copy(mw.buff[mw.buffLen:], chunk) - mw.buffLen = newBufLen + n := copy(rw.buff[rw.buffLen:], chunk) + rw.buffLen = newBufLen return n, nil } -func (mw *middlewareWriter) flush() { - if mw.enc != nil { - mw.enc.Close() +func (rw *responseWriter) flush() { + if rw.enc != nil { + rw.enc.Close() } - if mw.buff != nil && mw.buffLen > 0 { - mw.ResponseWriter.Write(mw.buff[0:mw.buffLen]) + if rw.buff != nil && rw.buffLen > 0 { + rw.ResponseWriter.Write(rw.buff[0:rw.buffLen]) } }