diff --git a/pkg/receive/handler.go b/pkg/receive/handler.go index 771b85cce88..36d6019b00f 100644 --- a/pkg/receive/handler.go +++ b/pkg/receive/handler.go @@ -12,18 +12,20 @@ import ( stdlog "log" "net" "net/http" + "path" "sort" "strconv" "sync" "time" + "github.com/mwitkow/go-conntrack" + "github.com/opentracing/opentracing-go" + "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/gogo/protobuf/proto" "github.com/jpillora/backoff" "github.com/klauspost/compress/s2" - "github.com/mwitkow/go-conntrack" - "github.com/opentracing/opentracing-go" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -403,6 +405,13 @@ func (h *Handler) handleRequest(ctx context.Context, rep uint64, tenant string, return h.forward(ctx, tenant, r, wreq) } +func (h *Handler) isTenantValid(tenant string, err error, w http.ResponseWriter) { + if tenant != path.Base(tenant) { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) { var err error span, ctx := tracing.StartSpan(r.Context(), "receive_http") @@ -422,6 +431,8 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) { } } + h.isTenantValid(tenant, err, w) + tLogger := log.With(h.logger, "tenant", tenant) writeGate := h.Limiter.WriteGate() diff --git a/pkg/receive/handler_test.go b/pkg/receive/handler_test.go index d363f0deb75..96cb0c0fd32 100644 --- a/pkg/receive/handler_test.go +++ b/pkg/receive/handler_test.go @@ -809,6 +809,10 @@ func makeRequest(h *Handler, tenant string, wreq *prompb.WriteRequest) (*httptes return rec, nil } +func TestIsTenantValid(h *Handler, tenant string) { + +} + type addrGen struct{ n int } func (a *addrGen) newAddr() string { @@ -1090,6 +1094,24 @@ func Heap(dir string) (err error) { return pprof.WriteHeapProfile(f) } +func TestValidTenant(t *testing.T) { + for _, tcase := range []struct { + name string + tenant string + err error + }{ + { + name: "test malicious tenant", + tenant: "/etc/foo", + }, + } { + t.Run(tcase.name, func(t *testing.T) { + h := NewHandler(nil, &Options{}) + h.isTenantValid(tcase.tenant, tcase.err) + }) + } +} + func TestRelabel(t *testing.T) { for _, tcase := range []struct { name string