Skip to content

Commit

Permalink
middleware: add Discard method to WrapResponseWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
patrislav committed Jun 27, 2024
1 parent 7957c0d commit 5a95c0a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 7 deletions.
26 changes: 19 additions & 7 deletions middleware/wrap_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ type WrapResponseWriter interface {
Tee(io.Writer)
// Unwrap returns the original proxied target.
Unwrap() http.ResponseWriter
// Discard causes all writes to the original ResponseWriter be discarded,
// instead writing only to the tee'd writer if it's set.
Discard()
}

// basicWriter wraps a http.ResponseWriter that implements the minimal
Expand All @@ -71,6 +74,7 @@ type basicWriter struct {
code int
bytes int
tee io.Writer
discard bool
}

func (b *basicWriter) WriteHeader(code int) {
Expand All @@ -81,15 +85,19 @@ func (b *basicWriter) WriteHeader(code int) {
}
}

func (b *basicWriter) Write(buf []byte) (int, error) {
func (b *basicWriter) Write(buf []byte) (n int, err error) {
b.maybeWriteHeader()
n, err := b.ResponseWriter.Write(buf)
if b.tee != nil {
_, err2 := b.tee.Write(buf[:n])
// Prefer errors generated by the proxied writer.
if err == nil {
err = err2
if !b.discard {
n, err = b.ResponseWriter.Write(buf)
if b.tee != nil {
_, err2 := b.tee.Write(buf[:n])
// Prefer errors generated by the proxied writer.
if err == nil {
err = err2
}
}
} else if b.tee != nil {
n, err = b.tee.Write(buf)
}
b.bytes += n
return n, err
Expand Down Expand Up @@ -117,6 +125,10 @@ func (b *basicWriter) Unwrap() http.ResponseWriter {
return b.ResponseWriter
}

func (b *basicWriter) Discard() {
b.discard = true
}

// flushWriter ...
type flushWriter struct {
basicWriter
Expand Down
43 changes: 43 additions & 0 deletions middleware/wrap_writer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"bytes"
"net/http/httptest"
"testing"
)
Expand All @@ -22,3 +23,45 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
t.Fatal("want Flush to have set wroteHeader=true")
}
}

func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) {
original := httptest.NewRecorder()
wrap := &basicWriter{ResponseWriter: original}

var buf bytes.Buffer
wrap.Tee(&buf)

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, []byte("hello world"), original.Body.Bytes())
assertEqual(t, []byte("hello world"), buf.Bytes())
}

func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) {
t.Run("With Tee", func(t *testing.T) {
original := httptest.NewRecorder()
wrap := &basicWriter{ResponseWriter: original}

var buf bytes.Buffer
wrap.Tee(&buf)
wrap.Discard()

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 0, original.Body.Len())
assertEqual(t, []byte("hello world"), buf.Bytes())
})

t.Run("Without Tee", func(t *testing.T) {
original := httptest.NewRecorder()
wrap := &basicWriter{ResponseWriter: original}
wrap.Discard()

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 0, original.Body.Len())
})
}

0 comments on commit 5a95c0a

Please sign in to comment.