Store path params in request context.

This commit is contained in:
2020-12-13 22:57:01 -08:00
parent ac8a5ff56e
commit ee0bb08b57
2 changed files with 21 additions and 9 deletions

View File

@ -58,7 +58,7 @@ type segment struct {
parameter parameter parameter parameter
} }
var paramKey = contextKey("params") var paramsKey = contextKey("params")
// NotFoundHandler is the default function for handling routes that are not found. If you wish to // NotFoundHandler is the default function for handling routes that are not found. If you wish to
// provide your own handler for this, simply set it on the router. // provide your own handler for this, simply set it on the router.
@ -115,11 +115,11 @@ func (r *Router) AddRoute(method string, path string, callback http.HandlerFunc)
// Handler returns the Handler to use for the given request, consulting r.Method, r.URL.Path. It // Handler returns the Handler to use for the given request, consulting r.Method, r.URL.Path. It
// always returns a non-nil Handler. // always returns a non-nil Handler.
// //
// Handler also returns a new context which contains any path parameters that are needed. // Handler also returns the path which it matched.
// //
// If there is no registered Handler that applies to the request, Handler returns a ``page not // If there is no registered Handler that applies to the request, Handler returns a ``page not
// found'' Handler and an empty pattern. // found'' Handler and an empty pattern.
func (r *Router) Handler(req *http.Request) (h http.Handler, ctx context.Context) { func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) {
method := req.Method method := req.Method
path := req.URL.Path path := req.URL.Path
@ -127,11 +127,11 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, ctx context.Context
h = NotFoundHandler h = NotFoundHandler
} }
endpoint, params, err := r.getEndpoint(method, path) endpoint, _, err := r.getEndpoint(method, path)
ctx = context.WithValue(context.Background(), paramKey, params)
if err == nil { if err == nil {
h = endpoint.callback h = endpoint.callback
pattern = endpoint.path
} }
return return
@ -144,9 +144,21 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, ctx context.Context
// In the case of this router, all it needs to do is lookup the Handler that has been saved at a given // In the case of this router, all it needs to do is lookup the Handler that has been saved at a given
// path and then call its ServeHTTP. // path and then call its ServeHTTP.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
handler, ctx := r.Handler(req) method := req.Method
path := req.URL.Path
var handler http.Handler
req = req.WithContext(ctx) endpoint, params, err := r.getEndpoint(method, path)
if err != nil {
handler = NotFoundHandler
} else {
handler = endpoint.callback
ctx := context.WithValue(context.Background(), paramsKey, params)
req = req.WithContext(ctx)
}
handler, _ = r.Handler(req)
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
@ -156,7 +168,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// PathParams takes a path and returns the values for any path parameters // PathParams takes a path and returns the values for any path parameters
// in the path. // in the path.
func PathParams(req *http.Request) (params map[string]string) { func PathParams(req *http.Request) (params map[string]string) {
params = req.Context().Value(paramKey).(map[string]string) params = req.Context().Value(paramsKey).(map[string]string)
return return
} }

View File

@ -246,7 +246,7 @@ func testParamValues(router Router, t *testing.T) {
reqPath := fmt.Sprintf("/users/%s/edit/%s", userID, status) reqPath := fmt.Sprintf("/users/%s/edit/%s", userID, status)
router.AddRoute(method, path, func(w http.ResponseWriter, r *http.Request) { router.AddRoute(method, path, func(w http.ResponseWriter, r *http.Request) {
params := router.PathParams(r) params := PathParams(r)
if len(params) != 2 { if len(params) != 2 {
t.Errorf("Received the wrong number of parameters. Expected 2, recieved %d", len(params)) t.Errorf("Received the wrong number of parameters. Expected 2, recieved %d", len(params))