Use request contex

This commit is contained in:
2020-12-09 22:34:43 -08:00
parent 08228b2aa7
commit df7e3ccc03
2 changed files with 44 additions and 53 deletions

View File

@ -1,11 +1,14 @@
package router
import (
"context"
"errors"
"net/http"
"strings"
)
type contextKey string
// Router is a replacement for the net/http DefaultServerMux. This version includes the
// ability to add path parameter in the given path.
//
@ -31,6 +34,12 @@ type endpoint struct {
pathParams []string
}
// parameter contains a pointer to a parameter segment and the name of the parameter.
type parameter struct {
name string
segment *segment
}
// route is not part of the tree, but is saved on the router to represent all the available
// routes in the tree.
type route struct {
@ -46,9 +55,11 @@ type route struct {
type segment struct {
children map[string]*segment
endpoints map[string]*endpoint
parameter *segment
parameter parameter
}
var paramKey = contextKey("params")
// NotFoundHandler is the default function for handling routes that are not found. If you wish to
// provide your own handler for this, simply set it on the router.
var NotFoundHandler http.Handler = http.HandlerFunc(
@ -103,42 +114,20 @@ func (r *Router) AddRoute(method string, path string, callback http.HandlerFunc)
// PathParams takes a path and returns the values for any path parameters
// in the path.
func (r *Router) PathParams(method string, reqPath string) (params map[string]string, err error) {
end, err := r.getEndpoint(method, reqPath)
path := end.path
reqParts := strings.Split(reqPath, "/")
reqKeys := setupKeys(reqParts)
pathParts := strings.Split(path, "/")
pathKeys := setupKeys(pathParts)
params = map[string]string{}
for i, pathKey := range pathKeys {
if isParameter(pathKey) {
name := pathKey[2:]
value := reqKeys[i][1:]
params[name] = value
}
}
func (r *Router) PathParams(req *http.Request) (params map[string]string) {
params = req.Context().Value(paramKey).(map[string]string)
return
}
// Get is a convinience method which calls Router.AddRoute with the "GET" method.
func (r *Router) Get(path string, callback http.HandlerFunc) {
r.AddRoute(http.MethodGet, path, callback)
}
// Handler returns the handler to use for the given request, consulting r.Method, r.URL.Path. It
// handler returns the handler to use for the given request, consulting r.Method, r.URL.Path. It
// always returns a non-nil handler.
//
// Handler also returns the registered pattern that matches the request.
// handler also returns a new context which contains any path parameters that are needed.
//
// If there is no registered handler that applies to the request, Handler returns a ``page not
// If there is no registered handler that applies to the request, handler returns a ``page not
// found'' handler and an empty pattern.
func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) {
func (r *Router) handler(req *http.Request) (h http.Handler, ctx context.Context) {
method := req.Method
path := req.URL.Path
@ -146,11 +135,11 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) {
h = NotFoundHandler
}
endpoint, err := r.getEndpoint(method, path)
endpoint, params, err := r.getEndpoint(method, path)
ctx = context.WithValue(context.Background(), paramKey, params)
if err == nil {
h = endpoint.callback
pattern = endpoint.path
}
return
@ -163,7 +152,10 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) {
// In the case of this router, all it needs to do is lookup the Handler that has been saved at a given
// path and then call its ServeHTTP.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
handler, _ := r.Handler(req)
handler, ctx := r.handler(req)
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
return
@ -172,8 +164,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// addSegment create a new segment either as a child or as a parameter depending on whether the key
// qualifies as a parameter. A pointer to the created segment is then returned.
func addSegment(curr *segment, key string) (seg *segment) {
if curr.parameter != nil {
seg = curr.parameter
if curr.parameter.segment != nil {
seg = curr.parameter.segment
} else if child, ok := curr.children[key]; !ok { // child does not match...
var isParam bool
@ -181,7 +173,8 @@ func addSegment(curr *segment, key string) (seg *segment) {
seg, isParam = newSegment(key)
if isParam {
curr.parameter = seg
curr.parameter.segment = seg
curr.parameter.name = key[2:]
} else {
curr.children[key] = seg
@ -200,13 +193,13 @@ func addSegment(curr *segment, key string) (seg *segment) {
// child on the segment, then that child segment is returned. If it is not a match, then the parameter child
// is returned. If there is no parameter child, nil is returned. isParam is true if the parameter child is
// being returned.
func getChild(key string, curr *segment) (child *segment, isParam bool) {
func getChild(key string, curr *segment) (child *segment, param string) {
if seg, ok := curr.children[key]; ok { // is there an exact match?
child = seg
} else if curr.parameter != nil { // could this be a parameter?
child = curr.parameter
isParam = true
} else if curr.parameter.segment != nil { // could this be a parameter?
child = curr.parameter.segment
param = curr.parameter.name
}
return
@ -214,22 +207,27 @@ func getChild(key string, curr *segment) (child *segment, isParam bool) {
// getEndpoint takes a path and traverses the tree until it finds the endpoint associated with that path.
// If no endpoint if found, an error is returned.
func (r *Router) getEndpoint(method string, path string) (end *endpoint, err error) {
func (r *Router) getEndpoint(method string, path string) (end *endpoint, params map[string]string, err error) {
curr := r.root
segments := strings.Split(path, "/")
params = map[string]string{}
keys := setupKeys(segments)
for _, v := range keys {
if v == "/" {
for _, key := range keys {
if key == "/" {
continue
}
seg, _ := getChild(v, curr)
seg, paramName := getChild(key, curr)
if seg == nil {
return
}
if paramName != "" {
params[paramName] = key[1:]
}
curr = seg
}