From ee0bb08b57d6158fb12b86b078165de284845c87 Mon Sep 17 00:00:00 2001 From: nolwn Date: Sun, 13 Dec 2020 22:57:01 -0800 Subject: [PATCH] Store path params in request context. --- router.go | 28 ++++++++++++++++++++-------- router_test.go | 2 +- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/router.go b/router.go index c59a4ec..6f19c09 100644 --- a/router.go +++ b/router.go @@ -58,7 +58,7 @@ type segment struct { parameter parameter } -var paramKey = contextKey("params") +var paramsKey = contextKey("params") // NotFoundHandler is the default function for handling routes that are not found. If you wish to // provide your own handler for this, simply set it on the router. @@ -115,11 +115,11 @@ func (r *Router) AddRoute(method string, path string, callback http.HandlerFunc) // Handler returns the Handler to use for the given request, consulting r.Method, r.URL.Path. It // always returns a non-nil Handler. // -// Handler also returns a new context which contains any path parameters that are needed. +// Handler also returns the path which it matched. // // If there is no registered Handler that applies to the request, Handler returns a ``page not // found'' Handler and an empty pattern. -func (r *Router) Handler(req *http.Request) (h http.Handler, ctx context.Context) { +func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) { method := req.Method path := req.URL.Path @@ -127,11 +127,11 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, ctx context.Context h = NotFoundHandler } - endpoint, params, err := r.getEndpoint(method, path) - ctx = context.WithValue(context.Background(), paramKey, params) + endpoint, _, err := r.getEndpoint(method, path) if err == nil { h = endpoint.callback + pattern = endpoint.path } return @@ -144,9 +144,21 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, ctx context.Context // In the case of this router, all it needs to do is lookup the Handler that has been saved at a given // path and then call its ServeHTTP. func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - handler, ctx := r.Handler(req) + method := req.Method + path := req.URL.Path + var handler http.Handler - req = req.WithContext(ctx) + endpoint, params, err := r.getEndpoint(method, path) + + if err != nil { + handler = NotFoundHandler + } else { + handler = endpoint.callback + ctx := context.WithValue(context.Background(), paramsKey, params) + req = req.WithContext(ctx) + } + + handler, _ = r.Handler(req) handler.ServeHTTP(w, req) @@ -156,7 +168,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { // PathParams takes a path and returns the values for any path parameters // in the path. func PathParams(req *http.Request) (params map[string]string) { - params = req.Context().Value(paramKey).(map[string]string) + params = req.Context().Value(paramsKey).(map[string]string) return } diff --git a/router_test.go b/router_test.go index 8aa252d..24df824 100644 --- a/router_test.go +++ b/router_test.go @@ -246,7 +246,7 @@ func testParamValues(router Router, t *testing.T) { reqPath := fmt.Sprintf("/users/%s/edit/%s", userID, status) router.AddRoute(method, path, func(w http.ResponseWriter, r *http.Request) { - params := router.PathParams(r) + params := PathParams(r) if len(params) != 2 { t.Errorf("Received the wrong number of parameters. Expected 2, recieved %d", len(params))