Skip to content

Commit 1f3ef29

Browse files
authored
Context: json should not send status code before serialization is complete (#2877)
Context: json should not send status code before serialization is complete
1 parent 489646e commit 1f3ef29

File tree

4 files changed

+119
-8
lines changed

4 files changed

+119
-8
lines changed

context.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ func (c *Context) Response() http.ResponseWriter {
139139
return c.response
140140
}
141141

142-
// SetResponse sets `*http.ResponseWriter`. Some middleware require that given ResponseWriter implements following
143-
// method `Unwrap() http.ResponseWriter` which eventually should return echo.Response instance.
142+
// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following
143+
// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance.
144144
func (c *Context) SetResponse(r http.ResponseWriter) {
145145
c.response = r
146146
}
@@ -415,6 +415,15 @@ func (c *Context) Render(code int, name string, data any) (err error) {
415415
if c.echo.Renderer == nil {
416416
return ErrRendererNotRegistered
417417
}
418+
// as Renderer.Render can fail, and in that case we need to delay sending status code to the client until
419+
// (global) error handler decides the correct status code for the error to be sent to the client, so we need to write
420+
// the rendered template to the buffer first.
421+
//
422+
// html.Template.ExecuteTemplate() documentations writes:
423+
// > If an error occurs executing the template or writing its output,
424+
// > execution stops, but partial results may already have been written to
425+
// > the output writer.
426+
418427
buf := new(bytes.Buffer)
419428
if err = c.echo.Renderer.Render(c, buf, name, data); err != nil {
420429
return
@@ -454,7 +463,18 @@ func (c *Context) jsonPBlob(code int, callback string, i any) (err error) {
454463

455464
func (c *Context) json(code int, i any, indent string) error {
456465
c.writeContentType(MIMEApplicationJSON)
457-
c.response.WriteHeader(code)
466+
467+
// as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until
468+
// (global) error handler decides correct status code for the error to be sent to the client.
469+
// For that we need to use writer that can store the proposed status code until the first Write is called.
470+
if r, err := UnwrapResponse(c.response); err == nil {
471+
r.Status = code
472+
} else {
473+
resp := c.Response()
474+
c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code})
475+
defer c.SetResponse(resp)
476+
}
477+
458478
return c.echo.JSONSerializer.Serialize(c, i, indent)
459479
}
460480

context_test.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"io"
1313
"io/fs"
1414
"log/slog"
15+
"math"
1516
"mime/multipart"
1617
"net/http"
1718
"net/http/httptest"
@@ -138,6 +139,24 @@ func TestContextRenderTemplate(t *testing.T) {
138139
}
139140
}
140141

142+
func TestContextRenderTemplateError(t *testing.T) {
143+
// we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do
144+
e := New()
145+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
146+
rec := httptest.NewRecorder()
147+
c := e.NewContext(req, rec)
148+
149+
tmpl := &Template{
150+
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
151+
}
152+
c.Echo().Renderer = tmpl
153+
err := c.Render(http.StatusOK, "not_existing", "Jon Snow")
154+
155+
assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`)
156+
assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
157+
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
158+
}
159+
141160
func TestContextRenderErrorsOnNoRenderer(t *testing.T) {
142161
e := New()
143162
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
@@ -173,10 +192,9 @@ func TestContextStream(t *testing.T) {
173192
}
174193

175194
func TestContextHTML(t *testing.T) {
176-
e := New()
177195
rec := httptest.NewRecorder()
178196
req := httptest.NewRequest(http.MethodGet, "/", nil)
179-
c := e.NewContext(req, rec)
197+
c := NewContext(req, rec)
180198

181199
err := c.HTML(http.StatusOK, "Hi, Jon Snow")
182200
if assert.NoError(t, err) {
@@ -187,10 +205,9 @@ func TestContextHTML(t *testing.T) {
187205
}
188206

189207
func TestContextHTMLBlob(t *testing.T) {
190-
e := New()
191208
rec := httptest.NewRecorder()
192209
req := httptest.NewRequest(http.MethodGet, "/", nil)
193-
c := e.NewContext(req, rec)
210+
c := NewContext(req, rec)
194211

195212
err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow"))
196213
if assert.NoError(t, err) {
@@ -222,6 +239,24 @@ func TestContextJSONErrorsOut(t *testing.T) {
222239

223240
err := c.JSON(http.StatusOK, make(chan bool))
224241
assert.EqualError(t, err, "json: unsupported type: chan bool")
242+
243+
assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
244+
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
245+
}
246+
247+
func TestContextJSONWithNotEchoResponse(t *testing.T) {
248+
e := New()
249+
rec := httptest.NewRecorder()
250+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
251+
c := e.NewContext(req, rec)
252+
253+
c.SetResponse(rec)
254+
255+
err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()})
256+
assert.EqualError(t, err, "json: unsupported value: NaN")
257+
258+
assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
259+
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
225260
}
226261

227262
func TestContextJSONPretty(t *testing.T) {

response.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,47 @@ func UnwrapResponse(rw http.ResponseWriter) (*Response, error) {
126126
rw = t.Unwrap()
127127
continue
128128
default:
129-
return nil, errors.New("ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface")
129+
return nil, errors.New("ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface or unwrap to *echo.Response")
130130
}
131131
}
132132
}
133+
134+
// delayedStatusWriter is a wrapper around http.ResponseWriter that delays writing the status code until first Write is called.
135+
// This allows (global) error handler to decide correct status code to be sent to the client.
136+
type delayedStatusWriter struct {
137+
http.ResponseWriter
138+
commited bool
139+
status int
140+
}
141+
142+
func (w *delayedStatusWriter) WriteHeader(statusCode int) {
143+
// in case something else writes status code explicitly before us we need mark response commited
144+
w.commited = true
145+
w.ResponseWriter.WriteHeader(statusCode)
146+
}
147+
148+
func (w *delayedStatusWriter) Write(data []byte) (int, error) {
149+
if !w.commited {
150+
w.commited = true
151+
if w.status == 0 {
152+
w.status = http.StatusOK
153+
}
154+
w.ResponseWriter.WriteHeader(w.status)
155+
}
156+
return w.ResponseWriter.Write(data)
157+
}
158+
159+
func (w *delayedStatusWriter) Flush() {
160+
err := http.NewResponseController(w.ResponseWriter).Flush()
161+
if err != nil && errors.Is(err, http.ErrNotSupported) {
162+
panic(errors.New("response writer flushing is not supported"))
163+
}
164+
}
165+
166+
func (w *delayedStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
167+
return http.NewResponseController(w.ResponseWriter).Hijack()
168+
}
169+
170+
func (w *delayedStatusWriter) Unwrap() http.ResponseWriter {
171+
return w.ResponseWriter
172+
}

response_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,19 @@ func TestResponse_FlushPanics(t *testing.T) {
115115
res.Flush()
116116
})
117117
}
118+
119+
func TestResponse_UnwrapResponse(t *testing.T) {
120+
orgRes := NewResponse(httptest.NewRecorder(), nil)
121+
res, err := UnwrapResponse(orgRes)
122+
123+
assert.NotNil(t, res)
124+
assert.NoError(t, err)
125+
}
126+
127+
func TestResponse_UnwrapResponse_error(t *testing.T) {
128+
rw := new(testResponseWriter)
129+
res, err := UnwrapResponse(rw)
130+
131+
assert.Nil(t, res)
132+
assert.EqualError(t, err, "ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface or unwrap to *echo.Response")
133+
}

0 commit comments

Comments
 (0)