diff --git a/router.go b/router.go index 0f5bfe2..99567df 100644 --- a/router.go +++ b/router.go @@ -32,12 +32,16 @@ type segment struct { type Router struct { routes []route lookup *segment + + NotFoundHandler http.Handler } -// NewRouter is a constructor for Router. -func NewRouter() (r Router) { - return -} +// NotFoundHandler is the default function for handling routes that are not found. If you wish to +// provide your own handler for this, simply set it on the router. +var NotFoundHandler http.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) // AddRoute registers a new handler function to a path and http.HandlerFunc. If a path and // method already have a callback registered to them, and error is returned. @@ -57,12 +61,10 @@ func (r *Router) AddRoute(method string, path string, callback http.HandlerFunc) continue } - var seg segment - if child, ok := curr.children[key]; !ok { - seg = *newSegment(curr.path, key) - curr.children[key] = &seg - curr = &seg + seg := newSegment(curr.path, key) + curr.children[key] = seg + curr = seg } else { curr = child } @@ -98,10 +100,9 @@ func (r *Router) Handler(req *http.Request) (h http.Handler, pattern string) { segments := strings.Split(path, "/") keys := setupKeys(segments) - // TODO: make this a named function somewhere. Maybe allow a custom version. - h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) + if r.NotFoundHandler == nil { + h = NotFoundHandler + } for _, v := range keys { if v == "/" { diff --git a/router_test.go b/router_test.go index a52b5f0..ac1062c 100644 --- a/router_test.go +++ b/router_test.go @@ -47,42 +47,23 @@ func TestAddRouter(t *testing.T) { func TestHandler(t *testing.T) { router := Router{} - request, err := http.NewRequest(http.MethodGet, "http://example.com/items", nil) - rr := httptest.NewRecorder() + path := "/items" expectedBody := "I am /items" + expectedCode := 200 - if err != nil { - t.Error("Could not create request") - } - - router.AddRoute(http.MethodGet, "/items", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + router.AddRoute(http.MethodGet, path, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedCode) w.Write([]byte(expectedBody)) }) + err := matchAndCheckRoute(&router, http.MethodGet, path, expectedBody, expectedCode) + + if err != nil { + t.Error("Did not find the expected callback handler", err) + } + checkLookup(router.lookup) - h, pattern := router.Handler(request) - - if pattern != "/items" { - t.Errorf("The recovered patter does not match: %s", pattern) - } - - h.ServeHTTP(rr, request) - - if rr.Code != 200 { - t.Errorf("The returned callback did not write 200 to the header. Found %d", rr.Code) - } - - body, _ := ioutil.ReadAll(rr.Body) - - if string(body) != string([]byte(expectedBody)) { - t.Errorf( - "The returned callback did not write the expected body. Expected: %s. Actual: %s", - expectedBody, - string(body), - ) - } } func checkLookup(curr *segment) { @@ -93,6 +74,47 @@ func checkLookup(curr *segment) { } } +func matchAndCheckRoute(r *Router, method string, path string, expectedBody string, expectedCode int) (err error) { + request, err := http.NewRequest(method, path, nil) + rr := httptest.NewRecorder() + + if err != nil { + err = fmt.Errorf("Could not create request") + + return + } + + h, pattern := r.Handler(request) + + if pattern != "/items" { + err = fmt.Errorf("The recovered patter does not match: %s", pattern) + + return + } + + h.ServeHTTP(rr, request) + + if rr.Code != expectedCode { + err = fmt.Errorf("The returned callback did not write 200 to the header. Found %d", rr.Code) + + return + } + + body, _ := ioutil.ReadAll(rr.Body) + + if string(body) != string([]byte(expectedBody)) { + err = fmt.Errorf( + "The returned callback did not write the expected body. Expected: %s. Actual: %s", + expectedBody, + string(body), + ) + + return + } + + return +} + func addAndCheckRoute(r *Router, method string, path string, callback http.HandlerFunc, routeCounter *int) (err error) { err = r.AddRoute(method, path, callback) @@ -132,12 +154,3 @@ func addAndCheckRoute(r *Router, method string, path string, callback http.Handl return } - -// func TestHandle(t *testing.T) { -// r := NewRouter() - -// request, _ := http.NewRequest(http.MethodGet, "http://example.domain/api", nil) -// var writer http.ResponseWriter - -// r.Handle(writer, request) -// }