Skip to content

Commit

Permalink
fix(middleware): Close created writer in the compressor middleware (#919
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Neurostep authored Jun 7, 2024
1 parent ef31c0b commit f10dc4a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) er
}

func (cw *compressResponseWriter) Close() error {
if c, ok := cw.writer().(io.WriteCloser); ok {
if c, ok := cw.w.(io.WriteCloser); ok {
return c.Close()
}
return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer")
Expand Down
46 changes: 44 additions & 2 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ func TestCompressor(t *testing.T) {
return w
})

if len(compressor.encoders) != 1 {
t.Errorf("nop encoder should be stored in the encoders map")
var sideEffect int
compressor.SetEncoder("test", func(w io.Writer, _ int) io.Writer {
return newSideEffectWriter(w, &sideEffect)
})

if len(compressor.encoders) != 2 {
t.Errorf("nop and test encoders should be stored in the encoders map")
}

r.Use(compressor.Handler)
Expand All @@ -47,6 +52,11 @@ func TestCompressor(t *testing.T) {
w.Write([]byte("textstring"))
})

r.Get("/getimage", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.Write([]byte("textstring"))
})

ts := httptest.NewServer(r)
defer ts.Close()

Expand Down Expand Up @@ -93,6 +103,12 @@ func TestCompressor(t *testing.T) {
acceptedEncodings: []string{"nop, gzip, deflate"},
expectedEncoding: "nop",
},
{
name: "test is used and side effect is cleared after close",
path: "/getimage",
acceptedEncodings: []string{"test"},
expectedEncoding: "",
},
}

for _, tc := range tests {
Expand All @@ -107,7 +123,10 @@ func TestCompressor(t *testing.T) {
}

})
}

if sideEffect > 1 {
t.Errorf("side effect should be cleared after close")
}
}

Expand Down Expand Up @@ -217,3 +236,26 @@ func decodeResponseBody(t *testing.T, resp *http.Response) string {

return string(respBody)
}

type (
sideEffectWriter struct {
w io.Writer
s *int
}
)

func newSideEffectWriter(w io.Writer, sideEffect *int) io.Writer {
*sideEffect = *sideEffect + 1

return &sideEffectWriter{w: w, s: sideEffect}
}

func (w *sideEffectWriter) Write(p []byte) (n int, err error) {
return w.w.Write(p)
}

func (w *sideEffectWriter) Close() error {
*w.s = *w.s - 1

return nil
}

0 comments on commit f10dc4a

Please sign in to comment.