From 9f9a449e4e57212e4e4db1e3c4830d3f3a5db6bb Mon Sep 17 00:00:00 2001 From: Marco Pracucci Date: Mon, 17 Jul 2023 15:27:43 +0200 Subject: [PATCH 1/5] Fix cardinality and label names/values requests handling on POST requests Signed-off-by: Marco Pracucci --- CHANGELOG.md | 2 +- pkg/cardinality/request.go | 56 +++--- pkg/cardinality/request_test.go | 35 ++-- .../cardinality_query_cache.go | 11 +- .../cardinality_query_cache_test.go | 6 +- .../querymiddleware/generic_query_cache.go | 15 +- .../generic_query_cache_test.go | 167 +++++++++++------- .../querymiddleware/labels_query_cache.go | 33 ++-- .../labels_query_cache_test.go | 15 +- pkg/frontend/transport/handler.go | 19 +- pkg/util/http.go | 46 +++++ pkg/util/http_test.go | 60 +++++++ 12 files changed, 313 insertions(+), 152 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0a9d7a8897..a2293af832f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ * [CHANGE] Querier: `-query-frontend.cache-unaligned-requests` has been moved from a global flag to a per-tenant override. #5312 * [CHANGE] Ingester: removed `cortex_ingester_shipper_dir_syncs_total` and `cortex_ingester_shipper_dir_sync_failures_total` metrics. The former metric was not much useful, and the latter was never incremented. #5396 * [FEATURE] Cardinality API: Add a new `count_method` parameter which enables counting active series #5136 -* [FEATURE] Query-frontend: added experimental support to cache cardinality, label names and label values query responses. The cache will be used when `-query-frontend.cache-results` is enabled, and `-query-frontend.results-cache-ttl-for-cardinality-query` or `-query-frontend.results-cache-ttl-for-labels-query` set to a value greater than 0. The following metrics have been added to track the query results cache hit ratio per `request_type`: #5212 #5235 #5426 +* [FEATURE] Query-frontend: added experimental support to cache cardinality, label names and label values query responses. The cache will be used when `-query-frontend.cache-results` is enabled, and `-query-frontend.results-cache-ttl-for-cardinality-query` or `-query-frontend.results-cache-ttl-for-labels-query` set to a value greater than 0. The following metrics have been added to track the query results cache hit ratio per `request_type`: #5212 #5235 #5426 #5524 * `cortex_frontend_query_result_cache_requests_total{request_type="query_range|cardinality|label_names_and_values"}` * `cortex_frontend_query_result_cache_hits_total{request_type="query_range|cardinality|label_names_and_values"}` * [FEATURE] Added `-.s3.list-objects-version` flag to configure the S3 list objects version. diff --git a/pkg/cardinality/request.go b/pkg/cardinality/request.go index 955acd3e2eb..8a00fcd4f57 100644 --- a/pkg/cardinality/request.go +++ b/pkg/cardinality/request.go @@ -5,6 +5,7 @@ package cardinality import ( "fmt" "net/http" + "net/url" "strconv" "strings" @@ -60,22 +61,26 @@ func (r *LabelNamesRequest) String() string { // DecodeLabelNamesRequest decodes the input http.Request into a LabelNamesRequest. // The input http.Request can either be a GET or POST with URL-encoded parameters. func DecodeLabelNamesRequest(r *http.Request) (*LabelNamesRequest, error) { + if err := r.ParseForm(); err != nil { + return nil, err + } + + return DecodeLabelNamesRequestFromValues(r.Form) +} + +// DecodeLabelNamesRequestFromValues is like DecodeLabelNamesRequest but takes url.Values in input. +func DecodeLabelNamesRequestFromValues(values url.Values) (*LabelNamesRequest, error) { var ( parsed = &LabelNamesRequest{} err error ) - err = r.ParseForm() - if err != nil { - return nil, err - } - - parsed.Matchers, err = extractSelector(r) + parsed.Matchers, err = extractSelector(values) if err != nil { return nil, err } - parsed.Limit, err = extractLimit(r) + parsed.Limit, err = extractLimit(values) if err != nil { return nil, err } @@ -126,31 +131,36 @@ func (r *LabelValuesRequest) String() string { // DecodeLabelValuesRequest decodes the input http.Request into a LabelValuesRequest. // The input http.Request can either be a GET or POST with URL-encoded parameters. func DecodeLabelValuesRequest(r *http.Request) (*LabelValuesRequest, error) { + if err := r.ParseForm(); err != nil { + return nil, err + } + + return DecodeLabelValuesRequestFromValues(r.Form) +} + +// DecodeLabelValuesRequestFromValues is like DecodeLabelValuesRequest but takes url.Values in input. +func DecodeLabelValuesRequestFromValues(values url.Values) (*LabelValuesRequest, error) { var ( parsed = &LabelValuesRequest{} err error ) - if err = r.ParseForm(); err != nil { - return nil, err - } - - parsed.LabelNames, err = extractLabelNames(r) + parsed.LabelNames, err = extractLabelNames(values) if err != nil { return nil, err } - parsed.Matchers, err = extractSelector(r) + parsed.Matchers, err = extractSelector(values) if err != nil { return nil, err } - parsed.Limit, err = extractLimit(r) + parsed.Limit, err = extractLimit(values) if err != nil { return nil, err } - parsed.CountMethod, err = extractCountMethod(r) + parsed.CountMethod, err = extractCountMethod(values) if err != nil { return nil, err } @@ -159,8 +169,8 @@ func DecodeLabelValuesRequest(r *http.Request) (*LabelValuesRequest, error) { } // extractSelector parses and gets selector query parameter containing a single matcher -func extractSelector(r *http.Request) (matchers []*labels.Matcher, err error) { - selectorParams := r.Form["selector"] +func extractSelector(values url.Values) (matchers []*labels.Matcher, err error) { + selectorParams := values["selector"] if len(selectorParams) == 0 { return nil, nil } @@ -187,8 +197,8 @@ func extractSelector(r *http.Request) (matchers []*labels.Matcher, err error) { } // extractLimit parses and validates request param `limit` if it's defined, otherwise returns default value. -func extractLimit(r *http.Request) (limit int, err error) { - limitParams := r.Form["limit"] +func extractLimit(values url.Values) (limit int, err error) { + limitParams := values["limit"] if len(limitParams) == 0 { return defaultLimit, nil } @@ -209,8 +219,8 @@ func extractLimit(r *http.Request) (limit int, err error) { } // extractLabelNames parses and gets label_names query parameter containing an array of label values -func extractLabelNames(r *http.Request) ([]model.LabelName, error) { - labelNamesParams := r.Form["label_names[]"] +func extractLabelNames(values url.Values) ([]model.LabelName, error) { + labelNamesParams := values["label_names[]"] if len(labelNamesParams) == 0 { return nil, fmt.Errorf("'label_names[]' param is required") } @@ -231,8 +241,8 @@ func extractLabelNames(r *http.Request) ([]model.LabelName, error) { } // extractCountMethod parses and validates request param `count_method` if it's defined, otherwise returns default value. -func extractCountMethod(r *http.Request) (countMethod CountMethod, err error) { - countMethodParams := r.Form["count_method"] +func extractCountMethod(values url.Values) (countMethod CountMethod, err error) { + countMethodParams := values["count_method"] if len(countMethodParams) == 0 { return defaultCountMethod, nil } diff --git a/pkg/cardinality/request_test.go b/pkg/cardinality/request_test.go index bfceac1e5de..facd9f17f45 100644 --- a/pkg/cardinality/request_test.go +++ b/pkg/cardinality/request_test.go @@ -19,7 +19,7 @@ func TestDecodeLabelNamesRequest(t *testing.T) { params = url.Values{ "selector": []string{`{second!="2",first="1"}`}, "limit": []string{"100"}, - }.Encode() + } expected = &LabelNamesRequest{ Matchers: []*labels.Matcher{ @@ -30,8 +30,8 @@ func TestDecodeLabelNamesRequest(t *testing.T) { } ) - t.Run("GET request", func(t *testing.T) { - req, err := http.NewRequest("GET", "http://localhost?"+params, nil) + t.Run("DecodeLabelNamesRequest() with GET request", func(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost?"+params.Encode(), nil) require.NoError(t, err) actual, err := DecodeLabelNamesRequest(req) @@ -40,8 +40,8 @@ func TestDecodeLabelNamesRequest(t *testing.T) { assert.Equal(t, expected, actual) }) - t.Run("POST request", func(t *testing.T) { - req, err := http.NewRequest("POST", "http://localhost/", strings.NewReader(params)) + t.Run("DecodeLabelNamesRequest() with POST request", func(t *testing.T) { + req, err := http.NewRequest("POST", "http://localhost/", strings.NewReader(params.Encode())) require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -50,6 +50,13 @@ func TestDecodeLabelNamesRequest(t *testing.T) { assert.Equal(t, expected, actual) }) + + t.Run("DecodeLabelNamesRequestFromValues()", func(t *testing.T) { + actual, err := DecodeLabelNamesRequestFromValues(params) + require.NoError(t, err) + + assert.Equal(t, expected, actual) + }) } func TestLabelNamesRequest_String(t *testing.T) { @@ -71,7 +78,7 @@ func TestDecodeLabelValuesRequest(t *testing.T) { "label_names[]": []string{"metric_2", "metric_1"}, "count_method": []string{"active"}, "limit": []string{"100"}, - }.Encode() + } expected = &LabelValuesRequest{ LabelNames: []model.LabelName{"metric_1", "metric_2"}, @@ -84,9 +91,8 @@ func TestDecodeLabelValuesRequest(t *testing.T) { } ) - t.Run("GET request", func(t *testing.T) { - - req, err := http.NewRequest("GET", "http://localhost?"+params, nil) + t.Run("DecodeLabelValuesRequest() GET request", func(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost?"+params.Encode(), nil) require.NoError(t, err) actual, err := DecodeLabelValuesRequest(req) @@ -95,8 +101,8 @@ func TestDecodeLabelValuesRequest(t *testing.T) { assert.Equal(t, expected, actual) }) - t.Run("POST request", func(t *testing.T) { - req, err := http.NewRequest("POST", "http://localhost/", strings.NewReader(params)) + t.Run("DecodeLabelValuesRequest() POST request", func(t *testing.T) { + req, err := http.NewRequest("POST", "http://localhost/", strings.NewReader(params.Encode())) require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -105,6 +111,13 @@ func TestDecodeLabelValuesRequest(t *testing.T) { assert.Equal(t, expected, actual) }) + + t.Run("DecodeLabelValuesRequestFromValues() GET request", func(t *testing.T) { + actual, err := DecodeLabelValuesRequestFromValues(params) + require.NoError(t, err) + + assert.Equal(t, expected, actual) + }) } func TestLabelValuesRequest_String(t *testing.T) { diff --git a/pkg/frontend/querymiddleware/cardinality_query_cache.go b/pkg/frontend/querymiddleware/cardinality_query_cache.go index 71b63a90e09..6c544493511 100644 --- a/pkg/frontend/querymiddleware/cardinality_query_cache.go +++ b/pkg/frontend/querymiddleware/cardinality_query_cache.go @@ -5,6 +5,7 @@ package querymiddleware import ( "errors" "net/http" + "net/url" "strings" "time" @@ -36,10 +37,10 @@ func (c *cardinalityQueryCache) getTTL(userID string) time.Duration { return c.limits.ResultsCacheTTLForCardinalityQuery(userID) } -func (c *cardinalityQueryCache) parseRequest(req *http.Request) (*genericQueryRequest, error) { +func (c *cardinalityQueryCache) parseRequest(path string, values url.Values) (*genericQueryRequest, error) { switch { - case strings.HasSuffix(req.URL.Path, cardinalityLabelNamesPathSuffix): - parsed, err := cardinality.DecodeLabelNamesRequest(req) + case strings.HasSuffix(path, cardinalityLabelNamesPathSuffix): + parsed, err := cardinality.DecodeLabelNamesRequestFromValues(values) if err != nil { return nil, err } @@ -48,8 +49,8 @@ func (c *cardinalityQueryCache) parseRequest(req *http.Request) (*genericQueryRe cacheKey: parsed.String(), cacheKeyPrefix: cardinalityLabelNamesQueryCachePrefix, }, nil - case strings.HasSuffix(req.URL.Path, cardinalityLabelValuesPathSuffix): - parsed, err := cardinality.DecodeLabelValuesRequest(req) + case strings.HasSuffix(path, cardinalityLabelValuesPathSuffix): + parsed, err := cardinality.DecodeLabelValuesRequestFromValues(values) if err != nil { return nil, err } diff --git a/pkg/frontend/querymiddleware/cardinality_query_cache_test.go b/pkg/frontend/querymiddleware/cardinality_query_cache_test.go index fce159873f1..e7ee7bfafe5 100644 --- a/pkg/frontend/querymiddleware/cardinality_query_cache_test.go +++ b/pkg/frontend/querymiddleware/cardinality_query_cache_test.go @@ -107,12 +107,14 @@ func TestCardinalityQueryCache_RoundTrip_WithTenantFederation(t *testing.T) { func TestCardinalityQueryCache_RoundTrip(t *testing.T) { testGenericQueryCacheRoundTrip(t, newCardinalityQueryCacheRoundTripper, "cardinality", map[string]testGenericQueryCacheRequestType{ "label names request": { - url: mustParseURL(t, `/prometheus/api/v1/cardinality/label_names?selector={job="test"}&limit=100`), + reqPath: "/prometheus/api/v1/cardinality/label_names", + reqData: url.Values{"selector": []string{`{job="test"}`}, "limit": []string{"100"}}, cacheKey: "user-1:job=\"test\"\x00100", hashedCacheKey: cardinalityLabelNamesQueryCachePrefix + cacheHashKey("user-1:job=\"test\"\x00100"), }, "label values request": { - url: mustParseURL(t, `/prometheus/api/v1/cardinality/label_values?selector={job="test"}&label_names[]=metric_1&label_names[]=metric_2&limit=100`), + reqPath: "/prometheus/api/v1/cardinality/label_values", + reqData: url.Values{"selector": []string{`{job="test"}`}, "label_names[]": []string{"metric_1", "metric_2"}, "limit": []string{"100"}}, cacheKey: "user-1:metric_1\x01metric_2\x00job=\"test\"\x00inmemory\x00100", hashedCacheKey: cardinalityLabelValuesQueryCachePrefix + cacheHashKey("user-1:metric_1\x01metric_2\x00job=\"test\"\x00inmemory\x00100"), }, diff --git a/pkg/frontend/querymiddleware/generic_query_cache.go b/pkg/frontend/querymiddleware/generic_query_cache.go index 805f02e660d..99cdb38bc80 100644 --- a/pkg/frontend/querymiddleware/generic_query_cache.go +++ b/pkg/frontend/querymiddleware/generic_query_cache.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net/http" + "net/url" "time" "github.com/go-kit/log" @@ -14,6 +15,7 @@ import ( "github.com/grafana/dskit/tenant" apierror "github.com/grafana/mimir/pkg/api/error" + "github.com/grafana/mimir/pkg/util" "github.com/grafana/mimir/pkg/util/spanlogger" "github.com/grafana/mimir/pkg/util/validation" ) @@ -28,8 +30,8 @@ type genericQueryRequest struct { } type genericQueryDelegate interface { - // parseRequest parses the input req and returns a genericQueryRequest, or an error if parsing fails. - parseRequest(req *http.Request) (*genericQueryRequest, error) + // parseRequest parses the input request and returns a genericQueryRequest, or an error if parsing fails. + parseRequest(path string, values url.Values) (*genericQueryRequest, error) // getTTL returns the cache TTL for the input userID. getTTL(userID string) time.Duration @@ -80,7 +82,14 @@ func (c *genericQueryCache) RoundTrip(req *http.Request) (*http.Response, error) } // Decode the request. - queryReq, err := c.delegate.parseRequest(req) + reqValues, err := util.ParseRequestFormWithoutConsumingBody(req) + if err != nil { + // This is considered a non-recoverable error, so we return error instead of passing + // the request to the downstream. + return nil, apierror.New(apierror.TypeBadData, err.Error()) + } + + queryReq, err := c.delegate.parseRequest(req.URL.Path, reqValues) if err != nil { // Logging as info because it's not an actionable error here. // We defer it to the downstream. diff --git a/pkg/frontend/querymiddleware/generic_query_cache_test.go b/pkg/frontend/querymiddleware/generic_query_cache_test.go index 23d9de5a362..7a936d307e3 100644 --- a/pkg/frontend/querymiddleware/generic_query_cache_test.go +++ b/pkg/frontend/querymiddleware/generic_query_cache_test.go @@ -26,7 +26,8 @@ import ( type newGenericQueryCacheFunc func(cache cache.Cache, limits Limits, next http.RoundTripper, logger log.Logger, reg prometheus.Registerer) http.RoundTripper type testGenericQueryCacheRequestType struct { - url *url.URL + reqPath string + reqData url.Values cacheKey string hashedCacheKey string } @@ -164,81 +165,114 @@ func testGenericQueryCacheRoundTrip(t *testing.T, newRoundTripper newGenericQuer for testName, testData := range tests { t.Run(testName, func(t *testing.T) { for reqName, reqData := range requestTypes { - t.Run(reqName, func(t *testing.T) { - // Mock the limits. - limits := multiTenantMockLimits{ - byTenant: map[string]mockLimits{ - userID: { - resultsCacheTTLForCardinalityQuery: testData.cacheTTL, - resultsCacheTTLForLabelsQuery: testData.cacheTTL, + for _, reqMethod := range []string{ /*http.MethodGet ,*/ http.MethodPost} { + t.Run(fmt.Sprintf("%s (%s)", reqName, reqMethod), func(t *testing.T) { + // Mock the limits. + limits := multiTenantMockLimits{ + byTenant: map[string]mockLimits{ + userID: { + resultsCacheTTLForCardinalityQuery: testData.cacheTTL, + resultsCacheTTLForLabelsQuery: testData.cacheTTL, + }, }, - }, - } + } - // Mock the downstream. - downstreamCalled := false - downstream := RoundTripFunc(func(request *http.Request) (*http.Response, error) { - downstreamCalled = true - return testData.downstreamRes(), testData.downstreamErr - }) + var ( + req *http.Request + downstreamCalled = false + downstreamReqParams url.Values + err error + ) - // Create the request. - req := &http.Request{URL: reqData.url, Header: testData.reqHeader} - req = req.WithContext(user.InjectOrgID(context.Background(), userID)) + // Mock the downstream and capture the request. + downstream := RoundTripFunc(func(req *http.Request) (*http.Response, error) { + downstreamCalled = true - // Init the cache. - cacheBackend := cache.NewInstrumentedMockCache() - if testData.init != nil { - testData.init(t, cacheBackend, reqData.cacheKey, reqData.hashedCacheKey) - } - initialStoreCallsCount := cacheBackend.CountStoreCalls() + // Parse the request body. + require.NoError(t, req.ParseForm()) + downstreamReqParams = req.Form - reg := prometheus.NewPedanticRegistry() - rt := newRoundTripper(cacheBackend, limits, downstream, testutil.NewLogger(t), reg) - res, err := rt.RoundTrip(req) - require.NoError(t, err) + return testData.downstreamRes(), testData.downstreamErr + }) - // Assert on the downstream. - assert.Equal(t, testData.expectedDownstreamCalled, downstreamCalled) + // Create the request. + switch reqMethod { + case http.MethodGet: + req, err = http.NewRequest(reqMethod, reqData.reqPath+"?"+reqData.reqData.Encode(), nil) + require.NoError(t, err) + case http.MethodPost: + req, err = http.NewRequest(reqMethod, reqData.reqPath, strings.NewReader(reqData.reqData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + require.NoError(t, err) + default: + t.Fatalf("unsupported HTTP method %q", reqMethod) + } - // Assert on the response received. - assert.Equal(t, testData.expectedStatusCode, res.StatusCode) - assert.Equal(t, testData.expectedHeader, res.Header) + for name, values := range testData.reqHeader { + for _, value := range values { + req.Header.Set(name, value) + } + } - actualBody, err := io.ReadAll(res.Body) - require.NoError(t, err) - assert.Equal(t, testData.expectedBody, actualBody) + // Inject the tenant ID in the request. + req = req.WithContext(user.InjectOrgID(context.Background(), userID)) - // Assert on the state of the cache. - if testData.expectedStoredToCache { - assert.Equal(t, initialStoreCallsCount+1, cacheBackend.CountStoreCalls()) + // Init the cache. + cacheBackend := cache.NewInstrumentedMockCache() + if testData.init != nil { + testData.init(t, cacheBackend, reqData.cacheKey, reqData.hashedCacheKey) + } + initialStoreCallsCount := cacheBackend.CountStoreCalls() + + reg := prometheus.NewPedanticRegistry() + rt := newRoundTripper(cacheBackend, limits, downstream, testutil.NewLogger(t), reg) + res, err := rt.RoundTrip(req) + require.NoError(t, err) - items := cacheBackend.GetItems() - require.Len(t, items, 1) - require.NotZero(t, items[reqData.hashedCacheKey]) + // Assert on the downstream. + assert.Equal(t, testData.expectedDownstreamCalled, downstreamCalled) + if testData.expectedDownstreamCalled { + assert.Equal(t, reqData.reqData, downstreamReqParams) + } - cached := CachedHTTPResponse{} - require.NoError(t, cached.Unmarshal(items[reqData.hashedCacheKey].Data)) - assert.Equal(t, testData.expectedStatusCode, int(cached.StatusCode)) - assert.Equal(t, testData.expectedHeader, DecodeCachedHTTPResponse(&cached).Header) - assert.Equal(t, testData.expectedBody, cached.Body) - assert.Equal(t, reqData.cacheKey, cached.CacheKey) - assert.WithinDuration(t, time.Now().Add(testData.cacheTTL), items[reqData.hashedCacheKey].ExpiresAt, 5*time.Second) - } else { - assert.Equal(t, initialStoreCallsCount, cacheBackend.CountStoreCalls()) - } + // Assert on the response received. + assert.Equal(t, testData.expectedStatusCode, res.StatusCode) + assert.Equal(t, testData.expectedHeader, res.Header) - // Assert on metrics. - expectedRequestsCount := 0 - expectedHitsCount := 0 - if testData.expectedLookupFromCache { - expectedRequestsCount = 1 - if !testData.expectedDownstreamCalled { - expectedHitsCount = 1 + actualBody, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, testData.expectedBody, actualBody) + + // Assert on the state of the cache. + if testData.expectedStoredToCache { + assert.Equal(t, initialStoreCallsCount+1, cacheBackend.CountStoreCalls()) + + items := cacheBackend.GetItems() + require.Len(t, items, 1) + require.NotZero(t, items[reqData.hashedCacheKey]) + + cached := CachedHTTPResponse{} + require.NoError(t, cached.Unmarshal(items[reqData.hashedCacheKey].Data)) + assert.Equal(t, testData.expectedStatusCode, int(cached.StatusCode)) + assert.Equal(t, testData.expectedHeader, DecodeCachedHTTPResponse(&cached).Header) + assert.Equal(t, testData.expectedBody, cached.Body) + assert.Equal(t, reqData.cacheKey, cached.CacheKey) + assert.WithinDuration(t, time.Now().Add(testData.cacheTTL), items[reqData.hashedCacheKey].ExpiresAt, 5*time.Second) + } else { + assert.Equal(t, initialStoreCallsCount, cacheBackend.CountStoreCalls()) } - } - assert.NoError(t, promtest.GatherAndCompare(reg, strings.NewReader(fmt.Sprintf(` + // Assert on metrics. + expectedRequestsCount := 0 + expectedHitsCount := 0 + if testData.expectedLookupFromCache { + expectedRequestsCount = 1 + if !testData.expectedDownstreamCalled { + expectedHitsCount = 1 + } + } + + assert.NoError(t, promtest.GatherAndCompare(reg, strings.NewReader(fmt.Sprintf(` # HELP cortex_frontend_query_result_cache_requests_total Total number of requests (or partial requests) looked up in the results cache. # TYPE cortex_frontend_query_result_cache_requests_total counter cortex_frontend_query_result_cache_requests_total{request_type="%s"} %d @@ -247,10 +281,11 @@ func testGenericQueryCacheRoundTrip(t *testing.T, newRoundTripper newGenericQuer # TYPE cortex_frontend_query_result_cache_hits_total counter cortex_frontend_query_result_cache_hits_total{request_type="%s"} %d `, requestTypeLabelValue, expectedRequestsCount, requestTypeLabelValue, expectedHitsCount)), - "cortex_frontend_query_result_cache_requests_total", - "cortex_frontend_query_result_cache_hits_total", - )) - }) + "cortex_frontend_query_result_cache_requests_total", + "cortex_frontend_query_result_cache_hits_total", + )) + }) + } } }) } diff --git a/pkg/frontend/querymiddleware/labels_query_cache.go b/pkg/frontend/querymiddleware/labels_query_cache.go index a31c4a90c7c..6858b18800f 100644 --- a/pkg/frontend/querymiddleware/labels_query_cache.go +++ b/pkg/frontend/querymiddleware/labels_query_cache.go @@ -5,6 +5,7 @@ package querymiddleware import ( "fmt" "net/http" + "net/url" "strings" "time" @@ -42,11 +43,7 @@ func (c *labelsQueryCache) getTTL(userID string) time.Duration { return c.limits.ResultsCacheTTLForLabelsQuery(userID) } -func (c *labelsQueryCache) parseRequest(req *http.Request) (*genericQueryRequest, error) { - if err := req.ParseForm(); err != nil { - return nil, err - } - +func (c *labelsQueryCache) parseRequest(path string, values url.Values) (*genericQueryRequest, error) { var ( cacheKeyPrefix string labelName string @@ -54,28 +51,28 @@ func (c *labelsQueryCache) parseRequest(req *http.Request) (*genericQueryRequest // Detect the request type switch { - case strings.HasSuffix(req.URL.Path, labelNamesPathSuffix): + case strings.HasSuffix(path, labelNamesPathSuffix): cacheKeyPrefix = labelNamesQueryCachePrefix - case labelValuesPathSuffix.MatchString(req.URL.Path): + case labelValuesPathSuffix.MatchString(path): cacheKeyPrefix = labelValuesQueryCachePrefix - labelName = labelValuesPathSuffix.FindStringSubmatch(req.URL.Path)[1] + labelName = labelValuesPathSuffix.FindStringSubmatch(path)[1] default: return nil, errors.New("unknown labels API endpoint") } // Both the label names and label values API endpoints support the same exact parameters (with the same defaults), // so in this function there's no distinction between the two. - startTime, err := parseRequestTimeParam(req, "start", util.PrometheusMinTime.UnixMilli()) + startTime, err := parseRequestTimeParam(values, "start", util.PrometheusMinTime.UnixMilli()) if err != nil { return nil, err } - endTime, err := parseRequestTimeParam(req, "end", util.PrometheusMaxTime.UnixMilli()) + endTime, err := parseRequestTimeParam(values, "end", util.PrometheusMaxTime.UnixMilli()) if err != nil { return nil, err } - matcherSets, err := parseRequestMatchersParam(req, "match[]") + matcherSets, err := parseRequestMatchersParam(values, "match[]") if err != nil { return nil, err } @@ -124,8 +121,12 @@ func generateLabelsQueryRequestCacheKey(startTime, endTime int64, labelName stri return b.String() } -func parseRequestTimeParam(req *http.Request, paramName string, defaultValue int64) (int64, error) { - value := req.FormValue(paramName) +func parseRequestTimeParam(values url.Values, paramName string, defaultValue int64) (int64, error) { + var value string + if len(values[paramName]) > 0 { + value = values[paramName][0] + } + if value == "" { return defaultValue, nil } @@ -138,10 +139,10 @@ func parseRequestTimeParam(req *http.Request, paramName string, defaultValue int return parsed, nil } -func parseRequestMatchersParam(req *http.Request, paramName string) ([][]*labels.Matcher, error) { - matcherSets := make([][]*labels.Matcher, 0, len(req.Form[paramName])) +func parseRequestMatchersParam(values url.Values, paramName string) ([][]*labels.Matcher, error) { + matcherSets := make([][]*labels.Matcher, 0, len(values[paramName])) - for _, value := range req.Form[paramName] { + for _, value := range values[paramName] { matchers, err := parser.ParseMetricSelector(value) if err != nil { diff --git a/pkg/frontend/querymiddleware/labels_query_cache_test.go b/pkg/frontend/querymiddleware/labels_query_cache_test.go index 7dc129ebba9..6ec646584fc 100644 --- a/pkg/frontend/querymiddleware/labels_query_cache_test.go +++ b/pkg/frontend/querymiddleware/labels_query_cache_test.go @@ -20,12 +20,14 @@ import ( func TestLabelsQueryCache_RoundTrip(t *testing.T) { testGenericQueryCacheRoundTrip(t, newLabelsQueryCacheRoundTripper, "label_names_and_values", map[string]testGenericQueryCacheRequestType{ "label names request": { - url: mustParseURL(t, `/prometheus/api/v1/labels?start=2023-07-05T01:00:00Z&end=2023-07-05T08:00:00Z&match[]={job="test_1"}&match[]={job!="test_2"}`), + reqPath: "/prometheus/api/v1/labels", + reqData: url.Values{"start": []string{"2023-07-05T01:00:00Z"}, "end": []string{"2023-07-05T08:00:00Z"}, "match[]": []string{`{job="test_1"}`, `{job!="test_2"}`}}, cacheKey: "user-1:1688515200000\x001688544000000\x00{job!=\"test_2\"},{job=\"test_1\"}", hashedCacheKey: labelNamesQueryCachePrefix + cacheHashKey("user-1:1688515200000\x001688544000000\x00{job!=\"test_2\"},{job=\"test_1\"}"), }, "label values request": { - url: mustParseURL(t, `/prometheus/api/v1/label/test/values?start=2023-07-05T01:00:00Z&end=2023-07-05T08:00:00Z&match[]={job="test_1"}&match[]={job!="test_2"}`), + reqPath: "/prometheus/api/v1/label/test/values", + reqData: url.Values{"start": []string{"2023-07-05T01:00:00Z"}, "end": []string{"2023-07-05T08:00:00Z"}, "match[]": []string{`{job="test_1"}`, `{job!="test_2"}`}}, cacheKey: "user-1:1688515200000\x001688544000000\x00test\x00{job!=\"test_2\"},{job=\"test_1\"}", hashedCacheKey: labelValuesQueryCachePrefix + cacheHashKey("user-1:1688515200000\x001688544000000\x00test\x00{job!=\"test_2\"},{job=\"test_1\"}"), }, @@ -142,11 +144,8 @@ func TestLabelsQueryCache_parseRequest(t *testing.T) { t.Run(testName, func(t *testing.T) { for requestTypeName, requestTypeData := range requestTypes { t.Run(requestTypeName, func(t *testing.T) { - req, err := http.NewRequest("GET", "http://localhost"+requestTypeData.requestPath+"?"+testData.params.Encode(), nil) - require.NoError(t, err) - c := &labelsQueryCache{} - actual, err := c.parseRequest(req) + actual, err := c.parseRequest(requestTypeData.requestPath, testData.params) require.NoError(t, err) assert.Equal(t, requestTypeData.expectedCacheKeyPrefix, actual.cacheKeyPrefix) @@ -325,7 +324,7 @@ func TestParseRequestMatchersParam(t *testing.T) { require.NoError(t, err) require.NoError(t, req.ParseForm()) - actual, err := parseRequestMatchersParam(req, paramName) + actual, err := parseRequestMatchersParam(req.Form, paramName) require.NoError(t, err) assert.Equal(t, testData.expected, actual) @@ -337,7 +336,7 @@ func TestParseRequestMatchersParam(t *testing.T) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") require.NoError(t, req.ParseForm()) - actual, err := parseRequestMatchersParam(req, "match[]") + actual, err := parseRequestMatchersParam(req.Form, "match[]") require.NoError(t, err) assert.Equal(t, testData.expected, actual) diff --git a/pkg/frontend/transport/handler.go b/pkg/frontend/transport/handler.go index e78fc198c40..cfd95459685 100644 --- a/pkg/frontend/transport/handler.go +++ b/pkg/frontend/transport/handler.go @@ -6,7 +6,6 @@ package transport import ( - "bytes" "context" "flag" "fmt" @@ -176,29 +175,15 @@ func (f *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { r = r.WithContext(ctx) } + // Ensure to close the request body reader. defer func() { _ = r.Body.Close() }() - // Store the body contents, so we can read it multiple times. - bodyBytes, err := io.ReadAll(http.MaxBytesReader(w, r.Body, f.cfg.MaxBodySize)) + params, err := util.ParseRequestFormWithoutConsumingBody(r) if err != nil { - writeError(w, err) - return - } - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - - // Parse the form, as it's needed to build the activity for the activity-tracker. - if err := r.ParseForm(); err != nil { writeError(w, apierror.New(apierror.TypeBadData, err.Error())) return } - // Store a copy of the params and restore the request state. - // Restore the body, so it can be read again if it's used to forward the request through a roundtripper. - // Restore the Form and PostForm, to avoid subtle bugs in middlewares, as they were set by ParseForm. - params := copyValues(r.Form) - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - r.Form, r.PostForm = nil, nil - activityIndex := f.at.Insert(func() string { return httpRequestActivity(r, params) }) defer f.at.Delete(activityIndex) diff --git a/pkg/util/http.go b/pkg/util/http.go index 32a3430e85f..0797fea08b0 100644 --- a/pkg/util/http.go +++ b/pkg/util/http.go @@ -14,6 +14,7 @@ import ( "html/template" "io" "net/http" + "net/url" "strings" "github.com/go-kit/log" @@ -296,3 +297,48 @@ func SerializeProtoResponse(w http.ResponseWriter, resp proto.Message, compressi } return nil } + +// ParseRequestFormWithoutConsumingBody parsed and returns the request parameters (query string and/or request body) +// from the input http.Request. If the request has a Body, the request's Body is replaces so that it can be consumed again. +func ParseRequestFormWithoutConsumingBody(r *http.Request) (url.Values, error) { + if r.Body == nil { + if err := r.ParseForm(); err != nil { + return nil, err + } + + return r.Form, nil + } + + // Close the original body reader. It's going to be replaced later in this function. + origBody := r.Body + defer func() { _ = origBody.Close() }() + + // Store the body contents, so we can read it multiple times. + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Parse the request data. + if err := r.ParseForm(); err != nil { + return nil, err + } + + // Store a copy of the params and restore the request state. + // Restore the body, so it can be read again if it's used to forward the request through a roundtripper. + // Restore the Form and PostForm, to avoid subtle bugs in middlewares, as they were set by ParseForm. + params := copyValues(r.Form) + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + r.Form, r.PostForm = nil, nil + + return params, nil +} + +func copyValues(src url.Values) url.Values { + dst := make(url.Values, len(src)) + for k, vs := range src { + dst[k] = append([]string(nil), vs...) + } + return dst +} diff --git a/pkg/util/http_test.go b/pkg/util/http_test.go index 3ff9548459b..79a2aba3841 100644 --- a/pkg/util/http_test.go +++ b/pkg/util/http_test.go @@ -13,7 +13,9 @@ import ( "math/rand" "net/http" "net/http/httptest" + "net/url" "strconv" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -232,3 +234,61 @@ func TestNewMsgSizeTooLargeErr(t *testing.T) { assert.Equal(t, msg, err.Error()) } + +func TestParseRequestFormWithoutConsumingBody(t *testing.T) { + expected := url.Values{ + "first": []string{"a", "b"}, + "second": []string{"c"}, + } + + t.Run("GET request", func(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost/?"+expected.Encode(), nil) + require.NoError(t, err) + + actual, err := util.ParseRequestFormWithoutConsumingBody(req) + require.NoError(t, err) + assert.Equal(t, expected, actual) + + // Parsing the request again should get the expected values. + require.NoError(t, req.ParseForm()) + assert.Equal(t, expected, req.Form) + }) + + t.Run("POST request", func(t *testing.T) { + origBody := newReadCloserObserver(io.NopCloser(strings.NewReader(expected.Encode()))) + req, err := http.NewRequest("POST", "http://localhost/", origBody) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + actual, err := util.ParseRequestFormWithoutConsumingBody(req) + require.NoError(t, err) + assert.Equal(t, expected, actual) + + // Since the original body has been consumed and discarded, it should have called Close() too. + assert.True(t, origBody.closeCalled) + + // The request should have been reset to a non-parsed state. + assert.Nil(t, req.Form) + assert.Nil(t, req.PostForm) + + // Parsing the request again should get the expected values. + require.NoError(t, req.ParseForm()) + assert.Equal(t, expected, req.Form) + }) +} + +type readCloserObserver struct { + io.ReadCloser + closeCalled bool +} + +func newReadCloserObserver(wrapped io.ReadCloser) *readCloserObserver { + return &readCloserObserver{ + ReadCloser: wrapped, + } +} + +func (o *readCloserObserver) Close() error { + o.closeCalled = true + return o.ReadCloser.Close() +} From e0abaa02393a96c05f53700cc0029c564d7df749 Mon Sep 17 00:00:00 2001 From: Marco Pracucci Date: Mon, 17 Jul 2023 15:36:31 +0200 Subject: [PATCH 2/5] Remove unused function Signed-off-by: Marco Pracucci --- pkg/frontend/transport/handler.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pkg/frontend/transport/handler.go b/pkg/frontend/transport/handler.go index cfd95459685..e9b9ad5deaf 100644 --- a/pkg/frontend/transport/handler.go +++ b/pkg/frontend/transport/handler.go @@ -365,11 +365,3 @@ func httpRequestActivity(request *http.Request, requestParams url.Values) string // This doesn't have to be pretty, just useful for debugging, so prioritize efficiency. return strings.Join([]string{tenantID, request.Method, request.URL.Path, params}, " ") } - -func copyValues(src url.Values) url.Values { - dst := make(url.Values, len(src)) - for k, vs := range src { - dst[k] = append([]string(nil), vs...) - } - return dst -} From a5438246c266badb165c574df932488cd70d57a0 Mon Sep 17 00:00:00 2001 From: Marco Pracucci Date: Mon, 17 Jul 2023 15:46:47 +0200 Subject: [PATCH 3/5] Fix Signed-off-by: Marco Pracucci --- pkg/frontend/transport/handler.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/frontend/transport/handler.go b/pkg/frontend/transport/handler.go index e9b9ad5deaf..091d85a916c 100644 --- a/pkg/frontend/transport/handler.go +++ b/pkg/frontend/transport/handler.go @@ -178,6 +178,9 @@ func (f *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Ensure to close the request body reader. defer func() { _ = r.Body.Close() }() + // Limit the read body size. + r.Body = http.MaxBytesReader(w, r.Body, f.cfg.MaxBodySize) + params, err := util.ParseRequestFormWithoutConsumingBody(r) if err != nil { writeError(w, apierror.New(apierror.TypeBadData, err.Error())) From 2bae1d325ca4231116176e76b7a41ca0c2afe672 Mon Sep 17 00:00:00 2001 From: Marco Pracucci Date: Mon, 17 Jul 2023 17:14:39 +0200 Subject: [PATCH 4/5] Uncommented test case Signed-off-by: Marco Pracucci --- pkg/frontend/querymiddleware/generic_query_cache_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/frontend/querymiddleware/generic_query_cache_test.go b/pkg/frontend/querymiddleware/generic_query_cache_test.go index 7a936d307e3..0cd5a348693 100644 --- a/pkg/frontend/querymiddleware/generic_query_cache_test.go +++ b/pkg/frontend/querymiddleware/generic_query_cache_test.go @@ -165,7 +165,7 @@ func testGenericQueryCacheRoundTrip(t *testing.T, newRoundTripper newGenericQuer for testName, testData := range tests { t.Run(testName, func(t *testing.T) { for reqName, reqData := range requestTypes { - for _, reqMethod := range []string{ /*http.MethodGet ,*/ http.MethodPost} { + for _, reqMethod := range []string{http.MethodGet, http.MethodPost} { t.Run(fmt.Sprintf("%s (%s)", reqName, reqMethod), func(t *testing.T) { // Mock the limits. limits := multiTenantMockLimits{ From 0dd0fa95fd2e185f2b4a070ebd9ceea77786ac33 Mon Sep 17 00:00:00 2001 From: Marco Pracucci Date: Mon, 17 Jul 2023 17:18:57 +0200 Subject: [PATCH 5/5] Improved integration test Signed-off-by: Marco Pracucci --- integration/querier_label_name_values_test.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/integration/querier_label_name_values_test.go b/integration/querier_label_name_values_test.go index a3096de39f6..1a4fb64cdf1 100644 --- a/integration/querier_label_name_values_test.go +++ b/integration/querier_label_name_values_test.go @@ -13,6 +13,7 @@ import ( "github.com/grafana/dskit/test" "github.com/grafana/e2e" + e2ecache "github.com/grafana/e2e/cache" e2edb "github.com/grafana/e2e/db" "github.com/prometheus/prometheus/model/labels" "github.com/prometheus/prometheus/prompb" @@ -98,16 +99,26 @@ func TestQuerierLabelNamesAndValues(t *testing.T) { require.NoError(t, err) defer s.Close() + // Start dependencies. + memcached := e2ecache.NewMemcached() + consul := e2edb.NewConsul() + require.NoError(t, s.StartAndWaitReady(consul, memcached)) + // Set configuration. flags := mergeFlags(BlocksStorageFlags(), BlocksStorageS3Flags(), map[string]string{ "-querier.cardinality-analysis-enabled": "true", "-ingester.ring.replication-factor": "3", + + // Enable the cardinality results cache with a very short TTL just to exercise its code. + "-query-frontend.cache-results": "true", + "-query-frontend.results-cache.backend": "memcached", + "-query-frontend.results-cache.memcached.addresses": "dns+" + memcached.NetworkEndpoint(e2ecache.MemcachedPort), + "-query-frontend.results-cache-ttl-for-cardinality-query": "1ms", }) - // Start dependencies. - consul := e2edb.NewConsul() + // Start minio. minio := e2edb.NewMinio(9000, flags["-blocks-storage.s3.bucket-name"]) - require.NoError(t, s.StartAndWaitReady(consul, minio)) + require.NoError(t, s.StartAndWaitReady(minio)) // Start the query-frontend. queryFrontend := e2emimir.NewQueryFrontend("query-frontend", flags)