diff --git a/router.go b/router.go index e7a1829..469ef85 100644 --- a/router.go +++ b/router.go @@ -1,11 +1,14 @@ package router import ( + "context" "errors" "net/http" "strings" ) +type contextKey string + // Router is a replacement for the net/http DefaultServerMux. This version includes the // ability to add path parameter in the given path. // @@ -31,6 +34,12 @@ type endpoint struct { pathParams []string } +// parameter contains a pointer to a parameter segment and the name of the parameter. +type parameter struct { + name string + segment *segment +} + // route is not part of the tree, but is saved on the router to represent all the available // routes in the tree. type route struct { @@ -46,9 +55,11 @@ type route struct { type segment struct { children map[string]*segment endpoints map[string]*endpoint - parameter *segment + parameter parameter } +var paramKey = contextKey("params") + // NotFoundHandler is the default function for handling routes that are not found. If you wish to // provide your own handler for this, simply set it on the router. var NotFoundHandler http.Handler = http.HandlerFunc( @@ -103,42 +114,20 @@ func (r *Router) AddRoute(method string, path string, callback http.HandlerFunc) // PathParams takes a path and returns the values for any path parameters // in the path. -func (r *Router) PathParams(method string, reqPath string) (params map[string]string, err error) { - end, err := r.getEndpoint(method, reqPath) - path := end.path - - reqParts := strings.Split(reqPath, "/") - reqKeys := setupKeys(reqParts) - - pathParts := strings.Split(path, "/") - pathKeys := setupKeys(pathParts) - - params = map[string]string{} - - for i, pathKey := range pathKeys { - if isParameter(pathKey) { - name := pathKey[2:] - value := reqKeys[i][1:] - params[name] = value - } - } +func (r *Router) PathParams(req *http.Request) (params map[string]string) { + params = req.Context().Value(paramKey).(map[string]string) return } -// Get is a convinience method which calls Router.AddRoute with the "GET" method. -func (r *Router) Get(path string, callback http.HandlerFunc) { - r.AddRoute(http.MethodGet, path, callback) -} - -// Handler returns the handler to use for the given request, consulting r.Method, r.URL.Path. It +// handler returns the handler to use for the given request, consulting r.Method, r.URL.Path. It // always returns a non-nil handler. // -// Handler also returns the registered pattern that matches the request. +// handler also returns a new context which contains any path parameters that are needed. // -// If there is no registered handler that applies to the request, Handler returns a ``page not +// If there is no registered handler that applies to the request, handler returns a ``page not // found'' handler and an empty pattern. -func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) { +func (r *Router) handler(req *http.Request) (h http.Handler, ctx context.Context) { method := req.Method path := req.URL.Path @@ -146,11 +135,11 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) { h = NotFoundHandler } - endpoint, err := r.getEndpoint(method, path) + endpoint, params, err := r.getEndpoint(method, path) + ctx = context.WithValue(context.Background(), paramKey, params) if err == nil { h = endpoint.callback - pattern = endpoint.path } return @@ -163,7 +152,10 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) { // In the case of this router, all it needs to do is lookup the Handler that has been saved at a given // path and then call its ServeHTTP. func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - handler, _ := r.Handler(req) + handler, ctx := r.handler(req) + + req = req.WithContext(ctx) + handler.ServeHTTP(w, req) return @@ -172,8 +164,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { // addSegment create a new segment either as a child or as a parameter depending on whether the key // qualifies as a parameter. A pointer to the created segment is then returned. func addSegment(curr *segment, key string) (seg *segment) { - if curr.parameter != nil { - seg = curr.parameter + if curr.parameter.segment != nil { + seg = curr.parameter.segment } else if child, ok := curr.children[key]; !ok { // child does not match... var isParam bool @@ -181,7 +173,8 @@ func addSegment(curr *segment, key string) (seg *segment) { seg, isParam = newSegment(key) if isParam { - curr.parameter = seg + curr.parameter.segment = seg + curr.parameter.name = key[2:] } else { curr.children[key] = seg @@ -200,13 +193,13 @@ func addSegment(curr *segment, key string) (seg *segment) { // child on the segment, then that child segment is returned. If it is not a match, then the parameter child // is returned. If there is no parameter child, nil is returned. isParam is true if the parameter child is // being returned. -func getChild(key string, curr *segment) (child *segment, isParam bool) { +func getChild(key string, curr *segment) (child *segment, param string) { if seg, ok := curr.children[key]; ok { // is there an exact match? child = seg - } else if curr.parameter != nil { // could this be a parameter? - child = curr.parameter - isParam = true + } else if curr.parameter.segment != nil { // could this be a parameter? + child = curr.parameter.segment + param = curr.parameter.name } return @@ -214,22 +207,27 @@ func getChild(key string, curr *segment) (child *segment, isParam bool) { // getEndpoint takes a path and traverses the tree until it finds the endpoint associated with that path. // If no endpoint if found, an error is returned. -func (r *Router) getEndpoint(method string, path string) (end *endpoint, err error) { +func (r *Router) getEndpoint(method string, path string) (end *endpoint, params map[string]string, err error) { curr := r.root segments := strings.Split(path, "/") + params = map[string]string{} keys := setupKeys(segments) - for _, v := range keys { - if v == "/" { + for _, key := range keys { + if key == "/" { continue } - seg, _ := getChild(v, curr) + seg, paramName := getChild(key, curr) if seg == nil { return } + if paramName != "" { + params[paramName] = key[1:] + } + curr = seg } diff --git a/router_test.go b/router_test.go index 2faa5ee..8aa252d 100644 --- a/router_test.go +++ b/router_test.go @@ -85,8 +85,8 @@ func checkLookup(curr *segment) { checkLookup(v) } - if curr.parameter != nil { - checkLookup(curr.parameter) + if curr.parameter.segment != nil { + checkLookup(curr.parameter.segment) } } @@ -246,14 +246,7 @@ func testParamValues(router Router, t *testing.T) { reqPath := fmt.Sprintf("/users/%s/edit/%s", userID, status) router.AddRoute(method, path, func(w http.ResponseWriter, r *http.Request) { - requestPath := r.URL.Path - requestMethod := r.Method - - params, err := router.PathParams(requestMethod, requestPath) - - if err != nil { - t.Error("An error occurred while getting path parameters") - } + params := router.PathParams(r) if len(params) != 2 { t.Errorf("Received the wrong number of parameters. Expected 2, recieved %d", len(params))