Fix messed up default 404.

This commit is contained in:
2020-12-15 00:38:09 -08:00
parent 05f90673c5
commit 19d616162f
2 changed files with 12 additions and 16 deletions

View File

@ -126,7 +126,11 @@ func (r Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
endpoint, params, err := r.getEndpoint(method, path) endpoint, params, err := r.getEndpoint(method, path)
if err != nil { if err != nil {
if r.NotFoundHandler != nil {
handler = r.NotFoundHandler handler = r.NotFoundHandler
} else {
handler = NotFoundHandler
}
} else { } else {
handler = endpoint.callback handler = endpoint.callback
ctx := context.WithValue(context.Background(), paramsKey, params) ctx := context.WithValue(context.Background(), paramsKey, params)
@ -179,7 +183,10 @@ func addSegment(curr *segment, key string) (seg *segment) {
// is returned. If there is no parameter child, nil is returned. isParam is true if the parameter child is // is returned. If there is no parameter child, nil is returned. isParam is true if the parameter child is
// being returned. // being returned.
func getChild(key string, curr *segment) (child *segment, param string) { func getChild(key string, curr *segment) (child *segment, param string) {
if seg, ok := curr.children[key]; ok { // is there an exact match? if curr == nil {
return
} else if seg, ok := curr.children[key]; ok { // is there an exact match?
child = seg child = seg
} else if curr.parameter.segment != nil { // could this be a parameter? } else if curr.parameter.segment != nil { // could this be a parameter?
@ -207,7 +214,6 @@ func (r *Router) getEndpoint(method string, path string) (end *endpoint, params
if seg == nil { if seg == nil {
err = errors.New("route not found") err = errors.New("route not found")
return return
} }

View File

@ -199,19 +199,14 @@ func testCustomNotFound(router Router, t *testing.T) {
expectedBody := "Forbidden" expectedBody := "Forbidden"
expectedCode := 401 expectedCode := 401
path := "/actual/path" path := "/gibberish/forbidden"
router.AddRoute(http.MethodPatch, path, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("Not found."))
})
router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedCode) w.WriteHeader(expectedCode)
w.Write([]byte(expectedBody)) w.Write([]byte(expectedBody))
}) })
err := matchAndCheckRoute(&router, http.MethodPatch, "/gibberish/forbidden", expectedBody, expectedCode) err := matchAndCheckRoute(&router, http.MethodPatch, path, expectedBody, expectedCode)
if err != nil { if err != nil {
t.Error("Did not call the custom handler.", err) t.Error("Did not call the custom handler.", err)
@ -227,12 +222,7 @@ func testDefaultNotFound(router Router, t *testing.T) {
expectedCode := 404 expectedCode := 404
path := "/gibberish" path := "/gibberish"
router.AddRoute(http.MethodDelete, path, func(w http.ResponseWriter, r *http.Request) { err := matchAndCheckRoute(&router, http.MethodDelete, path, expectedBody, expectedCode)
w.WriteHeader(expectedCode)
w.Write([]byte(expectedBody))
})
err := matchAndCheckRoute(&router, http.MethodDelete, "/gibberish", expectedBody, expectedCode)
if err != nil { if err != nil {
t.Error("Did not find the expected callback handler", err) t.Error("Did not find the expected callback handler", err)