diff --git a/middleware/wrap_writer.go b/middleware/wrap_writer.go index 8040d3ae..bf270881 100644 --- a/middleware/wrap_writer.go +++ b/middleware/wrap_writer.go @@ -64,6 +64,8 @@ type WrapResponseWriter interface { Unwrap() http.ResponseWriter // Discard causes all writes to the original ResponseWriter be discarded, // instead writing only to the tee'd writer if it's set. + // The caller is responsible for calling WriteHeader and Write on the + // original ResponseWriter once the processing is done. Discard() } @@ -82,7 +84,9 @@ func (b *basicWriter) WriteHeader(code int) { if !b.wroteHeader { b.code = code b.wroteHeader = true - b.ResponseWriter.WriteHeader(code) + if !b.discard { + b.ResponseWriter.WriteHeader(code) + } } } diff --git a/middleware/wrap_writer_test.go b/middleware/wrap_writer_test.go index 068225db..7e8f6ab2 100644 --- a/middleware/wrap_writer_test.go +++ b/middleware/wrap_writer_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "net/http" "net/http/httptest" "testing" ) @@ -25,7 +26,11 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) { } func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) { - original := httptest.NewRecorder() + // explicitly create the struct instead of NewRecorder to control the value of Code + original := &httptest.ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } wrap := &basicWriter{ResponseWriter: original} var buf bytes.Buffer @@ -34,6 +39,7 @@ func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) { _, err := wrap.Write([]byte("hello world")) assertNoError(t, err) + assertEqual(t, 200, original.Code) assertEqual(t, []byte("hello world"), original.Body.Bytes()) assertEqual(t, []byte("hello world"), buf.Bytes()) assertEqual(t, 11, wrap.BytesWritten()) @@ -41,7 +47,11 @@ func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) { func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) { t.Run("With Tee", func(t *testing.T) { - original := httptest.NewRecorder() + // explicitly create the struct instead of NewRecorder to control the value of Code + original := &httptest.ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } wrap := &basicWriter{ResponseWriter: original} var buf bytes.Buffer @@ -51,19 +61,25 @@ func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) { _, err := wrap.Write([]byte("hello world")) assertNoError(t, err) + assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly assertEqual(t, 0, original.Body.Len()) assertEqual(t, []byte("hello world"), buf.Bytes()) assertEqual(t, 11, wrap.BytesWritten()) }) t.Run("Without Tee", func(t *testing.T) { - original := httptest.NewRecorder() + // explicitly create the struct instead of NewRecorder to control the value of Code + original := &httptest.ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } wrap := &basicWriter{ResponseWriter: original} wrap.Discard() _, err := wrap.Write([]byte("hello world")) assertNoError(t, err) + assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly assertEqual(t, 0, original.Body.Len()) assertEqual(t, 11, wrap.BytesWritten()) })