diff --git a/middleware/wrap_writer.go b/middleware/wrap_writer.go index cf5c44de..0727a0e1 100644 --- a/middleware/wrap_writer.go +++ b/middleware/wrap_writer.go @@ -61,6 +61,9 @@ type WrapResponseWriter interface { Tee(io.Writer) // Unwrap returns the original proxied target. Unwrap() http.ResponseWriter + // Discard causes all writes to the original ResponseWriter be discarded, + // instead writing only to the tee'd writer if it's set. + Discard() } // basicWriter wraps a http.ResponseWriter that implements the minimal @@ -71,6 +74,7 @@ type basicWriter struct { code int bytes int tee io.Writer + discard bool } func (b *basicWriter) WriteHeader(code int) { @@ -81,15 +85,19 @@ func (b *basicWriter) WriteHeader(code int) { } } -func (b *basicWriter) Write(buf []byte) (int, error) { +func (b *basicWriter) Write(buf []byte) (n int, err error) { b.maybeWriteHeader() - n, err := b.ResponseWriter.Write(buf) - if b.tee != nil { - _, err2 := b.tee.Write(buf[:n]) - // Prefer errors generated by the proxied writer. - if err == nil { - err = err2 + if !b.discard { + n, err = b.ResponseWriter.Write(buf) + if b.tee != nil { + _, err2 := b.tee.Write(buf[:n]) + // Prefer errors generated by the proxied writer. + if err == nil { + err = err2 + } } + } else if b.tee != nil { + n, err = b.tee.Write(buf) } b.bytes += n return n, err @@ -117,6 +125,10 @@ func (b *basicWriter) Unwrap() http.ResponseWriter { return b.ResponseWriter } +func (b *basicWriter) Discard() { + b.discard = true +} + // flushWriter ... type flushWriter struct { basicWriter diff --git a/middleware/wrap_writer_test.go b/middleware/wrap_writer_test.go index 2c442ada..824ccf42 100644 --- a/middleware/wrap_writer_test.go +++ b/middleware/wrap_writer_test.go @@ -1,6 +1,7 @@ package middleware import ( + "bytes" "net/http/httptest" "testing" ) @@ -22,3 +23,45 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) { t.Fatal("want Flush to have set wroteHeader=true") } } + +func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) { + original := httptest.NewRecorder() + wrap := &basicWriter{ResponseWriter: original} + + var buf bytes.Buffer + wrap.Tee(&buf) + + _, err := wrap.Write([]byte("hello world")) + assertNoError(t, err) + + assertEqual(t, []byte("hello world"), original.Body.Bytes()) + assertEqual(t, []byte("hello world"), buf.Bytes()) +} + +func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) { + t.Run("With Tee", func(t *testing.T) { + original := httptest.NewRecorder() + wrap := &basicWriter{ResponseWriter: original} + + var buf bytes.Buffer + wrap.Tee(&buf) + wrap.Discard() + + _, err := wrap.Write([]byte("hello world")) + assertNoError(t, err) + + assertEqual(t, 0, original.Body.Len()) + assertEqual(t, []byte("hello world"), buf.Bytes()) + }) + + t.Run("Without Tee", func(t *testing.T) { + original := httptest.NewRecorder() + wrap := &basicWriter{ResponseWriter: original} + wrap.Discard() + + _, err := wrap.Write([]byte("hello world")) + assertNoError(t, err) + + assertEqual(t, 0, original.Body.Len()) + }) +}