diff --git a/tenant/resolver.go b/tenant/resolver.go index e5fbea252..72517b082 100644 --- a/tenant/resolver.go +++ b/tenant/resolver.go @@ -2,6 +2,7 @@ package tenant import ( "context" + "errors" "net/http" "strings" @@ -59,14 +60,36 @@ func NewSingleResolver() *SingleResolver { type SingleResolver struct { } +// containsUnsafePathSegments will return true if the string is a directory +// reference like `.` and `..` or if any path separator character like `/` and +// `\` can be found. +func containsUnsafePathSegments(id string) bool { + // handle the relative reference to current and parent path. + if id == "." || id == ".." { + return true + } + + return strings.ContainsAny(id, "\\/") +} + +var errInvalidTenantID = errors.New("invalid tenant ID") + func (t *SingleResolver) TenantID(ctx context.Context) (string, error) { //lint:ignore faillint wrapper around upstream method - return user.ExtractOrgID(ctx) + id, err := user.ExtractOrgID(ctx) + if err != nil { + return "", err + } + + if containsUnsafePathSegments(id) { + return "", errInvalidTenantID + } + + return id, nil } func (t *SingleResolver) TenantIDs(ctx context.Context) ([]string, error) { - //lint:ignore faillint wrapper around upstream method - orgID, err := user.ExtractOrgID(ctx) + orgID, err := t.TenantID(ctx) if err != nil { return nil, err } @@ -109,6 +132,9 @@ func (t *MultiResolver) TenantIDs(ctx context.Context) ([]string, error) { if err := ValidTenantID(orgID); err != nil { return nil, err } + if containsUnsafePathSegments(orgID) { + return nil, errInvalidTenantID + } } return NormalizeTenantIDs(orgIDs), nil diff --git a/tenant/resolver_test.go b/tenant/resolver_test.go index 69559263b..4d2da2416 100644 --- a/tenant/resolver_test.go +++ b/tenant/resolver_test.go @@ -64,6 +64,18 @@ var commonResolverTestCases = []resolverTestCase{ tenantID: "tenant-a", tenantIDs: []string{"tenant-a"}, }, + { + name: "parent-dir", + headerValue: strptr(".."), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, + { + name: "current-dir", + headerValue: strptr("."), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, } func TestSingleResolver(t *testing.T) { @@ -75,6 +87,18 @@ func TestSingleResolver(t *testing.T) { tenantID: "tenant-a|tenant-b", tenantIDs: []string{"tenant-a|tenant-b"}, }, + { + name: "containing-forward-slash", + headerValue: strptr("forward/slash"), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, + { + name: "containing-backward-slash", + headerValue: strptr(`backward\slash`), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, }...) { t.Run(tc.name, tc.test(r)) } @@ -101,6 +125,24 @@ func TestMultiResolver(t *testing.T) { errTenantID: user.ErrTooManyOrgIDs, tenantIDs: []string{"tenant-a", "tenant-b"}, }, + { + name: "multi-tenant-with-relative-path", + headerValue: strptr("tenant-a|tenant-b|.."), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, + { + name: "containing-forward-slash", + headerValue: strptr("forward/slash"), + errTenantID: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"}, + errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"}, + }, + { + name: "containing-backward-slash", + headerValue: strptr(`backward\slash`), + errTenantID: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"}, + errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"}, + }, }...) { t.Run(tc.name, tc.test(r)) }