diff --git a/router.go b/router.go index d4be689..251da5c 100644 --- a/router.go +++ b/router.go @@ -126,7 +126,11 @@ func (r Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { endpoint, params, err := r.getEndpoint(method, path) if err != nil { - handler = r.NotFoundHandler + if r.NotFoundHandler != nil { + handler = r.NotFoundHandler + } else { + handler = NotFoundHandler + } } else { handler = endpoint.callback ctx := context.WithValue(context.Background(), paramsKey, params) @@ -179,7 +183,10 @@ func addSegment(curr *segment, key string) (seg *segment) { // is returned. If there is no parameter child, nil is returned. isParam is true if the parameter child is // being returned. func getChild(key string, curr *segment) (child *segment, param string) { - if seg, ok := curr.children[key]; ok { // is there an exact match? + if curr == nil { + return + + } else if seg, ok := curr.children[key]; ok { // is there an exact match? child = seg } else if curr.parameter.segment != nil { // could this be a parameter? @@ -207,7 +214,6 @@ func (r *Router) getEndpoint(method string, path string) (end *endpoint, params if seg == nil { err = errors.New("route not found") - return } diff --git a/router_test.go b/router_test.go index 4288b0c..3c845e2 100644 --- a/router_test.go +++ b/router_test.go @@ -199,19 +199,14 @@ func testCustomNotFound(router Router, t *testing.T) { expectedBody := "Forbidden" expectedCode := 401 - path := "/actual/path" - - router.AddRoute(http.MethodPatch, path, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - w.Write([]byte("Not found.")) - }) + path := "/gibberish/forbidden" router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(expectedCode) w.Write([]byte(expectedBody)) }) - err := matchAndCheckRoute(&router, http.MethodPatch, "/gibberish/forbidden", expectedBody, expectedCode) + err := matchAndCheckRoute(&router, http.MethodPatch, path, expectedBody, expectedCode) if err != nil { t.Error("Did not call the custom handler.", err) @@ -227,12 +222,7 @@ func testDefaultNotFound(router Router, t *testing.T) { expectedCode := 404 path := "/gibberish" - router.AddRoute(http.MethodDelete, path, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(expectedCode) - w.Write([]byte(expectedBody)) - }) - - err := matchAndCheckRoute(&router, http.MethodDelete, "/gibberish", expectedBody, expectedCode) + err := matchAndCheckRoute(&router, http.MethodDelete, path, expectedBody, expectedCode) if err != nil { t.Error("Did not find the expected callback handler", err)