From f905dabc1f70757cc7217cb646f2880109dcf108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20Andr=C3=A9s=20Virviescas=20Santana?= Date: Tue, 21 Jan 2020 16:05:09 +0100 Subject: [PATCH] Align features with julienschmidt/httprouter (#18) * Align features with julienschmidt/httprouter * Add extra functionalities * Tests refactor * Add go.1.13 support --- .travis.yml | 1 + examples/auth/auth.go | 6 +- go.mod | 2 +- go.sum | 4 +- path.go | 189 +++---- path_test.go | 82 ++- router.go | 362 +++++++------ router_test.go | 1129 ++++++++++++++++------------------------- tolower.go | 9 - tolower_go112.go | 105 ---- tree.go | 665 ++++++++++++------------ tree_test.go | 96 ++-- 12 files changed, 1208 insertions(+), 1442 deletions(-) delete mode 100644 tolower.go delete mode 100644 tolower_go112.go diff --git a/.travis.yml b/.travis.yml index 66ade2e..122c3b7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ go: - 1.10.x - 1.11.x - 1.12.x + - 1.13.x - tip os: diff --git a/examples/auth/auth.go b/examples/auth/auth.go index 7001c63..e5b12dc 100644 --- a/examples/auth/auth.go +++ b/examples/auth/auth.go @@ -6,7 +6,7 @@ import ( "log" "strings" - "github.com/elithrar/simple-scrypt" + scrypt "github.com/elithrar/simple-scrypt" "github.com/fasthttp/router" "github.com/valyala/fasthttp" ) @@ -43,7 +43,7 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { // BasicAuth is the basic auth handler func BasicAuth(h fasthttp.RequestHandler, requiredUser string, requiredPasswordHash []byte) fasthttp.RequestHandler { - return fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { + return func(ctx *fasthttp.RequestCtx) { // Get the Basic Authentication credentials user, password, hasAuth := basicAuth(ctx) @@ -77,7 +77,7 @@ func BasicAuth(h fasthttp.RequestHandler, requiredUser string, requiredPasswordH // Request Basic Authentication otherwise ctx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized) ctx.Response.Header.Set("WWW-Authenticate", "Basic realm=Restricted") - }) + } } // Index is the index handler diff --git a/go.mod b/go.mod index 9d0397e..4c116d6 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/fasthttp/router go 1.13 require ( - github.com/savsgio/gotils v0.0.0-20190925070755-524bc4f47500 + github.com/savsgio/gotils v0.0.0-20200117113501-90175b0fbe3f github.com/valyala/bytebufferpool v1.0.0 github.com/valyala/fasthttp v1.8.0 ) diff --git a/go.sum b/go.sum index 873a236..4d9e96c 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/klauspost/compress v1.8.2 h1:Bx0qjetmNjdFXASH02NSAREKpiaDwkO1DRZ3dV2K github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/cpuid v1.2.1 h1:vJi+O/nMdFt0vqm8NZBI6wzALWdA2X+egi0ogNyrC/w= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= -github.com/savsgio/gotils v0.0.0-20190925070755-524bc4f47500 h1:9Pi10H7E8E79/x2HSe1FmMGd7BJ1WAqDKzwjpv+ojFg= -github.com/savsgio/gotils v0.0.0-20190925070755-524bc4f47500/go.mod h1:lHhJedqxCoHN+zMtwGNTXWmF0u9Jt363FYRhV6g0CdY= +github.com/savsgio/gotils v0.0.0-20200117113501-90175b0fbe3f h1:PgA+Olipyj258EIEYnpFFONrrCcAIWNUNoFhUfMqAGY= +github.com/savsgio/gotils v0.0.0-20200117113501-90175b0fbe3f/go.mod h1:lHhJedqxCoHN+zMtwGNTXWmF0u9Jt363FYRhV6g0CdY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.8.0 h1:actnGGBYtGQmxVaZxyZpp57Vcc2NhcO7mMN0IMwCC0w= diff --git a/path.go b/path.go index b9f1e71..fe5a1ec 100644 --- a/path.go +++ b/path.go @@ -7,50 +7,14 @@ package router import ( "strings" - "sync" "github.com/savsgio/gotils" ) -type cleanPathBuffer struct { - n int - r int - w int - trailing bool - buf []byte -} - -var cleanPathBufferPool = sync.Pool{ - New: func() interface{} { - return &cleanPathBuffer{ - n: 0, - r: 0, - w: 1, - trailing: false, - buf: make([]byte, 140), - } - }, -} - -func (cpb *cleanPathBuffer) reset() { - cpb.n = 0 - cpb.r = 0 - cpb.w = 1 - cpb.trailing = false - // cpb.buf = cpb.buf[:0] -} - -func acquireCleanPathBuffer() *cleanPathBuffer { - return cleanPathBufferPool.Get().(*cleanPathBuffer) -} - -func releaseCleanPathBuffer(cpb *cleanPathBuffer) { - cpb.reset() - cleanPathBufferPool.Put(cpb) -} +const stackBufSize = 128 // CleanPath is the URL version of path.Clean, it returns a canonical URL path -// for path, eliminating . and .. elements. +// for p, eliminating . and .. elements. // // The following rules are applied iteratively until no further processing can // be done: @@ -62,86 +26,133 @@ func releaseCleanPathBuffer(cpb *cleanPathBuffer) { // that is, replace "/.." by "/" at the beginning of a path. // // If the result of this process is an empty string, "/" is returned -func CleanPath(path string) string { - cpb := acquireCleanPathBuffer() - cleanPathWithBuffer(cpb, path) +func CleanPath(p string) string { + // Turn empty string into "/" + if p == "" { + return "/" + } - s := string(cpb.buf) - releaseCleanPathBuffer(cpb) + // Reasonably sized buffer on stack to avoid allocations in the common case. + // If a larger buffer is required, it gets allocated dynamically. + buf := make([]byte, 0, stackBufSize) - return s -} + n := len(p) -func cleanPathWithBuffer(cpb *cleanPathBuffer, path string) { - // Turn empty string into "/" - if path == "" { - cpb.buf = append(cpb.buf[:0], '/') - return - } + // Invariants: + // reading from path; r is index of next byte to process. + // writing to buf; w is index of next byte to write. - cpb.n = len(path) - cpb.buf = gotils.ExtendByteSlice(cpb.buf, len(path)+1) - cpb.buf[0] = '/' + // path must start with '/' + r := 1 + w := 1 - cpb.trailing = cpb.n > 2 && path[cpb.n-1] == '/' + if p[0] != '/' { + r = 0 + + if n+1 > stackBufSize { + buf = make([]byte, n+1) + } else { + buf = buf[:n+1] + } + buf[0] = '/' + } + + trailing := n > 1 && p[n-1] == '/' // A bit more clunky without a 'lazybuf' like the path package, but the loop - // gets completely inlined (bufApp). So in contrast to the path package this - // loop has no expensive function calls (except 1x make) + // gets completely inlined (bufApp calls). + // So in contrast to the path package this loop has no expensive function + // calls (except make, if needed). - for cpb.r < cpb.n { - // println(path[:cpb.r], " ####### ", string(path[cpb.r]), " ####### ", string(cpb.buf)) + for r < n { switch { - case path[cpb.r] == '/': + case p[r] == '/': // empty path element, trailing slash is added after the end - cpb.r++ + r++ - case path[cpb.r] == '.' && cpb.r+1 == cpb.n: - cpb.trailing = true - cpb.r++ + case p[r] == '.' && r+1 == n: + trailing = true + r++ - case path[cpb.r] == '.' && path[cpb.r+1] == '/': + case p[r] == '.' && p[r+1] == '/': // . element - cpb.r++ + r += 2 - case path[cpb.r] == '.' && path[cpb.r+1] == '.' && (cpb.r+2 == cpb.n || path[cpb.r+2] == '/'): + case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'): // .. element: remove to last / - cpb.r += 2 + r += 3 - if cpb.w > 1 { + if w > 1 { // can backtrack - cpb.w-- - - for cpb.w > 1 && cpb.buf[cpb.w] != '/' { - cpb.w-- + w-- + + if len(buf) == 0 { + for w > 1 && p[w] != '/' { + w-- + } + } else { + for w > 1 && buf[w] != '/' { + w-- + } } - } default: - // real path element. - // add slash if needed - if cpb.w > 1 { - cpb.buf[cpb.w] = '/' - cpb.w++ + // Real path element. + // Add slash if needed + if w > 1 { + bufApp(&buf, p, w, '/') + w++ } - // copy element - for cpb.r < cpb.n && path[cpb.r] != '/' { - cpb.buf[cpb.w] = path[cpb.r] - cpb.w++ - cpb.r++ + // Copy element + for r < n && p[r] != '/' { + bufApp(&buf, p, w, p[r]) + w++ + r++ } } } - // re-append trailing slash - if cpb.trailing && cpb.w > 1 { - cpb.buf[cpb.w] = '/' - cpb.w++ + // Re-append trailing slash + if trailing && w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + // If the original string was not modified (or only shortened at the end), + // return the respective substring of the original string. + // Otherwise return a new string from the buffer. + if len(buf) == 0 { + return p[:w] } + return string(buf[:w]) +} + +// Internal helper to lazily create a buffer if necessary. +// Calls to this function get inlined. +func bufApp(buf *[]byte, s string, w int, c byte) { + b := *buf + if len(b) == 0 { + // No modification of the original string so far. + // If the next character is the same as in the original string, we do + // not yet have to allocate a buffer. + if s[w] == c { + return + } - cpb.buf = cpb.buf[:cpb.w] + // Otherwise use either the stack buffer, if it is large enough, or + // allocate a new buffer on the heap, and copy all previous characters. + if l := len(s); l > cap(b) { + *buf = make([]byte, len(s)) + } else { + *buf = (*buf)[:l] + } + b = *buf + + copy(b, s[:w]) + } + b[w] = c } // returns all possible paths when the original path has optional arguments diff --git a/path_test.go b/path_test.go index e8530fc..db605b4 100644 --- a/path_test.go +++ b/path_test.go @@ -6,15 +6,17 @@ package router import ( - "runtime" + "strings" "testing" "github.com/valyala/fasthttp" ) -var cleanTests = []struct { +type cleanPathTest struct { path, result string -}{ +} + +var cleanTests = []cleanPathTest{ // Already clean {"/", "/"}, {"/abc", "/abc"}, @@ -24,6 +26,7 @@ var cleanTests = []struct { // missing root {"", "/"}, + {"a/", "/a/"}, {"abc", "/abc"}, {"abc/def", "/abc/def"}, {"a/b/c", "/a/b/c"}, @@ -68,10 +71,10 @@ var cleanTests = []struct { func TestPathClean(t *testing.T) { for _, test := range cleanTests { if s := CleanPath(test.path); s != test.result { - t.Errorf("CleanPath(%s) = %s, want %s", test.path, s, test.result) + t.Errorf("CleanPath(%q) = %q, want %q", test.path, s, test.result) } if s := CleanPath(test.result); s != test.result { - t.Errorf("CleanPath(%s) = %s, want %s", test.result, s, test.result) + t.Errorf("CleanPath(%q) = %q, want %q", test.result, s, test.result) } } } @@ -80,10 +83,6 @@ func TestPathCleanMallocs(t *testing.T) { if testing.Short() { t.Skip("skipping malloc count in short mode") } - if runtime.GOMAXPROCS(0) > 1 { - t.Log("skipping AllocsPerRun checks; GOMAXPROCS>1") - return - } for _, test := range cleanTests { allocs := testing.AllocsPerRun(100, func() { CleanPath(test.result) }) @@ -93,6 +92,51 @@ func TestPathCleanMallocs(t *testing.T) { } } +func BenchmarkPathClean(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, test := range cleanTests { + CleanPath(test.path) + } + } +} + +func genLongPaths() (testPaths []cleanPathTest) { + for i := 1; i <= 1234; i++ { + ss := strings.Repeat("a", i) + + correctPath := "/" + ss + testPaths = append(testPaths, cleanPathTest{ + path: correctPath, + result: correctPath, + }, cleanPathTest{ + path: ss, + result: correctPath, + }, cleanPathTest{ + path: "//" + ss, + result: correctPath, + }, cleanPathTest{ + path: "/" + ss + "/b/..", + result: correctPath, + }) + } + return +} + +func TestPathCleanLong(t *testing.T) { + cleanTests := genLongPaths() + + for _, test := range cleanTests { + if s := CleanPath(test.path); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.path, s, test.result) + } + if s := CleanPath(test.result); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.result, s, test.result) + } + } +} + func TestGetOptionalPath(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(fasthttp.StatusOK) @@ -120,20 +164,14 @@ func TestGetOptionalPath(t *testing.T) { } } -func BenchmarkCleanPathWithBuffer(b *testing.B) { - path := "/../bench/" - cpb := acquireCleanPathBuffer() +func BenchmarkPathCleanLong(b *testing.B) { + cleanTests := genLongPaths() + b.ResetTimer() + b.ReportAllocs() for i := 0; i < b.N; i++ { - cleanPathWithBuffer(cpb, path) - cpb.reset() - } -} - -func BenchmarkCleanPath(b *testing.B) { - path := "/../bench/" - - for i := 0; i < b.N; i++ { - CleanPath(path) + for _, test := range cleanTests { + CleanPath(test.path) + } } } diff --git a/router.go b/router.go index bbd6a63..6e7fe6a 100644 --- a/router.go +++ b/router.go @@ -6,38 +6,35 @@ // // A trivial example is: // -// package main +// package main +// +// import ( +// "fmt" +// "log" -// import ( -// "fmt" -// "log" +// "github.com/fasthttp/router" +// ) +// +// func Index(ctx *fasthttp.RequestCtx) { +// fmt.Fprint(w, "Welcome!\n") +// } +// +// func Hello(ctx *fasthttp.RequestCtx) { +// fmt.Fprintf(w, "hello, %s!\n", ctx.UserValue("name")) +// } +// +// func main() { +// r := router.New() +// r.GET("/", Index) +// r.GET("/hello/:name", Hello) // -// "github.com/fasthttp/router" -// "github.com/valyala/fasthttp" -// ) - -// func Index(ctx *fasthttp.RequestCtx) { -// fmt.Fprint(ctx, "Welcome!\n") -// } - -// func Hello(ctx *fasthttp.RequestCtx) { -// fmt.Fprintf(ctx, "hello, %s!\n", ctx.UserValue("name")) -// } - -// func main() { -// r := router.New() -// r.GET("/", Index) -// g := r.Group("/foo", Index) -// g.GET("/bar", Index) -// r.GET("/hello/:name", Hello) - -// log.Fatal(fasthttp.ListenAndServe(":8080", r.Handler)) -// } +// log.Fatal(fasthttp.ListenAndServe(":8080", r.Handler)) +// } // // The router matches incoming requests by the request method and the path. // If a handle is registered for this path and method, the router delegates the // request to that function. -// For the methods GET, POST, PUT, PATCH and DELETE shortcut functions exist to +// For the methods GET, POST, PUT, PATCH, DELETE and OPTIONS shortcut functions exist to // register handles, for all other methods router.Handle can be used. // // The registered path, against which the router matches incoming requests, can @@ -67,15 +64,15 @@ // /files/templates/article.html match: filepath="/templates/article.html" // /files no match, but the router would redirect // -// The value of parameters is inside ctx.UserValue -// To retrieve the value of a parameter: -// // use the name of the parameter -// user := ps.UserValue("user") -// - +// The value of parameters is saved in ctx.UserValue(), consisting +// each of a key and a value. The slice is passed to the Handle func as a third +// parameter. +// To retrieve the value of a parameter,gets by the name of the parameter +// user := ctx.UserValue("user") // defined by :user or *user package router import ( + "fmt" "strings" "github.com/savsgio/gotils" @@ -88,19 +85,30 @@ var ( questionMark = []byte("?") ) -// Router is a http.Handler which can be used to dispatch requests to different +// MatchedRoutePathParam is the param name under which the path of the matched +// route is stored, if Router.SaveMatchedRoutePath is set. +var MatchedRoutePathParam = fmt.Sprintf("__matchedRoutePath::%s__", gotils.RandBytes(make([]byte, 15))) + +// Router is a fasthttp.RequestHandler which can be used to dispatch requests to different // handler functions via configurable routes type Router struct { parent *Router beginPath string - trees map[string]*node registeredPaths map[string][]string + trees map[string]*node + + // If enabled, adds the matched route path onto the ctx.UserValue context + // before invoking the handler. + // The matched route path is only added to handlers of routes that were + // registered when this option was enabled. + SaveMatchedRoutePath bool + // Enables automatic redirection if the current route can't be matched but a // handler for the path with (without) the trailing slash exists. // For example if /foo/ is requested but a route only exists for /foo, the // client is redirected to /foo with http status code 301 for GET requests - // and 307 for all other request methods. + // and 308 for all other request methods. RedirectTrailingSlash bool // If enabled, the router tries to fix the current request path, if no @@ -108,7 +116,7 @@ type Router struct { // First superfluous path elements like ../ or // are removed. // Afterwards the router does a case-insensitive lookup of the cleaned path. // If a handle can be found for this route, the router makes a redirection - // to the corrected path with status code 301 for GET requests and 307 for + // to the corrected path with status code 301 for GET requests and 308 for // all other request methods. // For example /FOO and /..//Foo could be redirected to /foo. // RedirectTrailingSlash is independent of this option. @@ -126,13 +134,22 @@ type Router struct { // Custom OPTIONS handlers take priority over automatic replies. HandleOPTIONS bool - // Configurable http.Handler which is called when no matching route is - // found. If it is not set, http.NotFound is used. + // An optional fasthttp.RequestHandler that is called on automatic OPTIONS requests. + // The handler is only called if HandleOPTIONS is true and no OPTIONS + // handler for the specific path was set. + // The "Allowed" header is set before calling the handler. + GlobalOPTIONS fasthttp.RequestHandler + + // Cached value of global (*) allowed methods + globalAllowed string + + // Configurable fasthttp.RequestHandler which is called when no matching route is + // found. If it is not set, default NotFound is used. NotFound fasthttp.RequestHandler - // Configurable http.Handler which is called when a request + // Configurable fasthttp.RequestHandler which is called when a request // cannot be routed and HandleMethodNotAllowed is true. - // If it is not set, http.Error with http.StatusMethodNotAllowed is used. + // If it is not set, ctx.Error with fasthttp.StatusMethodNotAllowed is used. // The "Allow" header with allowed request methods is set before the handler // is called. MethodNotAllowed fasthttp.RequestHandler @@ -150,7 +167,6 @@ type Router struct { func New() *Router { return &Router{ beginPath: "/", - trees: make(map[string]*node), registeredPaths: make(map[string][]string), RedirectTrailingSlash: true, RedirectFixedPath: true, @@ -169,39 +185,47 @@ func (r *Router) Group(path string) *Router { return g } -// GET is a shortcut for router.Handle("GET", path, handle) +func (r *Router) saveMatchedRoutePath(path string, handle fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(MatchedRoutePathParam, path) + handle(ctx) + + } +} + +// GET is a shortcut for router.Handle(fasthttp.MethodGet, path, handle) func (r *Router) GET(path string, handle fasthttp.RequestHandler) { - r.Handle("GET", path, handle) + r.Handle(fasthttp.MethodGet, path, handle) } -// HEAD is a shortcut for router.Handle("HEAD", path, handle) +// HEAD is a shortcut for router.Handle(fasthttp.MethodHead, path, handle) func (r *Router) HEAD(path string, handle fasthttp.RequestHandler) { - r.Handle("HEAD", path, handle) + r.Handle(fasthttp.MethodHead, path, handle) } -// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle) +// OPTIONS is a shortcut for router.Handle(fasthttp.MethodOptions, path, handle) func (r *Router) OPTIONS(path string, handle fasthttp.RequestHandler) { - r.Handle("OPTIONS", path, handle) + r.Handle(fasthttp.MethodOptions, path, handle) } -// POST is a shortcut for router.Handle("POST", path, handle) +// POST is a shortcut for router.Handle(fasthttp.MethodPost, path, handle) func (r *Router) POST(path string, handle fasthttp.RequestHandler) { - r.Handle("POST", path, handle) + r.Handle(fasthttp.MethodPost, path, handle) } -// PUT is a shortcut for router.Handle("PUT", path, handle) +// PUT is a shortcut for router.Handle(fasthttp.MethodPut, path, handle) func (r *Router) PUT(path string, handle fasthttp.RequestHandler) { - r.Handle("PUT", path, handle) + r.Handle(fasthttp.MethodPut, path, handle) } -// PATCH is a shortcut for router.Handle("PATCH", path, handle) +// PATCH is a shortcut for router.Handle(fasthttp.MethodPatch, path, handle) func (r *Router) PATCH(path string, handle fasthttp.RequestHandler) { - r.Handle("PATCH", path, handle) + r.Handle(fasthttp.MethodPatch, path, handle) } -// DELETE is a shortcut for router.Handle("DELETE", path, handle) +// DELETE is a shortcut for router.Handle(fasthttp.MethodDelete, path, handle) func (r *Router) DELETE(path string, handle fasthttp.RequestHandler) { - r.Handle("DELETE", path, handle) + r.Handle(fasthttp.MethodDelete, path, handle) } // Handle registers a new request handle with the given path and method. @@ -213,9 +237,17 @@ func (r *Router) DELETE(path string, handle fasthttp.RequestHandler) { // frequently used, non-standardized or custom methods (e.g. for internal // communication with a proxy). func (r *Router) Handle(method, path string, handle fasthttp.RequestHandler) { - if path[0] != '/' { + varsCount := uint16(0) + + if method == "" { + panic("method must not be empty") + } + if len(path) < 1 || path[0] != '/' { panic("path must begin with '/' in path '" + path + "'") } + if handle == nil { + panic("handle must not be nil") + } if r.beginPath != "/" { path = r.beginPath + path @@ -229,10 +261,21 @@ func (r *Router) Handle(method, path string, handle fasthttp.RequestHandler) { return } + if r.SaveMatchedRoutePath { + varsCount++ + handle = r.saveMatchedRoutePath(path, handle) + } + + if r.trees == nil { + r.trees = make(map[string]*node) + } + root := r.trees[method] if root == nil { root = new(node) r.trees[method] = root + + r.globalAllowed = r.allowed("*", "") } optionalPaths := getOptionalPaths(path) @@ -252,9 +295,9 @@ func (r *Router) Handle(method, path string, handle fasthttp.RequestHandler) { // path /defined/root/dir/*filepath. // For example if root is "/etc" and *filepath is "passwd", the local file // "/etc/passwd" would be served. -// Internally a http.FileServer is used, therefore http.NotFound is used instead -// of the Router's NotFound handler. -// router.ServeFiles("/src/*filepath", "/var/www") +// Internally a fasthttp.FSHandler is used, therefore http.NotFound is used instead +// Use: +// router.ServeFiles("/src/*filepath", "./") func (r *Router) ServeFiles(path string, rootPath string) { if len(path) < 10 || path[len(path)-10:] != "/*filepath" { panic("path must end with /*filepath in path '" + path + "'") @@ -282,8 +325,9 @@ func (r *Router) ServeFiles(path string, rootPath string) { // path /defined/root/dir/*filepath. // For example if root is "/etc" and *filepath is "passwd", the local file // "/etc/passwd" would be served. -// Internally a http.FileServer is used, therefore http.NotFound is used instead +// Internally a fasthttp.FSHandler is used, therefore http.NotFound is used instead // of the Router's NotFound handler. +// Use: // router.ServeFilesCustom("/src/*filepath", *customFS) func (r *Router) ServeFilesCustom(path string, fs *fasthttp.FS) { if len(path) < 10 || path[len(path)-10:] != "/*filepath" { @@ -312,7 +356,80 @@ func (r *Router) ServeFilesCustom(path string, fs *fasthttp.FS) { }) } -// Handler makes the router implement the fasthttp.ListenAndServe interface. +func (r *Router) recv(ctx *fasthttp.RequestCtx) { + if rcv := recover(); rcv != nil { + r.PanicHandler(ctx, rcv) + } +} + +// Lookup allows the manual lookup of a method + path combo. +// This is e.g. useful to build a framework around this router. +// If the path was found, it returns the handle function and the path parameter +// values. Otherwise the third return value indicates whether a redirection to +// the same path with an extra / without the trailing slash should be performed. +func (r *Router) Lookup(method, path string, ctx *fasthttp.RequestCtx) (fasthttp.RequestHandler, bool) { + if root := r.trees[method]; root != nil { + handle, tsr := root.getValue(path, ctx) + if handle == nil { + return nil, tsr + } + + return handle, tsr + } + return nil, false +} + +func (r *Router) allowed(path, reqMethod string) (allow string) { + allowed := make([]string, 0, 9) + + if path == "*" || path == "/*" { // server-wide{ // server-wide + // empty method is used for internal calls to refresh the cache + if reqMethod == "" { + for method := range r.trees { + if method == fasthttp.MethodOptions { + continue + } + // Add request method to list of allowed methods + allowed = append(allowed, method) + } + } else { + return r.globalAllowed + } + } else { // specific path + for method := range r.trees { + // Skip the requested method - we already tried this one + if method == reqMethod || method == fasthttp.MethodOptions { + continue + } + + handle, _ := r.trees[method].getValue(path, nil) + if handle != nil { + // Add request method to list of allowed methods + allowed = append(allowed, method) + } + } + } + + if len(allowed) > 0 { + // Add request method to list of allowed methods + allowed = append(allowed, fasthttp.MethodOptions) + + // Sort allowed methods. + // sort.Strings(allowed) unfortunately causes unnecessary allocations + // due to allowed being moved to the heap and interface conversion + for i, l := 1, len(allowed); i < l; i++ { + for j := i; j > 0 && allowed[j] < allowed[j-1]; j-- { + allowed[j], allowed[j-1] = allowed[j-1], allowed[j] + } + } + + // return as comma separated list + return strings.Join(allowed, ", ") + } + return +} + +// Handler makes the router implement the http.Handler interface. func (r *Router) Handler(ctx *fasthttp.RequestCtx) { if r.PanicHandler != nil { defer r.recv(ctx) @@ -322,15 +439,15 @@ func (r *Router) Handler(ctx *fasthttp.RequestCtx) { method := gotils.B2S(ctx.Method()) if root := r.trees[method]; root != nil { - if f, tsr := root.getValue(path, ctx); f != nil { - f(ctx) + if handle, tsr := root.getValue(path, ctx); handle != nil { + handle(ctx) return - } else if method != "CONNECT" && path != "/" { - code := 301 // Permanent redirect, request with GET method - if method != "GET" { - // Temporary redirect, request with same method - // As of Go 1.3, Go does not support status code 308. - code = 307 + } else if method != fasthttp.MethodConnect && path != "/" { + // Moved Permanently, request with GET method + code := fasthttp.StatusMovedPermanently + if method != fasthttp.MethodGet { + // Permanent Redirect, request with same method + code = fasthttp.StatusPermanentRedirect } if tsr && r.RedirectTrailingSlash { @@ -343,31 +460,31 @@ func (r *Router) Handler(ctx *fasthttp.RequestCtx) { uri.WriteString("/") } - if len(ctx.URI().QueryString()) > 0 { - uri.WriteString("?") - uri.Write(ctx.QueryArgs().QueryString()) + queryBuf := ctx.URI().QueryString() + if len(queryBuf) > 0 { + uri.Write(questionMark) + uri.Write(queryBuf) } ctx.Redirect(uri.String(), code) bytebufferpool.Put(uri) - return } // Try to fix the request path if r.RedirectFixedPath { - cpb := acquireCleanPathBuffer() - cleanPathWithBuffer(cpb, path) - fixedPath, found := root.findCaseInsensitivePath(gotils.B2S(cpb.buf), r.RedirectTrailingSlash) - releaseCleanPathBuffer(cpb) - + fixedPath, found := root.findCaseInsensitivePath( + CleanPath(path), + r.RedirectTrailingSlash, + ) if found { queryBuf := ctx.URI().QueryString() if len(queryBuf) > 0 { fixedPath = append(fixedPath, questionMark...) fixedPath = append(fixedPath, queryBuf...) } + ctx.RedirectBytes(fixedPath, code) return } @@ -375,28 +492,25 @@ func (r *Router) Handler(ctx *fasthttp.RequestCtx) { } } - if method == "OPTIONS" { + if method == fasthttp.MethodOptions && r.HandleOPTIONS { // Handle OPTIONS requests - if r.HandleOPTIONS { - if allow := r.allowed(path, method); len(allow) > 0 { - ctx.Response.Header.Set("Allow", allow) - return + if allow := r.allowed(path, fasthttp.MethodOptions); allow != "" { + ctx.Response.Header.Set("Allow", allow) + if r.GlobalOPTIONS != nil { + r.GlobalOPTIONS(ctx) } + return } - } else { - // Handle 405 - if r.HandleMethodNotAllowed { - if allow := r.allowed(path, method); len(allow) > 0 { - ctx.Response.Header.Set("Allow", allow) - if r.MethodNotAllowed != nil { - r.MethodNotAllowed(ctx) - } else { - ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed) - ctx.SetContentTypeBytes(defaultContentType) - ctx.SetBodyString(fasthttp.StatusMessage(fasthttp.StatusMethodNotAllowed)) - } - return + } else if r.HandleMethodNotAllowed { // Handle 405 + if allow := r.allowed(path, method); allow != "" { + ctx.Response.Header.Set("Allow", allow) + if r.MethodNotAllowed != nil { + r.MethodNotAllowed(ctx) + } else { + ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed) + ctx.SetBodyString(fasthttp.StatusMessage(fasthttp.StatusMethodNotAllowed)) } + return } } @@ -408,63 +522,7 @@ func (r *Router) Handler(ctx *fasthttp.RequestCtx) { } } -// Lookup allows the manual lookup of a method + path combo. -// This is e.g. useful to build a framework around this router. -// If the path was found, it returns the handle function and the path parameter -// values. Otherwise the third return value indicates whether a redirection to -// the same path with an extra / without the trailing slash should be performed. -func (r *Router) Lookup(method, path string, ctx *fasthttp.RequestCtx) (fasthttp.RequestHandler, bool) { - if root := r.trees[method]; root != nil { - return root.getValue(path, ctx) - } - return nil, false -} - // List returns all registered routes grouped by method func (r *Router) List() map[string][]string { return r.registeredPaths } - -func (r *Router) allowed(path, reqMethod string) (allow string) { - if path == "*" || path == "/*" { // server-wide - for method := range r.trees { - if method == "OPTIONS" { - continue - } - - // add request method to list of allowed methods - if len(allow) == 0 { - allow = method - } else { - allow += ", " + method - } - } - } else { // specific path - for method := range r.trees { - // Skip the requested method - we already tried this one - if method == reqMethod || method == "OPTIONS" { - continue - } - - handle, _ := r.trees[method].getValue(path, nil) - if handle != nil { - // add request method to list of allowed methods - if len(allow) == 0 { - allow = method - } else { - allow += ", " + method - } - } - } - } - if len(allow) > 0 { - allow += ", OPTIONS" - } - return -} - -func (r *Router) recv(ctx *fasthttp.RequestCtx) { - if rcv := recover(); rcv != nil { - r.PanicHandler(ctx, rcv) - } -} diff --git a/router_test.go b/router_test.go index 5b38e86..3c3fc1c 100644 --- a/router_test.go +++ b/router_test.go @@ -10,7 +10,6 @@ import ( "fmt" "io/ioutil" "net" - "net/http" "os" "reflect" "strings" @@ -20,292 +19,255 @@ import ( "github.com/valyala/fasthttp" ) -func TestRouter(t *testing.T) { - r := New() +type readWriter struct { + net.Conn + r bytes.Buffer + w bytes.Buffer +} - routed := false - r.Handle("GET", "/user/:name", func(ctx *fasthttp.RequestCtx) { - routed = true - want := map[string]string{"name": "gopher"} +var zeroTCPAddr = &net.TCPAddr{ + IP: net.IPv4zero, +} - if ctx.UserValue("name") != want["name"] { - t.Fatalf("wrong wildcard values: want %v, got %v", want["name"], ctx.UserValue("name")) - } - ctx.Success("foo/bar", []byte("success")) - }) +func (rw *readWriter) Close() error { + return nil +} + +func (rw *readWriter) Read(b []byte) (int, error) { + return rw.r.Read(b) +} + +func (rw *readWriter) Write(b []byte) (int, error) { + return rw.w.Write(b) +} + +func (rw *readWriter) RemoteAddr() net.Addr { + return zeroTCPAddr +} + +func (rw *readWriter) LocalAddr() net.Addr { + return zeroTCPAddr +} + +func (rw *readWriter) SetReadDeadline(t time.Time) error { + return nil +} + +func (rw *readWriter) SetWriteDeadline(t time.Time) error { + return nil +} +type assertFn func(rw *readWriter) + +func assertWithTestServer(t *testing.T, uri string, handler fasthttp.RequestHandler, fn assertFn) { s := &fasthttp.Server{ - Handler: r.Handler, + Handler: handler, } rw := &readWriter{} - rw.r.WriteString("GET /user/gopher?baz HTTP/1.1\r\n\r\n") - ch := make(chan error) + + rw.r.WriteString(uri) go func() { ch <- s.ServeConn(rw) }() - select { case err := <-ch: if err != nil { t.Fatalf("return error %s", err) } - case <-time.After(100 * time.Millisecond): + case <-time.After(500 * time.Millisecond): t.Fatalf("timeout") } - if !routed { - t.Fatal("routing failed") - } + fn(rw) } -type handlerStruct struct { - handeled *bool -} +func TestRouter(t *testing.T) { + router := New() + + routed := false + router.Handle(fasthttp.MethodGet, "/user/:name", func(ctx *fasthttp.RequestCtx) { + routed = true + want := "gopher" + + param, ok := ctx.UserValue("name").(string) + + if !ok { + t.Fatalf("wrong wildcard values: param value is nil") + } + + if param != want { + t.Fatalf("wrong wildcard values: want %s, got %s", want, param) + } + }) + + ctx := new(fasthttp.RequestCtx) + ctx.Request.SetRequestURI("/user/gopher") + + router.Handler(ctx) -func (h handlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { - *h.handeled = true + if !routed { + t.Fatal("routing failed") + } } func TestRouterAPI(t *testing.T) { - var get, head, options, post, put, patch, deleted bool + var handled, get, head, options, post, put, patch, delete bool - r := New() - r.GET("/GET", func(ctx *fasthttp.RequestCtx) { + httpHandler := func(ctx *fasthttp.RequestCtx) { + handled = true + } + + router := New() + router.GET("/GET", func(ctx *fasthttp.RequestCtx) { get = true }) - r.HEAD("/GET", func(ctx *fasthttp.RequestCtx) { + router.HEAD("/GET", func(ctx *fasthttp.RequestCtx) { head = true }) - r.OPTIONS("/GET", func(ctx *fasthttp.RequestCtx) { + router.OPTIONS("/GET", func(ctx *fasthttp.RequestCtx) { options = true }) - r.POST("/POST", func(ctx *fasthttp.RequestCtx) { + router.POST("/POST", func(ctx *fasthttp.RequestCtx) { post = true }) - r.PUT("/PUT", func(ctx *fasthttp.RequestCtx) { + router.PUT("/PUT", func(ctx *fasthttp.RequestCtx) { put = true }) - r.PATCH("/PATCH", func(ctx *fasthttp.RequestCtx) { + router.PATCH("/PATCH", func(ctx *fasthttp.RequestCtx) { patch = true }) - r.DELETE("/DELETE", func(ctx *fasthttp.RequestCtx) { - deleted = true + router.DELETE("/DELETE", func(ctx *fasthttp.RequestCtx) { + delete = true }) + router.Handle(fasthttp.MethodGet, "/Handler", httpHandler) - s := &fasthttp.Server{ - Handler: r.Handler, - } - - rw := &readWriter{} - ch := make(chan error) + ctx := new(fasthttp.RequestCtx) - rw.r.WriteString("GET /GET HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + var request = func(method, path string) { + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI(path) + router.Handler(ctx) } + + request(fasthttp.MethodGet, "/GET") if !get { t.Error("routing GET failed") } - rw.r.WriteString("HEAD /GET HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + request(fasthttp.MethodHead, "/GET") if !head { t.Error("routing HEAD failed") } - rw.r.WriteString("OPTIONS /GET HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + request(fasthttp.MethodOptions, "/GET") if !options { t.Error("routing OPTIONS failed") } - rw.r.WriteString("POST /POST HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + request(fasthttp.MethodPost, "/POST") if !post { t.Error("routing POST failed") } - rw.r.WriteString("PUT /PUT HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + request(fasthttp.MethodPut, "/PUT") if !put { t.Error("routing PUT failed") } - rw.r.WriteString("PATCH /PATCH HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + request(fasthttp.MethodPatch, "/PATCH") if !patch { t.Error("routing PATCH failed") } - rw.r.WriteString("DELETE /DELETE HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if !deleted { + request(fasthttp.MethodDelete, "/DELETE") + if !delete { t.Error("routing DELETE failed") } + + request(fasthttp.MethodGet, "/Handler") + if !handled { + t.Error("routing Handler failed") + } } -func TestRouterRoot(t *testing.T) { - r := New() +func TestRouterInvalidInput(t *testing.T) { + router := New() + + handle := func(_ *fasthttp.RequestCtx) {} + recv := catchPanic(func() { - r.GET("noSlashRoot", nil) + router.Handle("", "/", handle) + }) + if recv == nil { + t.Fatal("registering empty method did not panic") + } + + recv = catchPanic(func() { + router.GET("", handle) + }) + if recv == nil { + t.Fatal("registering empty path did not panic") + } + + recv = catchPanic(func() { + router.GET("noSlashRoot", handle) }) if recv == nil { t.Fatal("registering path not beginning with '/' did not panic") } + + recv = catchPanic(func() { + router.GET("/", nil) + }) + if recv == nil { + t.Fatal("registering nil handler did not panic") + } } func TestRouterChaining(t *testing.T) { - r1 := New() - r2 := New() - r1.NotFound = r2.Handler + router1 := New() + router2 := New() + router1.NotFound = router2.Handler fooHit := false - r1.POST("/foo", func(ctx *fasthttp.RequestCtx) { + router1.POST("/foo", func(ctx *fasthttp.RequestCtx) { fooHit = true ctx.SetStatusCode(fasthttp.StatusOK) }) barHit := false - r2.POST("/bar", func(ctx *fasthttp.RequestCtx) { + router2.POST("/bar", func(ctx *fasthttp.RequestCtx) { barHit = true ctx.SetStatusCode(fasthttp.StatusOK) }) - s := &fasthttp.Server{ - Handler: r1.Handler, - } + ctx := new(fasthttp.RequestCtx) - rw := &readWriter{} - ch := make(chan error) + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.SetRequestURI("/foo") + router1.Handler(ctx) - rw.r.WriteString("POST /foo HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - br := bufio.NewReader(&rw.w) - var resp fasthttp.Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == fasthttp.StatusOK && fooHit) { + if !(ctx.Response.StatusCode() == fasthttp.StatusOK && fooHit) { t.Errorf("Regular routing failed with router chaining.") t.FailNow() } - rw.r.WriteString("POST /bar HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == fasthttp.StatusOK && barHit) { + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.SetRequestURI("/bar") + router1.Handler(ctx) + + if !(ctx.Response.StatusCode() == fasthttp.StatusOK && barHit) { t.Errorf("Chained routing failed with router chaining.") t.FailNow() } - rw.r.WriteString("POST /qax HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == fasthttp.StatusNotFound) { + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.SetRequestURI("/qax") + router1.Handler(ctx) + + if !(ctx.Response.StatusCode() == fasthttp.StatusNotFound) { t.Errorf("NotFound behavior failed with router chaining.") t.FailNow() } @@ -348,14 +310,7 @@ func TestRouterGroup(t *testing.T) { r6.ServeFiles("/static/*filepath", "./") r6.ServeFilesCustom("/custom/static/*filepath", &fasthttp.FS{Root: "./"}) - s := &fasthttp.Server{ - Handler: r1.Handler, - } - - rw := &readWriter{} - ch := make(chan error) - - requests := []string{ + uris := []string{ "POST /foo HTTP/1.1\r\n\r\n", // testing router group - r2 (grouped from r1) "POST /boo/bar HTTP/1.1\r\n\r\n", @@ -373,244 +328,106 @@ func TestRouterGroup(t *testing.T) { "GET /moo/foo/foo/custom/static/router.go HTTP/1.1\r\n\r\n", } - for _, req := range requests { + for _, uri := range uris { hit = false - rw.r.WriteString(req) - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) + assertWithTestServer(t, uri, r1.Handler, func(rw *readWriter) { + br := bufio.NewReader(&rw.w) + var resp fasthttp.Response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when reading response: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + if !(resp.Header.StatusCode() == fasthttp.StatusOK) { + t.Fatalf("Status code %d, want %d", resp.Header.StatusCode(), fasthttp.StatusOK) + } + if !strings.Contains(uri, "static") && !hit { + t.Fatalf("Regular routing failed with router chaining. %s", uri) + } + }) + } + + assertWithTestServer(t, "POST /qax HTTP/1.1\r\n\r\n", r1.Handler, func(rw *readWriter) { br := bufio.NewReader(&rw.w) var resp fasthttp.Response if err := resp.Read(br); err != nil { t.Fatalf("Unexpected error when reading response: %s", err) } - if !(resp.Header.StatusCode() == fasthttp.StatusOK) { - t.Fatalf("Status code %d, want %d", resp.Header.StatusCode(), fasthttp.StatusOK) - } - if !strings.Contains(req, "static") && !hit { - t.Fatalf("Regular routing failed with router chaining. %s", req) - } - } - - // Not found - rw.r.WriteString("POST /qax HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) + if !(resp.Header.StatusCode() == fasthttp.StatusNotFound) { + t.Errorf("NotFound behavior failed with router chaining.") + t.FailNow() } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - br := bufio.NewReader(&rw.w) - var resp fasthttp.Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == fasthttp.StatusNotFound) { - t.Errorf("NotFound behavior failed with router chaining.") - t.FailNow() - } + }) } func TestRouterOPTIONS(t *testing.T) { - // TODO: because fasthttp is not support OPTIONS method now, - // these test cases will be used in the future. handlerFunc := func(_ *fasthttp.RequestCtx) {} - r := New() - r.POST("/path", handlerFunc) + router := New() + router.POST("/path", handlerFunc) - // test not allowed - // * (server) - s := &fasthttp.Server{ - Handler: r.Handler, - } + ctx := new(fasthttp.RequestCtx) - rw := &readWriter{} - ch := make(chan error) + var checkHandling = func(path, expectedAllowed string, expectedStatusCode int) { + ctx.Request.Header.SetMethod(fasthttp.MethodOptions) + ctx.Request.SetRequestURI(path) + router.Handler(ctx) - rw.r.WriteString("OPTIONS * HTTP/1.1\r\nHost:\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) + if !(ctx.Response.StatusCode() == expectedStatusCode) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", ctx.Response.StatusCode(), ctx.Response.Header.String()) + } else if allow := string(ctx.Response.Header.Peek("Allow")); allow != expectedAllowed { + t.Error("unexpected Allow header value: " + allow) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - br := bufio.NewReader(&rw.w) - var resp fasthttp.Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != fasthttp.StatusOK { - t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", - resp.Header.StatusCode(), resp.Header.String()) - } else if allow := string(resp.Header.Peek("Allow")); allow != "POST, OPTIONS" { - t.Error("unexpected Allow header value: " + allow) } + // test not allowed + // * (server) + checkHandling("*", "OPTIONS, POST", fasthttp.StatusOK) + // path - rw.r.WriteString("OPTIONS /path HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != fasthttp.StatusOK { - t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", - resp.Header.StatusCode(), resp.Header.String()) - } else if allow := string(resp.Header.Peek("Allow")); allow != "POST, OPTIONS" { - t.Error("unexpected Allow header value: " + allow) - } + checkHandling("/path", "OPTIONS, POST", fasthttp.StatusOK) - rw.r.WriteString("OPTIONS /doesnotexist HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == fasthttp.StatusNotFound) { - t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", - resp.Header.StatusCode(), resp.Header.String()) + ctx.Request.Header.SetMethod(fasthttp.MethodOptions) + ctx.Request.SetRequestURI("/doesnotexist") + router.Handler(ctx) + if !(ctx.Response.StatusCode() == fasthttp.StatusNotFound) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", ctx.Response.StatusCode(), ctx.Response.Header.String()) } // add another method - r.GET("/path", handlerFunc) + router.GET("/path", handlerFunc) + + // set a global OPTIONS handler + router.GlobalOPTIONS = func(ctx *fasthttp.RequestCtx) { + // Adjust status code to 204 + ctx.SetStatusCode(fasthttp.StatusNoContent) + } // test again // * (server) - rw.r.WriteString("OPTIONS * HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != fasthttp.StatusOK { - t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", - resp.Header.StatusCode(), resp.Header.String()) - } else if allow := string(resp.Header.Peek("Allow")); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" { - t.Error("unexpected Allow header value: " + allow) - } + checkHandling("*", "GET, OPTIONS, POST", fasthttp.StatusNoContent) // path - rw.r.WriteString("OPTIONS /path HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != fasthttp.StatusOK { - t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", - resp.Header.StatusCode(), resp.Header.String()) - } else if allow := string(resp.Header.Peek("Allow")); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" { - t.Error("unexpected Allow header value: " + allow) - } + checkHandling("/path", "GET, OPTIONS, POST", fasthttp.StatusNoContent) // custom handler var custom bool - r.OPTIONS("/path", func(_ *fasthttp.RequestCtx) { + router.OPTIONS("/path", func(ctx *fasthttp.RequestCtx) { custom = true }) // test again // * (server) - rw.r.WriteString("OPTIONS * HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != fasthttp.StatusOK { - t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", - resp.Header.StatusCode(), resp.Header.String()) - } else if allow := string(resp.Header.Peek("Allow")); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" { - t.Error("unexpected Allow header value: " + allow) - } + checkHandling("*", "GET, OPTIONS, POST", fasthttp.StatusNoContent) if custom { t.Error("custom handler called on *") } // path - rw.r.WriteString("OPTIONS /path HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != fasthttp.StatusOK { - t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", - resp.Header.StatusCode(), resp.Header.String()) + ctx.Request.Header.SetMethod(fasthttp.MethodOptions) + ctx.Request.SetRequestURI("/path") + router.Handler(ctx) + if !(ctx.Response.StatusCode() == fasthttp.StatusNoContent) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", ctx.Response.StatusCode(), ctx.Response.Header.String()) } if !custom { t.Error("custom handler not called") @@ -620,259 +437,164 @@ func TestRouterOPTIONS(t *testing.T) { func TestRouterNotAllowed(t *testing.T) { handlerFunc := func(_ *fasthttp.RequestCtx) {} - r := New() - r.POST("/path", handlerFunc) + router := New() + router.POST("/path", handlerFunc) - // Test not allowed - s := &fasthttp.Server{ - Handler: r.Handler, - } + ctx := new(fasthttp.RequestCtx) - rw := &readWriter{} - ch := make(chan error) + var checkHandling = func(path, expectedAllowed string, expectedStatusCode int) { + ctx.Request.Header.SetMethod(fasthttp.MethodGet) + ctx.Request.SetRequestURI(path) + router.Handler(ctx) - rw.r.WriteString("GET /path HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) + if !(ctx.Response.StatusCode() == expectedStatusCode) { + t.Errorf("NotAllowed handling failed:: Code=%d, Header=%v", ctx.Response.StatusCode(), ctx.Response.Header.String()) + } else if allow := string(ctx.Response.Header.Peek("Allow")); allow != expectedAllowed { + t.Error("unexpected Allow header value: " + allow) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - br := bufio.NewReader(&rw.w) - var resp fasthttp.Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == fasthttp.StatusMethodNotAllowed) { - t.Errorf("NotAllowed handling failed: Code=%d", resp.Header.StatusCode()) - } else if allow := string(resp.Header.Peek("Allow")); allow != "POST, OPTIONS" { - t.Error("unexpected Allow header value: " + allow) } + // test not allowed + checkHandling("/path", "OPTIONS, POST", fasthttp.StatusMethodNotAllowed) + // add another method - r.DELETE("/path", handlerFunc) - r.OPTIONS("/path", handlerFunc) // must be ignored + router.DELETE("/path", handlerFunc) + router.OPTIONS("/path", handlerFunc) // must be ignored // test again - rw.r.WriteString("GET /path HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == fasthttp.StatusMethodNotAllowed) { - t.Errorf("NotAllowed handling failed: Code=%d", resp.Header.StatusCode()) - } else if allow := string(resp.Header.Peek("Allow")); allow != "POST, DELETE, OPTIONS" && allow != "DELETE, POST, OPTIONS" { - t.Error("unexpected Allow header value: " + allow) - } + checkHandling("/path", "DELETE, OPTIONS, POST", fasthttp.StatusMethodNotAllowed) + // test custom handler responseText := "custom method" - r.MethodNotAllowed = fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { + router.MethodNotAllowed = func(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(fasthttp.StatusTeapot) ctx.Write([]byte(responseText)) - }) - rw.r.WriteString("GET /path HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) } - if !bytes.Equal(resp.Body(), []byte(responseText)) { - t.Errorf("unexpected response got %q want %q", string(resp.Body()), responseText) + + ctx.Response.Reset() + router.Handler(ctx) + + if got := string(ctx.Response.Body()); !(got == responseText) { + t.Errorf("unexpected response got %q want %q", got, responseText) } - if resp.Header.StatusCode() != fasthttp.StatusTeapot { - t.Errorf("unexpected response code %d want %d", resp.Header.StatusCode(), fasthttp.StatusTeapot) + if ctx.Response.StatusCode() != fasthttp.StatusTeapot { + t.Errorf("unexpected response code %d want %d", ctx.Response.StatusCode(), fasthttp.StatusTeapot) } - if allow := string(resp.Header.Peek("Allow")); allow != "POST, DELETE, OPTIONS" && allow != "DELETE, POST, OPTIONS" { + if allow := string(ctx.Response.Header.Peek("Allow")); allow != "DELETE, OPTIONS, POST" { t.Error("unexpected Allow header value: " + allow) } } func TestRouterNotFound(t *testing.T) { handlerFunc := func(_ *fasthttp.RequestCtx) {} + host := "fast" - r := New() - r.GET("/path", handlerFunc) - r.GET("/dir/", handlerFunc) - r.GET("/", handlerFunc) + var buildLocation = func(path string) string { + return fmt.Sprintf("http://%s%s", host, path) + } + + router := New() + router.GET("/path", handlerFunc) + router.GET("/dir/", handlerFunc) + router.GET("/", handlerFunc) testRoutes := []struct { - route string - code int + route string + code int + location string }{ - {"/path/", 301}, // TSR -/ - {"/dir", 301}, // TSR +/ - {"/", 200}, // TSR +/ - {"/PATH", 301}, // Fixed Case - {"/DIR", 301}, // Fixed Case - {"/PATH/", 301}, // Fixed Case -/ - {"/DIR/", 301}, // Fixed Case +/ - {"/paTh/?name=foo", 301}, // Fixed Case With Params +/ - {"/paTh?name=foo", 301}, // Fixed Case With Params +/ - {"/../path", 200}, // CleanPath (Not clean by router, this path is cleaned by fasthttp `ctx.Path()`) - {"/nope", 404}, // NotFound - } - - s := &fasthttp.Server{ - Handler: r.Handler, + {"/path/", fasthttp.StatusMovedPermanently, buildLocation("/path")}, // TSR -/ + {"/dir", fasthttp.StatusMovedPermanently, buildLocation("/dir/")}, // TSR +/ + {"", fasthttp.StatusOK, ""}, // TSR +/ (Not clean by router, this path is cleaned by fasthttp `ctx.Path()`) + {"/PATH", fasthttp.StatusMovedPermanently, buildLocation("/path")}, // Fixed Case + {"/DIR/", fasthttp.StatusMovedPermanently, buildLocation("/dir/")}, // Fixed Case + {"/PATH/", fasthttp.StatusMovedPermanently, buildLocation("/path")}, // Fixed Case -/ + {"/DIR", fasthttp.StatusMovedPermanently, buildLocation("/dir/")}, // Fixed Case +/ + {"/paTh/?name=foo", fasthttp.StatusMovedPermanently, buildLocation("/path?name=foo")}, // Fixed Case With Params +/ + {"/paTh?name=foo", fasthttp.StatusMovedPermanently, buildLocation("/path?name=foo")}, // Fixed Case With Params +/ + {"/../path", fasthttp.StatusOK, ""}, // CleanPath (Not clean by router, this path is cleaned by fasthttp `ctx.Path()`) + {"/nope", fasthttp.StatusNotFound, ""}, // NotFound } - rw := &readWriter{} - br := bufio.NewReader(&rw.w) - var resp fasthttp.Response - ch := make(chan error) for _, tr := range testRoutes { - rw.r.WriteString(fmt.Sprintf("GET %s HTTP/1.1\r\n\r\n", tr.route)) - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == tr.code) { - t.Errorf("NotFound handling route %s failed: Code=%d want=%d", - tr.route, resp.Header.StatusCode(), tr.code) + ctx := new(fasthttp.RequestCtx) + + ctx.Request.Header.SetMethod(fasthttp.MethodGet) + ctx.Request.SetRequestURI(tr.route) + ctx.Request.SetHost(host) + router.Handler(ctx) + + statusCode := ctx.Response.StatusCode() + location := string(ctx.Response.Header.Peek("Location")) + if !(statusCode == tr.code && (statusCode == fasthttp.StatusNotFound || location == tr.location)) { + t.Errorf("NotFound handling route %s failed: Code=%d, Header=%v", tr.route, statusCode, location) } } + ctx := new(fasthttp.RequestCtx) + // Test custom not found handler var notFound bool - r.NotFound = fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { - ctx.SetStatusCode(404) + router.NotFound = func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(fasthttp.StatusNotFound) notFound = true - }) - rw.r.WriteString("GET /nope HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == 404 && notFound == true) { - t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", resp.Header.StatusCode(), string(resp.Header.Peek("Location"))) } - // Test other method than GET (want 307 instead of 301) - r.PATCH("/path", handlerFunc) - rw.r.WriteString("PATCH /path/ HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + ctx.Request.Header.SetMethod(fasthttp.MethodGet) + ctx.Request.SetRequestURI("/nope") + router.Handler(ctx) + if !(ctx.Response.StatusCode() == fasthttp.StatusNotFound && notFound == true) { + t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", ctx.Response.StatusCode(), ctx.Response.Header.String()) } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == 307) { - t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", resp.Header.StatusCode(), string(resp.Header.Peek("Location"))) + ctx.Response.Reset() + + // Test other method than GET (want 308 instead of 301) + router.PATCH("/path", handlerFunc) + + ctx.Request.Header.SetMethod(fasthttp.MethodPatch) + ctx.Request.SetRequestURI("/path/?key=val") + ctx.Request.SetHost(host) + router.Handler(ctx) + if !(ctx.Response.StatusCode() == fasthttp.StatusPermanentRedirect && string(ctx.Response.Header.Peek("Location")) == buildLocation("/path?key=val")) { + t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", ctx.Response.StatusCode(), ctx.Response.Header.String()) } + ctx.Response.Reset() // Test special case where no node for the prefix "/" exists - r = New() - r.GET("/a", handlerFunc) - s.Handler = r.Handler - rw.r.WriteString("GET / HTTP/1.1\r\n\r\n") - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if !(resp.Header.StatusCode() == 404) { - t.Errorf("NotFound handling route / failed: Code=%d", resp.Header.StatusCode()) + router = New() + router.GET("/a", handlerFunc) + + ctx.Request.Header.SetMethod(fasthttp.MethodPatch) + ctx.Request.SetRequestURI("/") + router.Handler(ctx) + if !(ctx.Response.StatusCode() == fasthttp.StatusNotFound) { + t.Errorf("NotFound handling route / failed: Code=%d", ctx.Response.StatusCode()) } } func TestRouterPanicHandler(t *testing.T) { - r := New() + router := New() panicHandled := false - r.PanicHandler = func(ctx *fasthttp.RequestCtx, p interface{}) { + router.PanicHandler = func(ctx *fasthttp.RequestCtx, p interface{}) { panicHandled = true } - r.Handle("PUT", "/user/:name", func(_ *fasthttp.RequestCtx) { + router.Handle(fasthttp.MethodPut, "/user/:name", func(ctx *fasthttp.RequestCtx) { panic("oops!") }) + ctx := new(fasthttp.RequestCtx) + ctx.Request.Header.SetMethod(fasthttp.MethodPut) + ctx.Request.SetRequestURI("/user/gopher") + defer func() { if rcv := recover(); rcv != nil { t.Fatal("handling panic failed") } }() - s := &fasthttp.Server{ - Handler: r.Handler, - } - - rw := &readWriter{} - ch := make(chan error) - - rw.r.WriteString(string("PUT /user/gopher HTTP/1.1\r\n\r\n")) - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + router.Handler(ctx) if !panicHandled { t.Fatal("simulating failed") @@ -884,12 +606,13 @@ func TestRouterLookup(t *testing.T) { wantHandle := func(_ *fasthttp.RequestCtx) { routed = true } + wantParams := map[string]string{"name": "gopher"} - r := New() - ctx := &fasthttp.RequestCtx{} + ctx := new(fasthttp.RequestCtx) + router := New() // try empty router first - handle, tsr := r.Lookup("GET", "/nope", ctx) + handle, tsr := router.Lookup(fasthttp.MethodGet, "/nope", ctx) if handle != nil { t.Fatalf("Got handle for unregistered pattern: %v", handle) } @@ -898,9 +621,28 @@ func TestRouterLookup(t *testing.T) { } // insert route and try again - r.GET("/user/:name", wantHandle) + router.GET("/user/:name", wantHandle) + handle, _ = router.Lookup(fasthttp.MethodGet, "/user/gopher", ctx) + if handle == nil { + t.Fatal("Got no handle!") + } else { + handle(nil) + if !routed { + t.Fatal("Routing failed!") + } + } + + for expectedKey, expectedVal := range wantParams { + if ctx.UserValue(expectedKey) != expectedVal { + t.Errorf("The values %s = %s is not save in context", expectedKey, expectedVal) + } + } + + routed = false - handle, _ = r.Lookup("GET", "/user/gopher", ctx) + // route without param + router.GET("/user", wantHandle) + handle, _ = router.Lookup(fasthttp.MethodGet, "/user", ctx) if handle == nil { t.Fatal("Got no handle!") } else { @@ -909,11 +651,14 @@ func TestRouterLookup(t *testing.T) { t.Fatal("Routing failed!") } } - if ctx.UserValue("name") != "gopher" { - t.Error("Param not set!") + + for expectedKey, expectedVal := range wantParams { + if ctx.UserValue(expectedKey) != expectedVal { + t.Errorf("The values %s = %s is not save in context", expectedKey, expectedVal) + } } - handle, tsr = r.Lookup("GET", "/user/gopher/", ctx) + handle, tsr = router.Lookup(fasthttp.MethodGet, "/user/gopher/", ctx) if handle != nil { t.Fatalf("Got handle for unregistered pattern: %v", handle) } @@ -921,7 +666,7 @@ func TestRouterLookup(t *testing.T) { t.Error("Got no TSR recommendation!") } - handle, tsr = r.Lookup("GET", "/nope", ctx) + handle, tsr = router.Lookup(fasthttp.MethodGet, "/nope", ctx) if handle != nil { t.Fatalf("Got handle for unregistered pattern: %v", handle) } @@ -930,6 +675,67 @@ func TestRouterLookup(t *testing.T) { } } +func TestRouterMatchedRoutePath(t *testing.T) { + route1 := "/user/:name" + routed1 := false + handle1 := func(ctx *fasthttp.RequestCtx) { + route := ctx.UserValue(MatchedRoutePathParam) + if route != route1 { + t.Fatalf("Wrong matched route: want %s, got %s", route1, route) + } + routed1 = true + } + + route2 := "/user/:name/details" + routed2 := false + handle2 := func(ctx *fasthttp.RequestCtx) { + route := ctx.UserValue(MatchedRoutePathParam) + if route != route2 { + t.Fatalf("Wrong matched route: want %s, got %s", route2, route) + } + routed2 = true + } + + route3 := "/" + routed3 := false + handle3 := func(ctx *fasthttp.RequestCtx) { + route := ctx.UserValue(MatchedRoutePathParam) + if route != route3 { + t.Fatalf("Wrong matched route: want %s, got %s", route3, route) + } + routed3 = true + } + + router := New() + router.SaveMatchedRoutePath = true + router.Handle(fasthttp.MethodGet, route1, handle1) + router.Handle(fasthttp.MethodGet, route2, handle2) + router.Handle(fasthttp.MethodGet, route3, handle3) + + ctx := new(fasthttp.RequestCtx) + + ctx.Request.Header.SetMethod(fasthttp.MethodGet) + ctx.Request.SetRequestURI("/user/gopher") + router.Handler(ctx) + if !routed1 || routed2 || routed3 { + t.Fatal("Routing failed!") + } + + ctx.Request.Header.SetMethod(fasthttp.MethodGet) + ctx.Request.SetRequestURI("/user/gopher/details") + router.Handler(ctx) + if !routed2 || routed3 { + t.Fatal("Routing failed!") + } + + ctx.Request.Header.SetMethod(fasthttp.MethodGet) + ctx.Request.SetRequestURI("/") + router.Handler(ctx) + if !routed3 { + t.Fatal("Routing failed!") + } +} + func TestRouterServeFiles(t *testing.T) { r := New() @@ -944,37 +750,19 @@ func TestRouterServeFiles(t *testing.T) { r.ServeFiles("/*filepath", os.TempDir()) - s := &fasthttp.Server{ - Handler: r.Handler, - } - - rw := &readWriter{} - ch := make(chan error) - - rw.r.WriteString(string("GET /favicon.ico HTTP/1.1\r\n\r\n")) - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) + assertWithTestServer(t, "GET /favicon.ico HTTP/1.1\r\n\r\n", r.Handler, func(rw *readWriter) { + br := bufio.NewReader(&rw.w) + var resp fasthttp.Response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when reading response: %s", err) } - case <-time.After(500 * time.Millisecond): - t.Fatalf("timeout") - } - - br := bufio.NewReader(&rw.w) - var resp fasthttp.Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != 200 { - t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423) - } - if !bytes.Equal(resp.Body(), body) { - t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), string(body)) - } + if resp.Header.StatusCode() != 200 { + t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423) + } + if !bytes.Equal(resp.Body(), body) { + t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), string(body)) + } + }) } func TestRouterServeFilesCustom(t *testing.T) { @@ -997,37 +785,19 @@ func TestRouterServeFilesCustom(t *testing.T) { r.ServeFilesCustom("/*filepath", fs) - s := &fasthttp.Server{ - Handler: r.Handler, - } - - rw := &readWriter{} - ch := make(chan error) - - rw.r.WriteString(string("GET /favicon.ico HTTP/1.1\r\n\r\n")) - go func() { - ch <- s.ServeConn(rw) - }() - select { - case err := <-ch: - if err != nil { - t.Fatalf("return error %s", err) + assertWithTestServer(t, "GET /favicon.ico HTTP/1.1\r\n\r\n", r.Handler, func(rw *readWriter) { + br := bufio.NewReader(&rw.w) + var resp fasthttp.Response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when reading response: %s", err) } - case <-time.After(500 * time.Millisecond): - t.Fatalf("timeout") - } - - br := bufio.NewReader(&rw.w) - var resp fasthttp.Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != 200 { - t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 200) - } - if !bytes.Equal(resp.Body(), body) { - t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), string(body)) - } + if resp.Header.StatusCode() != 200 { + t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 200) + } + if !bytes.Equal(resp.Body(), body) { + t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), string(body)) + } + }) } func TestRouterList(t *testing.T) { @@ -1054,42 +824,25 @@ func TestRouterList(t *testing.T) { } -type readWriter struct { - net.Conn - r bytes.Buffer - w bytes.Buffer -} - -var zeroTCPAddr = &net.TCPAddr{ - IP: net.IPv4zero, -} - -func (rw *readWriter) Close() error { - return nil -} - -func (rw *readWriter) Read(b []byte) (int, error) { - return rw.r.Read(b) -} - -func (rw *readWriter) Write(b []byte) (int, error) { - return rw.w.Write(b) -} - -func (rw *readWriter) RemoteAddr() net.Addr { - return zeroTCPAddr -} - -func (rw *readWriter) LocalAddr() net.Addr { - return zeroTCPAddr -} +func BenchmarkAllowed(b *testing.B) { + handlerFunc := func(_ *fasthttp.RequestCtx) {} -func (rw *readWriter) SetReadDeadline(t time.Time) error { - return nil -} + router := New() + router.POST("/path", handlerFunc) + router.GET("/path", handlerFunc) -func (rw *readWriter) SetWriteDeadline(t time.Time) error { - return nil + b.Run("Global", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = router.allowed("*", fasthttp.MethodOptions) + } + }) + b.Run("Path", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = router.allowed("/path", fasthttp.MethodOptions) + } + }) } func BenchmarkRouterGet(b *testing.B) { diff --git a/tolower.go b/tolower.go deleted file mode 100644 index 1fe05df..0000000 --- a/tolower.go +++ /dev/null @@ -1,9 +0,0 @@ -//+build !go1.12 - -package router - -import "strings" - -func toLower(s string) string { - return strings.ToLower(s) -} diff --git a/tolower_go112.go b/tolower_go112.go deleted file mode 100644 index cf1aaa1..0000000 --- a/tolower_go112.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the "third_party_licenses/Go (golang)" file. - -//+build go1.12 - -package router - -import ( - "strings" - "unicode" - "unicode/utf8" -) - -// toLower returns a copy of the string s with all Unicode letters mapped to their lower case. -func toLower(s string) string { - isASCII, hasUpper := true, false - for i := 0; i < len(s); i++ { - c := s[i] - if c >= utf8.RuneSelf { - isASCII = false - break - } - hasUpper = hasUpper || (c >= 'A' && c <= 'Z') - } - - if isASCII { // optimize for ASCII-only strings. - if !hasUpper { - return s - } - var b strings.Builder - b.Grow(len(s)) - for i := 0; i < len(s); i++ { - c := s[i] - if c >= 'A' && c <= 'Z' { - c += 'a' - 'A' - } - b.WriteByte(c) - } - return b.String() - } - return stringsMap(unicode.ToLower, s) -} - -// Map returns a copy of the string s with all its characters modified -// according to the mapping function. If mapping returns a negative value, the character is -// dropped from the string with no replacement. -func stringsMap(mapping func(rune) rune, s string) string { - // In the worst case, the string can grow when mapped, making - // things unpleasant. But it's so rare we barge in assuming it's - // fine. It could also shrink but that falls out naturally. - - // The output buffer b is initialized on demand, the first - // time a character differs. - var b strings.Builder - - for i, c := range s { - r := mapping(c) - if r == c { - continue - } - - b.Grow(len(s) + utf8.UTFMax) - b.WriteString(s[:i]) - if r >= 0 { - b.WriteRune(r) - } - - if c == utf8.RuneError { - // RuneError is the result of either decoding - // an invalid sequence or '\uFFFD'. Determine - // the correct number of bytes we need to advance. - _, w := utf8.DecodeRuneInString(s[i:]) - i += w - } else { - i += utf8.RuneLen(c) - } - - s = s[i:] - break - } - - // Fast path for unchanged input - if b.Cap() == 0 { // didn't call b.Grow above - return s - } - - for _, c := range s { - r := mapping(c) - - if r >= 0 { - // common case - // Due to inlining, it is more performant to determine if WriteByte should be - // invoked rather than always call WriteRune - if r < utf8.RuneSelf { - b.WriteByte(byte(r)) - } else { - // r is not a ASCII rune. - b.WriteRune(r) - } - } - } - - return b.String() -} diff --git a/tree.go b/tree.go index 9f3e4f4..651555f 100644 --- a/tree.go +++ b/tree.go @@ -6,112 +6,102 @@ package router import ( "strings" - "sync" "unicode" "unicode/utf8" - "github.com/savsgio/gotils" "github.com/valyala/fasthttp" ) -const ( - static nodeType = iota // default - root - param - catchAll -) - -type nodeType uint8 - -type buffer struct { - b []byte +func min(a, b int) int { + if a <= b { + return a + } + return b } -type node struct { - path string - wildChild bool - nType nodeType - maxParams uint8 - indices string - children []*node - handle fasthttp.RequestHandler - priority uint32 +func longestCommonPrefix(a, b string) int { + i := 0 + max := min(len(a), len(b)) + for i < max && a[i] == b[i] { + i++ + } + return i } -var bufferPool = sync.Pool{ - New: func() interface{} { - return &buffer{ - b: make([]byte, 0, 0), +// Search for a wildcard segment and check the name for invalid characters. +// Returns -1 as index, if no wildcard was found. +func findWildcard(path string) (wilcard string, i int, valid bool) { + // Find start + for start, c := range []byte(path) { + // A wildcard starts with ':' (param) or '*' (catch-all) + if c != ':' && c != '*' { + continue } - }, -} - -func acquireBuffer() *buffer { - return bufferPool.Get().(*buffer) -} - -func releaseBuffer(b *buffer) { - bufferPool.Put(b) -} -func min(a, b int) int { - if a <= b { - return a + // Find end and check for invalid characters + valid = true + for end, c := range []byte(path[start+1:]) { + switch c { + case '/': + return path[start : start+1+end], start, valid + case ':', '*': + valid = false + } + } + return path[start:], start, valid } - return b + return "", -1, false } -func countParams(path string) uint8 { +func countParams(path string) uint16 { var n uint - for i := 0; i < len(path); i++ { - if path[i] != ':' && path[i] != '*' { - continue + for i := range []byte(path) { + switch path[i] { + case ':', '*': + n++ } - n++ - } - if n >= 255 { - return 255 } - return uint8(n) + return uint16(n) } -// shift bytes in array by n bytes left -func shiftNRuneBytes(rb [4]byte, n int) [4]byte { - switch n { - case 0: - return rb - case 1: - return [4]byte{rb[1], rb[2], rb[3], 0} - case 2: - return [4]byte{rb[2], rb[3]} - case 3: - return [4]byte{rb[3]} - default: - return [4]byte{} - } +type nodeType uint8 + +const ( + static nodeType = iota // default + root + param + catchAll +) + +type node struct { + path string + indices string + wildChild bool + nType nodeType + priority uint32 + children []*node + handle fasthttp.RequestHandler } -// increments priority of the given child and reorders if necessary +// Increments priority of the given child and reorders if necessary func (n *node) incrementChildPrio(pos int) int { - n.children[pos].priority++ - prio := n.children[pos].priority + cs := n.children + cs[pos].priority++ + prio := cs[pos].priority - // adjust position (move to front) + // Adjust position (move to front) newPos := pos - for newPos > 0 && n.children[newPos-1].priority < prio { - // swap node positions - tmpN := n.children[newPos-1] - n.children[newPos-1] = n.children[newPos] - n.children[newPos] = tmpN + for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- { + // Swap node positions + cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1] - newPos-- } - // build new index char string + // Build new index char string if newPos != pos { - n.indices = n.indices[:newPos] + // unchanged prefix, might be empty - n.indices[pos:pos+1] + // the index char we move - n.indices[newPos:pos] + n.indices[pos+1:] // rest without char at 'pos' + n.indices = n.indices[:newPos] + // Unchanged prefix, might be empty + n.indices[pos:pos+1] + // The index char we move + n.indices[newPos:pos] + n.indices[pos+1:] // Rest without char at 'pos' } return newPos @@ -122,198 +112,171 @@ func (n *node) incrementChildPrio(pos int) int { func (n *node) addRoute(path string, handle fasthttp.RequestHandler) { fullPath := path n.priority++ - numParams := countParams(path) - - // non-empty tree - if len(n.path) > 0 || len(n.children) > 0 { - walk: - for { - // Update maxParams of the current node - if numParams > n.maxParams { - n.maxParams = numParams - } - // Find the longest common prefix. - // This also implies that the common prefix contains no ':' or '*' - // since the existing key can't contain those chars. - i := 0 - max := min(len(path), len(n.path)) - for i < max && path[i] == n.path[i] { - i++ - } - - // Split edge - if i < len(n.path) { - child := node{ - path: n.path[i:], - wildChild: n.wildChild, - nType: static, - indices: n.indices, - children: n.children, - handle: n.handle, - priority: n.priority - 1, - } - - // Update maxParams (max of all children) - for i := range child.children { - if child.children[i].maxParams > child.maxParams { - child.maxParams = child.children[i].maxParams - } - } + // Empty tree + if len(n.path) == 0 && len(n.indices) == 0 { + n.insertChild(path, fullPath, handle) + n.nType = root + return + } - n.children = []*node{&child} - // []byte for proper unicode char conversion, see #65 - n.indices = string([]byte{n.path[i]}) - n.path = path[:i] - n.handle = nil - n.wildChild = false +walk: + for { + // Find the longest common prefix. + // This also implies that the common prefix contains no ':' or '*' + // since the existing key can't contain those chars. + i := longestCommonPrefix(path, n.path) + + // Split edge + if i < len(n.path) { + child := node{ + path: n.path[i:], + wildChild: n.wildChild, + nType: static, + indices: n.indices, + children: n.children, + handle: n.handle, + priority: n.priority - 1, } - // Make new node a child of this node - if i < len(path) { - path = path[i:] + n.children = []*node{&child} + // []byte for proper unicode char conversion, see #65 + n.indices = string([]byte{n.path[i]}) + n.path = path[:i] + n.handle = nil + n.wildChild = false + } - if n.wildChild { - n = n.children[0] - n.priority++ + // Make new node a child of this node + if i < len(path) { + path = path[i:] - // Update maxParams of the child node - if numParams > n.maxParams { - n.maxParams = numParams - } - numParams-- - - // Check if the wildcard matches - if len(path) >= len(n.path) && n.path == path[:len(n.path)] && - // Check for longer wildcard, e.g. :name and :names - (len(n.path) >= len(path) || path[len(n.path)] == '/') { - continue walk - } else { - // Wildcard conflict - pathSeg := strings.SplitN(path, "/", 2)[0] - prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path - panic("'" + pathSeg + - "' in new path '" + fullPath + - "' conflicts with existing wildcard '" + n.path + - "' in existing prefix '" + prefix + - "'") + if n.wildChild { + n = n.children[0] + n.priority++ + + // Check if the wildcard matches + if len(path) >= len(n.path) && n.path == path[:len(n.path)] && + // Adding a child to a catchAll is not possible + n.nType != catchAll && + // Check for longer wildcard, e.g. :name and :names + (len(n.path) >= len(path) || path[len(n.path)] == '/') { + continue walk + } else { + // Wildcard conflict + pathSeg := path + if n.nType != catchAll { + pathSeg = strings.SplitN(pathSeg, "/", 2)[0] } + prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path + panic("'" + pathSeg + + "' in new path '" + fullPath + + "' conflicts with existing wildcard '" + n.path + + "' in existing prefix '" + prefix + + "'") } + } - c := path[0] + idxc := path[0] - // slash after param - if n.nType == param && c == '/' && len(n.children) == 1 { - n = n.children[0] - n.priority++ - continue walk - } - - // Check if a child with the next path byte exists - for i := 0; i < len(n.indices); i++ { - if c == n.indices[i] { - i = n.incrementChildPrio(i) - n = n.children[i] - continue walk - } - } + // '/' after param + if n.nType == param && idxc == '/' && len(n.children) == 1 { + n = n.children[0] + n.priority++ + continue walk + } - // Otherwise insert it - if c != ':' && c != '*' { - // []byte for proper unicode char conversion, see #65 - n.indices += string([]byte{c}) - child := &node{ - maxParams: numParams, - } - n.children = append(n.children, child) - n.incrementChildPrio(len(n.indices) - 1) - n = child + // Check if a child with the next path byte exists + for i, c := range []byte(n.indices) { + if c == idxc { + i = n.incrementChildPrio(i) + n = n.children[i] + continue walk } - n.insertChild(numParams, path, fullPath, handle) - return + } - } else if i == len(path) { // Make node a (in-path) leaf - if n.handle != nil { - panic("a handle is already registered for path '" + fullPath + "'") - } - n.handle = handle + // Otherwise insert it + if idxc != ':' && idxc != '*' { + // []byte for proper unicode char conversion, see #65 + n.indices += string([]byte{idxc}) + child := &node{} + n.children = append(n.children, child) + n.incrementChildPrio(len(n.indices) - 1) + n = child } + n.insertChild(path, fullPath, handle) return } - } else { // Empty tree - n.insertChild(numParams, path, fullPath, handle) - n.nType = root + + // Otherwise add handle to current node + if n.handle != nil { + panic("a handle is already registered for path '" + fullPath + "'") + } + n.handle = handle + return } } -func (n *node) insertChild(numParams uint8, path, fullPath string, handle fasthttp.RequestHandler) { - var offset int // already handled bytes of the path +func (n *node) insertChild(path, fullPath string, handle fasthttp.RequestHandler) { + for { + // Find prefix until first wildcard + wildcard, i, valid := findWildcard(path) + if i < 0 { // No wilcard found + break + } - // find prefix until first wildcard (beginning with ':'' or '*'') - for i, max := 0, len(path); numParams > 0; i++ { - c := path[i] - if c != ':' && c != '*' { - continue + // The wildcard name must not contain ':' and '*' + if !valid { + panic("only one wildcard per path segment is allowed, has: '" + + wildcard + "' in path '" + fullPath + "'") } - // find wildcard end (either '/' or path end) - end := i + 1 - for end < max && path[end] != '/' { - switch path[end] { - // the wildcard name must not contain ':' and '*' - case ':', '*': - panic("only one wildcard per path segment is allowed, has: '" + - path[i:] + "' in path '" + fullPath + "'") - default: - end++ - } + // Check if the wildcard has a name + if len(wildcard) < 2 { + panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") } - // check if this Node existing children which would be + // Check if this node has existing children which would be // unreachable if we insert the wildcard here if len(n.children) > 0 { - panic("wildcard route '" + path[i:end] + + panic("wildcard segment '" + wildcard + "' conflicts with existing children in path '" + fullPath + "'") } - // check if the wildcard has a name - if end-i < 2 { - panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") - } - - if c == ':' { // param - // split path at the beginning of the wildcard + if wildcard[0] == ':' { // param if i > 0 { - n.path = path[offset:i] - offset = i + // Insert prefix before the current wildcard + n.path = path[:i] + path = path[i:] } + n.wildChild = true child := &node{ - nType: param, - maxParams: numParams, + nType: param, + path: wildcard, } n.children = []*node{child} - n.wildChild = true n = child n.priority++ - numParams-- - // if the path doesn't end with the wildcard, then there + // If the path doesn't end with the wildcard, then there // will be another non-wildcard subpath starting with '/' - if end < max { - n.path = path[offset:end] - offset = end - + if len(wildcard) < len(path) { + path = path[len(wildcard):] child := &node{ - maxParams: numParams, - priority: 1, + priority: 1, } n.children = []*node{child} n = child + continue } + // Otherwise we're done. Insert the handle in the new leaf + n.handle = handle + return + } else { // catchAll - if end != max || numParams > 1 { + if i+len(wildcard) != len(path) { panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") } @@ -321,32 +284,30 @@ func (n *node) insertChild(numParams uint8, path, fullPath string, handle fastht panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'") } - // currently fixed width 1 for '/' + // Currently fixed width 1 for '/' i-- if path[i] != '/' { panic("no / before catch-all in path '" + fullPath + "'") } - n.path = path[offset:i] + n.path = path[:i] - // first node: catchAll node with empty path + // First node: catchAll node with empty path child := &node{ wildChild: true, nType: catchAll, - maxParams: 1, } n.children = []*node{child} - n.indices = string(path[i]) + n.indices = string('/') n = child n.priority++ - // second node: node holding the variable + // Second node: node holding the variable child = &node{ - path: path[i:], - nType: catchAll, - maxParams: 1, - handle: handle, - priority: 1, + path: path[i:], + nType: catchAll, + handle: handle, + priority: 1, } n.children = []*node{child} @@ -354,8 +315,8 @@ func (n *node) insertChild(numParams uint8, path, fullPath string, handle fastht } } - // insert remaining path part and handle to the leaf - n.path = path[offset:] + // If no wildcard was found, simply insert the path and handle + n.path = path n.handle = handle } @@ -365,18 +326,20 @@ func (n *node) insertChild(numParams uint8, path, fullPath string, handle fastht // made if a handle exists with an extra (without the) trailing slash for the // given path. func (n *node) getValue(path string, ctx *fasthttp.RequestCtx) (handle fasthttp.RequestHandler, tsr bool) { -walk: // outer loop for walking the tree +walk: // Outer loop for walking the tree for { - if len(path) > len(n.path) { - if path[:len(n.path)] == n.path { - path = path[len(n.path):] + prefix := n.path + if len(path) > len(prefix) { + if path[:len(prefix)] == prefix { + path = path[len(prefix):] + // If this node does not have a wildcard (param or catchAll) - // child, we can just look up the next child node and continue + // child, we can just look up the next child node and continue // to walk down the tree if !n.wildChild { - c := path[0] - for i := 0; i < len(n.indices); i++ { - if c == n.indices[i] { + idxc := path[0] + for i, c := range []byte(n.indices) { + if c == idxc { n = n.children[i] continue walk } @@ -390,22 +353,22 @@ walk: // outer loop for walking the tree } - // handle wildcard child + // Handle wildcard child n = n.children[0] switch n.nType { case param: - // find param end (either '/' or path end) + // Find param end (either '/' or path end) end := 0 for end < len(path) && path[end] != '/' { end++ } - // handle calls to Router.allowed method with nil context + // Save param value if ctx != nil { ctx.SetUserValue(n.path[1:], path[:end]) } - // we need to go deeper! + // We need to go deeper! if end < len(path) { if len(n.children) > 0 { path = path[end:] @@ -430,10 +393,11 @@ walk: // outer loop for walking the tree return case catchAll: + // Save param value if ctx != nil { - // save param value ctx.SetUserValue(n.path[2:], path) } + handle = n.handle return @@ -441,13 +405,16 @@ walk: // outer loop for walking the tree panic("invalid node type") } } - } else if path == n.path { + } else if path == prefix { // We should have reached the node containing the handle. // Check if this node has a handle registered. if handle = n.handle; handle != nil { return } + // If there is no handle for this route, but this route has a + // wildcard child, there must be a handle for this path with an + // additional trailing slash if path == "/" && n.wildChild && n.nType != root { tsr = true return @@ -455,23 +422,22 @@ walk: // outer loop for walking the tree // No handle found. Check if a handle for this path + a // trailing slash exists for trailing slash recommendation - for i := 0; i < len(n.indices); i++ { - if n.indices[i] == '/' { + for i, c := range []byte(n.indices) { + if c == '/' { n = n.children[i] tsr = (len(n.path) == 1 && n.handle != nil) || (n.nType == catchAll && n.children[0].handle != nil) return } } - return } // Nothing found. We can recommend to redirect to the same URL with an // extra trailing slash if a leaf exists for that path tsr = (path == "/") || - (len(n.path) == len(path)+1 && n.path[len(path)] == '/' && - path == n.path[:len(n.path)-1] && n.handle != nil) + (len(prefix) == len(path)+1 && prefix[len(path)] == '/' && + path == prefix[:len(prefix)-1] && n.handle != nil) return } } @@ -480,101 +446,124 @@ walk: // outer loop for walking the tree // It can optionally also fix trailing slashes. // It returns the case-corrected path and a bool indicating whether the lookup // was successful. -func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { - buff := acquireBuffer() - - buff.b = gotils.ExtendByteSlice(buff.b, len(path)+1) // preallocate enough memory for new path +func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (fixedPath []byte, found bool) { + const stackBufSize = 128 + + // Use a static sized buffer on the stack in the common case. + // If the path is too long, allocate a buffer on the heap instead. + buf := make([]byte, 0, stackBufSize) + if l := len(path) + 1; l > stackBufSize { + buf = make([]byte, 0, l) + } - fixedPath, found := n.findCaseInsensitivePathRec( + ciPath := n.findCaseInsensitivePathRec( path, - toLower(path), - buff.b[:0], - [4]byte{}, // empty rune buffer + buf, // Preallocate enough memory for new path + [4]byte{}, // Empty rune buffer fixTrailingSlash, ) - releaseBuffer(buff) + return ciPath, ciPath != nil +} - return fixedPath, found +// Shift bytes in array by n bytes left +func shiftNRuneBytes(rb [4]byte, n int) [4]byte { + switch n { + case 0: + return rb + case 1: + return [4]byte{rb[1], rb[2], rb[3], 0} + case 2: + return [4]byte{rb[2], rb[3]} + case 3: + return [4]byte{rb[3]} + default: + return [4]byte{} + } } -// recursive case-insensitive lookup function used by n.findCaseInsensitivePath -func (n *node) findCaseInsensitivePathRec(path, loPath string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) ([]byte, bool) { - loNPath := toLower(n.path) +// Recursive case-insensitive lookup function used by n.findCaseInsensitivePath +func (n *node) findCaseInsensitivePathRec(path string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) []byte { + npLen := len(n.path) -walk: // outer loop for walking the tree - for len(loPath) >= len(loNPath) && (len(loNPath) == 0 || loPath[1:len(loNPath)] == loNPath[1:]) { - // add common path to result +walk: // Outer loop for walking the tree + for len(path) >= npLen && (npLen == 0 || strings.EqualFold(path[1:npLen], n.path[1:])) { + // Add common prefix to result + oldPath := path + path = path[npLen:] ciPath = append(ciPath, n.path...) - if path = path[len(n.path):]; len(path) > 0 { - loOld := loPath - loPath = loPath[len(loNPath):] - + if len(path) > 0 { // If this node does not have a wildcard (param or catchAll) child, // we can just look up the next child node and continue to walk down // the tree if !n.wildChild { - // skip rune bytes already processed - rb = shiftNRuneBytes(rb, len(loNPath)) + // Skip rune bytes already processed + rb = shiftNRuneBytes(rb, npLen) if rb[0] != 0 { - // old rune not finished - for i := 0; i < len(n.indices); i++ { - if n.indices[i] == rb[0] { + // Old rune not finished + idxc := rb[0] + for i, c := range []byte(n.indices) { + if c == idxc { // continue with child node n = n.children[i] - loNPath = toLower(n.path) + npLen = len(n.path) continue walk } } } else { - // process a new rune + // Process a new rune var rv rune - // find rune start - // runes are up to 4 byte long, - // -4 would definitely be another rune + // Find rune start. + // Runes are up to 4 byte long, + // -4 would definitely be another rune. var off int - for max := min(len(loNPath), 3); off < max; off++ { - if i := len(loNPath) - off; utf8.RuneStart(loOld[i]) { - // read rune from cached lowercase path - rv, _ = utf8.DecodeRuneInString(loOld[i:]) + for max := min(npLen, 3); off < max; off++ { + if i := npLen - off; utf8.RuneStart(oldPath[i]) { + // read rune from cached path + rv, _ = utf8.DecodeRuneInString(oldPath[i:]) break } } - // calculate lowercase bytes of current rune - utf8.EncodeRune(rb[:], rv) - // skipp already processed bytes + // Calculate lowercase bytes of current rune + lo := unicode.ToLower(rv) + utf8.EncodeRune(rb[:], lo) + + // Skip already processed bytes rb = shiftNRuneBytes(rb, off) - for i := 0; i < len(n.indices); i++ { - // lowercase matches - if n.indices[i] == rb[0] { + idxc := rb[0] + for i, c := range []byte(n.indices) { + // Lowercase matches + if c == idxc { // must use a recursive approach since both the // uppercase byte and the lowercase byte might exist // as an index - if out, found := n.children[i].findCaseInsensitivePathRec( - path, loPath, ciPath, rb, fixTrailingSlash, - ); found { - return out, true + if out := n.children[i].findCaseInsensitivePathRec( + path, ciPath, rb, fixTrailingSlash, + ); out != nil { + return out } break } } - // same for uppercase rune, if it differs - if up := unicode.ToUpper(rv); up != rv { + // If we found no match, the same for the uppercase rune, + // if it differs + if up := unicode.ToUpper(rv); up != lo { utf8.EncodeRune(rb[:], up) rb = shiftNRuneBytes(rb, off) - for i := 0; i < len(n.indices); i++ { - // uppercase matches - if n.indices[i] == rb[0] { - // continue with child node + idxc := rb[0] + for i, c := range []byte(n.indices) { + // Uppercase matches + if c == idxc { + // Continue with child node n = n.children[i] - loNPath = toLower(n.path) + npLen = len(n.path) continue walk } } @@ -583,53 +572,55 @@ walk: // outer loop for walking the tree // Nothing found. We can recommend to redirect to the same URL // without a trailing slash if a leaf exists for that path - return ciPath, (fixTrailingSlash && path == "/" && n.handle != nil) + if fixTrailingSlash && path == "/" && n.handle != nil { + return ciPath + } + return nil } n = n.children[0] switch n.nType { case param: - // find param end (either '/' or path end) - k := 0 - for k < len(path) && path[k] != '/' { - k++ + // Find param end (either '/' or path end) + end := 0 + for end < len(path) && path[end] != '/' { + end++ } - // add param value to case insensitive path - ciPath = append(ciPath, path[:k]...) + // Add param value to case insensitive path + ciPath = append(ciPath, path[:end]...) - // we need to go deeper! - if k < len(path) { + // We need to go deeper! + if end < len(path) { if len(n.children) > 0 { - // continue with child node + // Continue with child node n = n.children[0] - loNPath = toLower(n.path) - loPath = loPath[k:] - path = path[k:] + npLen = len(n.path) + path = path[end:] continue } // ... but we can't - if fixTrailingSlash && len(path) == k+1 { - return ciPath, true + if fixTrailingSlash && len(path) == end+1 { + return ciPath } - return ciPath, false + return nil } if n.handle != nil { - return ciPath, true + return ciPath } else if fixTrailingSlash && len(n.children) == 1 { // No handle found. Check if a handle for this path + a // trailing slash exists n = n.children[0] if n.path == "/" && n.handle != nil { - return append(ciPath, '/'), true + return append(ciPath, '/') } } - return ciPath, false + return nil case catchAll: - return append(ciPath, path...), true + return append(ciPath, path...) default: panic("invalid node type") @@ -638,24 +629,24 @@ walk: // outer loop for walking the tree // We should have reached the node containing the handle. // Check if this node has a handle registered. if n.handle != nil { - return ciPath, true + return ciPath } // No handle found. // Try to fix the path by adding a trailing slash if fixTrailingSlash { - for i := 0; i < len(n.indices); i++ { - if n.indices[i] == '/' { + for i, c := range []byte(n.indices) { + if c == '/' { n = n.children[i] if (len(n.path) == 1 && n.handle != nil) || (n.nType == catchAll && n.children[0].handle != nil) { - return append(ciPath, '/'), true + return append(ciPath, '/') } - return ciPath, false + return nil } } } - return ciPath, false + return nil } } @@ -663,12 +654,12 @@ walk: // outer loop for walking the tree // Try to fix the path by adding / removing a trailing slash if fixTrailingSlash { if path == "/" { - return ciPath, true + return ciPath } - if len(loPath)+1 == len(loNPath) && loNPath[len(loPath)] == '/' && - loPath[1:] == loNPath[1:len(loPath)] && n.handle != nil { - return append(ciPath, n.path...), true + if len(path)+1 == npLen && n.path[len(path)] == '/' && + strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handle != nil { + return append(ciPath, n.path...) } } - return ciPath, false + return nil } diff --git a/tree_test.go b/tree_test.go index 508854f..6728847 100644 --- a/tree_test.go +++ b/tree_test.go @@ -6,6 +6,7 @@ package router import ( "fmt" + "regexp" "strings" "testing" @@ -13,7 +14,7 @@ import ( ) func printChildren(n *node, prefix string) { - fmt.Printf(" %02d:%02d %s%s[%d] %v %t %d \r\n", n.priority, n.maxParams, prefix, n.path, len(n.children), n.handle, n.wildChild, n.nType) + fmt.Printf(" %02d %s%s[%d] %v %t %d \r\n", n.priority, prefix, n.path, len(n.children), n.handle, n.wildChild, n.nType) for l := len(n.path); l > 0; l-- { prefix += " " } @@ -26,7 +27,7 @@ func printChildren(n *node, prefix string) { var fakeHandlerValue string func fakeHandler(val string) fasthttp.RequestHandler { - return func(*fasthttp.RequestCtx) { + return func(ctx *fasthttp.RequestCtx) { fakeHandlerValue = val } } @@ -38,7 +39,7 @@ type testRequests []struct { ps map[string]string } -func acquarieReqeustCtx(path string) *fasthttp.RequestCtx { +func acquireRequestCtx(path string) *fasthttp.RequestCtx { var requestCtx fasthttp.RequestCtx var fastRequest fasthttp.Request fastRequest.SetRequestURI(path) @@ -48,8 +49,8 @@ func acquarieReqeustCtx(path string) *fasthttp.RequestCtx { func checkRequests(t *testing.T, tree *node, requests testRequests) { for _, request := range requests { - requestCtx := acquarieReqeustCtx(request.path) - handler, _ := tree.getValue(request.path, requestCtx) + ctx := acquireRequestCtx(request.path) + handler, _ := tree.getValue(request.path, ctx) if handler == nil { if !request.nilHandler { @@ -65,7 +66,7 @@ func checkRequests(t *testing.T, tree *node, requests testRequests) { } for expectedKey, expectedVal := range request.ps { - if requestCtx.UserValue(expectedKey) != expectedVal { + if ctx.UserValue(expectedKey) != expectedVal { t.Errorf(" mismatch for route '%s'", request.path) } } @@ -92,33 +93,11 @@ func checkPriorities(t *testing.T, n *node) uint32 { return prio } -func checkMaxParams(t *testing.T, n *node) uint8 { - var maxParams uint8 - for i := range n.children { - params := checkMaxParams(t, n.children[i]) - if params > maxParams { - maxParams = params - } - } - if n.nType > root && !n.wildChild { - maxParams++ - } - - if n.maxParams != maxParams { - t.Errorf( - "maxParams mismatch for node '%s': is %d, should be %d", - n.path, n.maxParams, maxParams, - ) - } - - return maxParams -} - func TestCountParams(t *testing.T) { if countParams("/path/:param1/static/*catch-all") != 2 { t.Fail() } - if countParams(strings.Repeat("/:param", 256)) != 255 { + if countParams(strings.Repeat("/:param", 256)) != 256 { t.Fail() } } @@ -160,7 +139,6 @@ func TestTreeAddAndGet(t *testing.T) { }) checkPriorities(t, tree) - checkMaxParams(t, tree) } func TestTreeWildcard(t *testing.T) { @@ -206,7 +184,6 @@ func TestTreeWildcard(t *testing.T) { }) checkPriorities(t, tree) - checkMaxParams(t, tree) } func catchPanic(testFunc func()) (recv interface{}) { @@ -342,6 +319,8 @@ func TestTreeCatchAllConflict(t *testing.T) { {"/src/*filepath/x", true}, {"/src2/", false}, {"/src2/*filepath/x", true}, + {"/src3/*filepath", false}, + {"/src3/*filepath/x", true}, } testRoutes(t, routes) } @@ -354,6 +333,12 @@ func TestTreeCatchAllConflictRoot(t *testing.T) { testRoutes(t, routes) } +func TestTreeCatchMaxParams(t *testing.T) { + tree := &node{} + var route = "/cmd/*filepath" + tree.addRoute(route, fakeHandler(route)) +} + func TestTreeDoubleWildcard(t *testing.T) { const panicMsg = "only one wildcard per path segment is allowed" @@ -388,7 +373,6 @@ func TestTreeDoubleWildcard(t *testing.T) { func TestTreeTrailingSlashRedirect(t *testing.T) { tree := &node{} - ctx := &fasthttp.RequestCtx{} routes := [...]string{ "/hi", @@ -444,7 +428,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/doc/", } for _, route := range tsrRoutes { - handler, tsr := tree.getValue(route, ctx) + handler, tsr := tree.getValue(route, nil) if handler != nil { t.Fatalf("non-nil handler for TSR route '%s", route) } else if !tsr { @@ -461,7 +445,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/api/world/abc", } for _, route := range noTsrRoutes { - handler, tsr := tree.getValue(route, ctx) + handler, tsr := tree.getValue(route, nil) if handler != nil { t.Fatalf("non-nil handler for No-TSR route '%s", route) } else if tsr { @@ -491,6 +475,9 @@ func TestTreeRootTrailingSlashRedirect(t *testing.T) { func TestTreeFindCaseInsensitivePath(t *testing.T) { tree := &node{} + longPath := "/l" + strings.Repeat("o", 128) + "ng" + lOngPath := "/l" + strings.Repeat("O", 128) + "ng/" + routes := [...]string{ "/hi", "/b/", @@ -524,6 +511,7 @@ func TestTreeFindCaseInsensitivePath(t *testing.T) { "/w/♭/", // 3 byte, last byte differs "/w/𠜎", // 4 byte "/w/𠜏/", // 4 byte + longPath, } for _, route := range routes { @@ -616,6 +604,7 @@ func TestTreeFindCaseInsensitivePath(t *testing.T) { {"/w/♭", "/w/♭/", true, true}, {"/w/𠜎/", "/w/𠜎", true, true}, {"/w/𠜏", "/w/𠜏/", true, true}, + {lOngPath, longPath, true, true}, } // With fixTrailingSlash = true for _, test := range tests { @@ -669,3 +658,42 @@ func TestTreeInvalidNodeType(t *testing.T) { t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv) } } + +func TestTreeWildcardConflictEx(t *testing.T) { + conflicts := [...]struct { + route string + segPath string + existPath string + existSegPath string + }{ + {"/who/are/foo", "/foo", `/who/are/\*you`, `/\*you`}, + {"/who/are/foo/", "/foo/", `/who/are/\*you`, `/\*you`}, + {"/who/are/foo/bar", "/foo/bar", `/who/are/\*you`, `/\*you`}, + {"/conxxx", "xxx", `/con:tact`, `:tact`}, + {"/conooo/xxx", "ooo", `/con:tact`, `:tact`}, + } + + for _, conflict := range conflicts { + // I have to re-create a 'tree', because the 'tree' will be + // in an inconsistent state when the loop recovers from the + // panic which threw by 'addRoute' function. + tree := &node{} + routes := [...]string{ + "/con:tact", + "/who/are/*you", + "/who/foo/hello", + } + + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + recv := catchPanic(func() { + tree.addRoute(conflict.route, fakeHandler(conflict.route)) + }) + + if !regexp.MustCompile(fmt.Sprintf("'%s' in new path .* conflicts with existing wildcard '%s' in existing prefix '%s'", conflict.segPath, conflict.existSegPath, conflict.existPath)).MatchString(fmt.Sprint(recv)) { + t.Fatalf("invalid wildcard conflict error (%v)", recv) + } + } +}