diff --git a/middleware/strip.go b/middleware/strip.go index ce8ebfcc..1368fa7a 100644 --- a/middleware/strip.go +++ b/middleware/strip.go @@ -60,3 +60,11 @@ func RedirectSlashes(next http.Handler) http.Handler { } return http.HandlerFunc(fn) } + +// StripPrefix is a middleware that will strip the provided prefix from the +// request path before handing the request over to the next handler. +func StripPrefix(prefix string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.StripPrefix(prefix, next) + } +} diff --git a/middleware/strip_test.go b/middleware/strip_test.go index 51fa9393..3ff8e87d 100644 --- a/middleware/strip_test.go +++ b/middleware/strip_test.go @@ -237,3 +237,38 @@ func TestStripSlashesWithNilContext(t *testing.T) { t.Fatal(resp) } } + +func TestStripPrefix(t *testing.T) { + r := chi.NewRouter() + + r.Use(StripPrefix("/api")) + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("api root")) + }) + + r.Get("/accounts", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("api accounts")) + }) + + r.Get("/accounts/{accountID}", func(w http.ResponseWriter, r *http.Request) { + accountID := chi.URLParam(r, "accountID") + w.Write([]byte(accountID)) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, resp := testRequest(t, ts, "GET", "/api/", nil); resp != "api root" { + t.Fatalf("got: %q, want: %q", resp, "api root") + } + if _, resp := testRequest(t, ts, "GET", "/api/accounts", nil); resp != "api accounts" { + t.Fatalf("got: %q, want: %q", resp, "api accounts") + } + if _, resp := testRequest(t, ts, "GET", "/api/accounts/admin", nil); resp != "admin" { + t.Fatalf("got: %q, want: %q", resp, "admin") + } + if _, resp := testRequest(t, ts, "GET", "/api-nope/", nil); resp != "404 page not found\n" { + t.Fatalf("got: %q, want: %q", resp, "404 page not found\n") + } +}