Skip to content

Commit

Permalink
discard calls to WriteHeader too
Browse files Browse the repository at this point in the history
  • Loading branch information
patrislav committed Jun 28, 2024
1 parent eeac37c commit 91f4808
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
6 changes: 5 additions & 1 deletion middleware/wrap_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ type WrapResponseWriter interface {
Unwrap() http.ResponseWriter
// Discard causes all writes to the original ResponseWriter be discarded,
// instead writing only to the tee'd writer if it's set.
// The caller is responsible for calling WriteHeader and Write on the
// original ResponseWriter once the processing is done.
Discard()
}

Expand All @@ -82,7 +84,9 @@ func (b *basicWriter) WriteHeader(code int) {
if !b.wroteHeader {
b.code = code
b.wroteHeader = true
b.ResponseWriter.WriteHeader(code)
if !b.discard {
b.ResponseWriter.WriteHeader(code)
}
}
}

Expand Down
22 changes: 19 additions & 3 deletions middleware/wrap_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)
Expand All @@ -25,7 +26,11 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
}

func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) {
original := httptest.NewRecorder()
// explicitly create the struct instead of NewRecorder to control the value of Code
original := &httptest.ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
wrap := &basicWriter{ResponseWriter: original}

var buf bytes.Buffer
Expand All @@ -34,14 +39,19 @@ func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) {
_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 200, original.Code)
assertEqual(t, []byte("hello world"), original.Body.Bytes())
assertEqual(t, []byte("hello world"), buf.Bytes())
assertEqual(t, 11, wrap.BytesWritten())
}

func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) {
t.Run("With Tee", func(t *testing.T) {
original := httptest.NewRecorder()
// explicitly create the struct instead of NewRecorder to control the value of Code
original := &httptest.ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
wrap := &basicWriter{ResponseWriter: original}

var buf bytes.Buffer
Expand All @@ -51,19 +61,25 @@ func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) {
_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly
assertEqual(t, 0, original.Body.Len())
assertEqual(t, []byte("hello world"), buf.Bytes())
assertEqual(t, 11, wrap.BytesWritten())
})

t.Run("Without Tee", func(t *testing.T) {
original := httptest.NewRecorder()
// explicitly create the struct instead of NewRecorder to control the value of Code
original := &httptest.ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
wrap := &basicWriter{ResponseWriter: original}
wrap.Discard()

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly
assertEqual(t, 0, original.Body.Len())
assertEqual(t, 11, wrap.BytesWritten())
})
Expand Down

0 comments on commit 91f4808

Please sign in to comment.