@@ -238,12 +238,14 @@ func TestCSRFWithConfig(t *testing.T) {
238238 expectEmptyBody bool
239239 expectMWError string
240240 expectCookieContains string
241+ expectTokenInContext string
241242 expectErr string
242243 }{
243244 {
244245 name : "ok, GET" ,
245246 whenMethod : http .MethodGet ,
246247 expectCookieContains : "_csrf" ,
248+ expectTokenInContext : "TESTTOKEN" ,
247249 },
248250 {
249251 name : "ok, POST valid token" ,
@@ -253,6 +255,7 @@ func TestCSRFWithConfig(t *testing.T) {
253255 },
254256 whenMethod : http .MethodPost ,
255257 expectCookieContains : "_csrf" ,
258+ expectTokenInContext : token ,
256259 },
257260 {
258261 name : "nok, POST without token" ,
@@ -281,13 +284,23 @@ func TestCSRFWithConfig(t *testing.T) {
281284 },
282285 whenMethod : http .MethodGet ,
283286 expectCookieContains : "_csrf" ,
287+ expectTokenInContext : "TESTTOKEN" ,
284288 },
285289 {
286290 name : "ok, unsafe method + SecFetchSite=same-origin passes" ,
287291 whenHeaders : map [string ]string {
288292 echo .HeaderSecFetchSite : "same-origin" ,
289293 },
290- whenMethod : http .MethodPost ,
294+ whenMethod : http .MethodPost ,
295+ expectTokenInContext : "_echo_csrf_using_sec_fetch_site_" ,
296+ },
297+ {
298+ name : "ok, safe method + SecFetchSite=same-origin passes" ,
299+ whenHeaders : map [string ]string {
300+ echo .HeaderSecFetchSite : "same-origin" ,
301+ },
302+ whenMethod : http .MethodGet ,
303+ expectTokenInContext : "_echo_csrf_using_sec_fetch_site_" ,
291304 },
292305 {
293306 name : "nok, unsafe method + SecFetchSite=same-cross blocked" ,
@@ -315,6 +328,12 @@ func TestCSRFWithConfig(t *testing.T) {
315328 if tc .givenConfig != nil {
316329 config = * tc .givenConfig
317330 }
331+ if config .Generator == nil {
332+ config .Generator = func () string {
333+ return "TESTTOKEN"
334+ }
335+ }
336+
318337 mw , err := config .ToMiddleware ()
319338 if tc .expectMWError != "" {
320339 assert .EqualError (t , err , tc .expectMWError )
@@ -323,6 +342,8 @@ func TestCSRFWithConfig(t *testing.T) {
323342 assert .NoError (t , err )
324343
325344 h := mw (func (c * echo.Context ) error {
345+ cToken := c .Get (cmp .Or (config .ContextKey , DefaultCSRFConfig .ContextKey ))
346+ assert .Equal (t , tc .expectTokenInContext , cToken )
326347 return c .String (http .StatusOK , "test" )
327348 })
328349
0 commit comments