Skip to content

Commit 037dec4

Browse files
committed
CSRF: support older token-based CSRF protection handler that want to render token into template
1 parent 096ce41 commit 037dec4

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

middleware/csrf.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ import (
1313
"github.com/labstack/echo/v5"
1414
)
1515

16+
// CSRFUsingSecFetchSite is a context key for CSRF middleware what is set when the client browser is using Sec-Fetch-Site
17+
// header and the request is deemed safe.
18+
// It is a dummy token value that can be used to render CSRF token for form by handlers.
19+
const CSRFUsingSecFetchSite = "_echo_csrf_using_sec_fetch_site_"
20+
1621
// CSRFConfig defines the config for CSRF middleware.
1722
type CSRFConfig struct {
1823
// Skipper defines a function to skip middleware.
@@ -277,6 +282,11 @@ func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error)
277282
}
278283

279284
if isSafe {
285+
// This helps handlers that support older token-based CSRF protection.
286+
// We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in
287+
// the future with this dummy token value it is OK. Although the request is safe, the template rendered by the
288+
// handler may need this value to render CSRF token for form.
289+
c.Set(config.ContextKey, CSRFUsingSecFetchSite)
280290
return true, nil
281291
}
282292
// we are here when request is state-changing and `cross-site` or `same-site`

middleware/csrf_test.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,14 @@ func TestCSRFWithConfig(t *testing.T) {
238238
expectEmptyBody bool
239239
expectMWError string
240240
expectCookieContains string
241+
expectTokenInContext string
241242
expectErr string
242243
}{
243244
{
244245
name: "ok, GET",
245246
whenMethod: http.MethodGet,
246247
expectCookieContains: "_csrf",
248+
expectTokenInContext: "TESTTOKEN",
247249
},
248250
{
249251
name: "ok, POST valid token",
@@ -253,6 +255,7 @@ func TestCSRFWithConfig(t *testing.T) {
253255
},
254256
whenMethod: http.MethodPost,
255257
expectCookieContains: "_csrf",
258+
expectTokenInContext: token,
256259
},
257260
{
258261
name: "nok, POST without token",
@@ -281,13 +284,23 @@ func TestCSRFWithConfig(t *testing.T) {
281284
},
282285
whenMethod: http.MethodGet,
283286
expectCookieContains: "_csrf",
287+
expectTokenInContext: "TESTTOKEN",
284288
},
285289
{
286290
name: "ok, unsafe method + SecFetchSite=same-origin passes",
287291
whenHeaders: map[string]string{
288292
echo.HeaderSecFetchSite: "same-origin",
289293
},
290-
whenMethod: http.MethodPost,
294+
whenMethod: http.MethodPost,
295+
expectTokenInContext: "_echo_csrf_using_sec_fetch_site_",
296+
},
297+
{
298+
name: "ok, safe method + SecFetchSite=same-origin passes",
299+
whenHeaders: map[string]string{
300+
echo.HeaderSecFetchSite: "same-origin",
301+
},
302+
whenMethod: http.MethodGet,
303+
expectTokenInContext: "_echo_csrf_using_sec_fetch_site_",
291304
},
292305
{
293306
name: "nok, unsafe method + SecFetchSite=same-cross blocked",
@@ -315,6 +328,12 @@ func TestCSRFWithConfig(t *testing.T) {
315328
if tc.givenConfig != nil {
316329
config = *tc.givenConfig
317330
}
331+
if config.Generator == nil {
332+
config.Generator = func() string {
333+
return "TESTTOKEN"
334+
}
335+
}
336+
318337
mw, err := config.ToMiddleware()
319338
if tc.expectMWError != "" {
320339
assert.EqualError(t, err, tc.expectMWError)
@@ -323,6 +342,8 @@ func TestCSRFWithConfig(t *testing.T) {
323342
assert.NoError(t, err)
324343

325344
h := mw(func(c *echo.Context) error {
345+
cToken := c.Get(cmp.Or(config.ContextKey, DefaultCSRFConfig.ContextKey))
346+
assert.Equal(t, tc.expectTokenInContext, cToken)
326347
return c.String(http.StatusOK, "test")
327348
})
328349

0 commit comments

Comments
 (0)