Skip to content

Commit

Permalink
Add TestMiddlewaresConsistency (#8373)
Browse files Browse the repository at this point in the history
* Add TestMiddlewaresConsistency

Signed-off-by: Marco Pracucci <marco@pracucci.com>

* Remove useless assignment

Signed-off-by: Marco Pracucci <marco@pracucci.com>

* Add PR number to CHANGELOG

Signed-off-by: Marco Pracucci <marco@pracucci.com>

---------

Signed-off-by: Marco Pracucci <marco@pracucci.com>
  • Loading branch information
pracucci authored Jun 14, 2024
1 parent dee6b35 commit 023e3a6
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 82 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
* [ENHANCEMENT] Distributor: add `insight=true` to remote-write and OTLP write handlers when the HTTP response status code is 4xx. #8294
* [ENHANCEMENT] Ingester: reduce locked time while matching postings for a label, improving the write latency and compaction speed. #8327
* [ENHANCEMENT] Ingester: reduce the amount of locks taken during the Head compaction's garbage-collection process, improving the write latency and compaction speed. #8327
* [ENHANCEMENT] Query-frontend: log the start, end time and matchers for remote read requests to the query stats logs. #8326 #8370
* [ENHANCEMENT] Query-frontend: log the start, end time and matchers for remote read requests to the query stats logs. #8326 #8370 #8373
* [BUGFIX] Distributor: prometheus retry on 5xx and 429 errors, while otlp collector only retry on 429, 502, 503 and 504, mapping other 5xx errors to the retryable ones in otlp endpoint. #8324 #8339
* [BUGFIX] Distributor: make OTLP endpoint return marshalled proto bytes as response body for 4xx/5xx errors. #8227
* [BUGFIX] Rules: improve error handling when querier is local to the ruler. #7567
Expand Down
38 changes: 25 additions & 13 deletions pkg/frontend/querymiddleware/instrumentation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"
)

type instrumentMiddleware struct {
next MetricsQueryHandler
name string
durationCol instrument.Collector
}

// newInstrumentMiddleware can be inserted into the middleware chain to expose timing information.
func newInstrumentMiddleware(name string, metrics *instrumentMiddlewareMetrics) MetricsQueryMiddleware {
var durationCol instrument.Collector
Expand All @@ -27,21 +33,27 @@ func newInstrumentMiddleware(name string, metrics *instrumentMiddlewareMetrics)
}

return MetricsQueryMiddlewareFunc(func(next MetricsQueryHandler) MetricsQueryHandler {
return HandlerFunc(func(ctx context.Context, req MetricsQueryRequest) (Response, error) {
var resp Response
err := instrument.CollectedRequest(ctx, name, durationCol, instrument.ErrorCode, func(ctx context.Context) error {
sp := opentracing.SpanFromContext(ctx)
if sp != nil {
req.AddSpanTags(sp)
}
return &instrumentMiddleware{
next: next,
name: name,
durationCol: durationCol,
}
})
}

func (h *instrumentMiddleware) Do(ctx context.Context, req MetricsQueryRequest) (Response, error) {
var resp Response
err := instrument.CollectedRequest(ctx, h.name, h.durationCol, instrument.ErrorCode, func(ctx context.Context) error {
sp := opentracing.SpanFromContext(ctx)
if sp != nil {
req.AddSpanTags(sp)
}

var err error
resp, err = next.Do(ctx, req)
return err
})
return resp, err
})
var err error
resp, err = h.next.Do(ctx, req)
return err
})
return resp, err
}

// instrumentMiddlewareMetrics holds the metrics tracked by newInstrumentMiddleware.
Expand Down
150 changes: 82 additions & 68 deletions pkg/frontend/querymiddleware/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,24 +217,6 @@ func newQueryTripperware(
// Experimental functions can only be enabled globally, and not on a per-engine basis.
parser.EnableExperimentalFunctions = engineExperimentalFunctionsEnabled

// Metric used to keep track of each middleware execution duration.
metrics := newInstrumentMiddlewareMetrics(registerer)
queryBlockerMiddleware := newQueryBlockerMiddleware(limits, log, registerer)
queryStatsMiddleware := newQueryStatsMiddleware(registerer, engine)

remoteReadMiddleware := []MetricsQueryMiddleware{
// Empty for now.
}

queryRangeMiddleware := []MetricsQueryMiddleware{
// Track query range statistics. Added first before any subsequent middleware modifies the request.
queryStatsMiddleware,
newLimitsMiddleware(limits, log),
queryBlockerMiddleware,
newInstrumentMiddleware("step_align", metrics),
newStepAlignMiddleware(limits, log, registerer),
}

var c cache.Cache
if cfg.CacheResults || cfg.cardinalityBasedShardingEnabled() {
var err error
Expand All @@ -251,6 +233,83 @@ func newQueryTripperware(
cacheKeyGenerator = NewDefaultCacheKeyGenerator(codec, cfg.SplitQueriesByInterval)
}

queryRangeMiddleware, queryInstantMiddleware, remoteReadMiddleware := newQueryMiddlewares(cfg, log, limits, codec, c, cacheKeyGenerator, cacheExtractor, engine, registerer)

return func(next http.RoundTripper) http.RoundTripper {
queryrange := newLimitedParallelismRoundTripper(next, codec, limits, queryRangeMiddleware...)
instant := newLimitedParallelismRoundTripper(next, codec, limits, queryInstantMiddleware...)
remoteRead := newRemoteReadRoundTripper(next, remoteReadMiddleware...)

// Wrap next for cardinality, labels queries and all other queries.
// That attempts to parse "start" and "end" from the HTTP request and set them in the request's QueryDetails.
// range and instant queries have more accurate logic for query details.
next = newQueryDetailsStartEndRoundTripper(next)
cardinality := next
activeSeries := next
activeNativeHistogramMetrics := next
labels := next

// Inject the cardinality and labels query cache roundtripper only if the query results cache is enabled.
if cfg.CacheResults {
cardinality = newCardinalityQueryCacheRoundTripper(c, cacheKeyGenerator, limits, cardinality, log, registerer)
labels = newLabelsQueryCacheRoundTripper(c, cacheKeyGenerator, limits, labels, log, registerer)
}

if cfg.ShardActiveSeriesQueries {
activeSeries = newShardActiveSeriesMiddleware(activeSeries, cfg.UseActiveSeriesDecoder, limits, log)
activeNativeHistogramMetrics = newShardActiveNativeHistogramMetricsMiddleware(activeNativeHistogramMetrics, limits, log)
}

return RoundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case IsRangeQuery(r.URL.Path):
return queryrange.RoundTrip(r)
case IsInstantQuery(r.URL.Path):
return instant.RoundTrip(r)
case IsCardinalityQuery(r.URL.Path):
return cardinality.RoundTrip(r)
case IsActiveSeriesQuery(r.URL.Path):
return activeSeries.RoundTrip(r)
case IsActiveNativeHistogramMetricsQuery(r.URL.Path):
return activeNativeHistogramMetrics.RoundTrip(r)
case IsLabelsQuery(r.URL.Path):
return labels.RoundTrip(r)
case IsRemoteReadQuery(r.URL.Path):
return remoteRead.RoundTrip(r)
default:
return next.RoundTrip(r)
}
})
}, nil
}

// newQueryMiddlewares creates and returns the middlewares that should injected for each type of request
// handled by the query-frontend.
func newQueryMiddlewares(
cfg Config,
log log.Logger,
limits Limits,
codec Codec,
cacheClient cache.Cache,
cacheKeyGenerator CacheKeyGenerator,
cacheExtractor Extractor,
engine *promql.Engine,
registerer prometheus.Registerer,
) (queryRangeMiddleware, queryInstantMiddleware, remoteReadMiddleware []MetricsQueryMiddleware) {
// Metric used to keep track of each middleware execution duration.
metrics := newInstrumentMiddlewareMetrics(registerer)
queryBlockerMiddleware := newQueryBlockerMiddleware(limits, log, registerer)
queryStatsMiddleware := newQueryStatsMiddleware(registerer, engine)

queryRangeMiddleware = append(queryRangeMiddleware,
// Track query range statistics. Added first before any subsequent middleware modifies the request.
queryStatsMiddleware,
newLimitsMiddleware(limits, log),
queryBlockerMiddleware,
newInstrumentMiddleware("step_align", metrics),
newStepAlignMiddleware(limits, log, registerer),
)

// Inject the middleware to split requests by interval + results cache (if at least one of the two is enabled).
if cfg.SplitQueriesByInterval > 0 || cfg.CacheResults {
shouldCache := func(r MetricsQueryRequest) bool {
Expand All @@ -263,7 +322,7 @@ func newQueryTripperware(
cfg.SplitQueriesByInterval,
limits,
codec,
c,
cacheClient,
cacheKeyGenerator,
cacheExtractor,
shouldCache,
Expand All @@ -272,20 +331,20 @@ func newQueryTripperware(
))
}

queryInstantMiddleware := []MetricsQueryMiddleware{
queryInstantMiddleware = append(queryInstantMiddleware,
// Track query range statistics. Added first before any subsequent middleware modifies the request.
queryStatsMiddleware,
newLimitsMiddleware(limits, log),
newSplitInstantQueryByIntervalMiddleware(limits, log, engine, registerer),
queryBlockerMiddleware,
}
)

if cfg.ShardedQueries {
// Inject the cardinality estimation middleware after time-based splitting and
// before query-sharding so that it can operate on the partial queries that are
// considered for sharding.
if cfg.cardinalityBasedShardingEnabled() {
cardinalityEstimationMiddleware := newCardinalityEstimationMiddleware(c, log, registerer)
cardinalityEstimationMiddleware := newCardinalityEstimationMiddleware(cacheClient, log, registerer)
queryRangeMiddleware = append(
queryRangeMiddleware,
newInstrumentMiddleware("cardinality_estimation", metrics),
Expand Down Expand Up @@ -324,52 +383,7 @@ func newQueryTripperware(
queryInstantMiddleware = append(queryInstantMiddleware, newInstrumentMiddleware("retry", metrics), newRetryMiddleware(log, cfg.MaxRetries, retryMiddlewareMetrics))
}

return func(next http.RoundTripper) http.RoundTripper {
queryrange := newLimitedParallelismRoundTripper(next, codec, limits, queryRangeMiddleware...)
instant := newLimitedParallelismRoundTripper(next, codec, limits, queryInstantMiddleware...)
remoteRead := newRemoteReadRoundTripper(next, remoteReadMiddleware...)

// Wrap next for cardinality, labels queries and all other queries.
// That attempts to parse "start" and "end" from the HTTP request and set them in the request's QueryDetails.
// range and instant queries have more accurate logic for query details.
next = newQueryDetailsStartEndRoundTripper(next)
cardinality := next
activeSeries := next
activeNativeHistogramMetrics := next
labels := next

// Inject the cardinality and labels query cache roundtripper only if the query results cache is enabled.
if cfg.CacheResults {
cardinality = newCardinalityQueryCacheRoundTripper(c, cacheKeyGenerator, limits, cardinality, log, registerer)
labels = newLabelsQueryCacheRoundTripper(c, cacheKeyGenerator, limits, labels, log, registerer)
}

if cfg.ShardActiveSeriesQueries {
activeSeries = newShardActiveSeriesMiddleware(activeSeries, cfg.UseActiveSeriesDecoder, limits, log)
activeNativeHistogramMetrics = newShardActiveNativeHistogramMetricsMiddleware(activeNativeHistogramMetrics, limits, log)
}

return RoundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case IsRangeQuery(r.URL.Path):
return queryrange.RoundTrip(r)
case IsInstantQuery(r.URL.Path):
return instant.RoundTrip(r)
case IsCardinalityQuery(r.URL.Path):
return cardinality.RoundTrip(r)
case IsActiveSeriesQuery(r.URL.Path):
return activeSeries.RoundTrip(r)
case IsActiveNativeHistogramMetricsQuery(r.URL.Path):
return activeNativeHistogramMetrics.RoundTrip(r)
case IsLabelsQuery(r.URL.Path):
return labels.RoundTrip(r)
case IsRemoteReadQuery(r.URL.Path):
return remoteRead.RoundTrip(r)
default:
return next.RoundTrip(r)
}
})
}, nil
return
}

// newQueryDetailsStartEndRoundTripper parses "start" and "end" parameters from the query and sets same fields in the QueryDetails in the context.
Expand Down
97 changes: 97 additions & 0 deletions pkg/frontend/querymiddleware/roundtrip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"slices"
"strconv"
"strings"
"testing"
"time"

"github.com/go-kit/log"
"github.com/grafana/dskit/flagext"
"github.com/grafana/dskit/middleware"
"github.com/grafana/dskit/user"
"github.com/prometheus/client_golang/api"
Expand Down Expand Up @@ -436,6 +439,100 @@ func TestTripperware_Metrics(t *testing.T) {
}
}

// TestMiddlewaresConsistency ensures that we don't forget to add a middleware to a given type of request
// (e.g. range query, remote read, ...) when a new middleware is added. By default, it expects that a middleware
// is added to each type of request, and then it allows to define exceptions when we intentionally don't
// want a given middleware to be used for a specific request.
func TestMiddlewaresConsistency(t *testing.T) {
cfg := Config{}
flagext.DefaultValues(&cfg)
cfg.CacheResults = true
cfg.ShardedQueries = true

// Ensure all features are enabled, so that we assert on all middlewares.
require.NotZero(t, cfg.CacheResults)
require.NotZero(t, cfg.ShardedQueries)
require.NotZero(t, cfg.SplitQueriesByInterval)
require.NotZero(t, cfg.MaxRetries)

queryRangeMiddlewares, queryInstantMiddlewares, remoteReadMiddlewares := newQueryMiddlewares(
cfg,
log.NewNopLogger(),
mockLimits{
alignQueriesWithStep: true,
},
newTestPrometheusCodec(),
nil,
nil,
nil,
promql.NewEngine(promql.EngineOpts{}),
nil,
)

middlewaresByRequestType := map[string]struct {
instances []MetricsQueryMiddleware
exceptions []string
}{
"instant query": {
instances: queryInstantMiddlewares,
exceptions: []string{"splitAndCacheMiddleware", "stepAlignMiddleware"},
},
"range query": {
instances: queryRangeMiddlewares,
exceptions: []string{"splitInstantQueryByIntervalMiddleware"},
},
"remote read": {
instances: remoteReadMiddlewares,
exceptions: []string{"instrumentMiddleware", "limitsMiddleware", "queryBlockerMiddleware", "querySharding", "queryStatsMiddleware", "retry", "splitAndCacheMiddleware", "splitInstantQueryByIntervalMiddleware", "stepAlignMiddleware"},
},
}

// Utility to get the name of the struct.
getName := func(i interface{}) string {
t := reflect.TypeOf(i)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t.Name()
}

// Utility to get the names of middlewares.
getMiddlewareNames := func(middlewares []MetricsQueryMiddleware) (names []string) {
for _, middleware := range middlewares {
handler := middleware.Wrap(&mockHandler{})
name := getName(handler)

names = append(names, name)
}

// Unique names.
slices.Sort(names)
names = slices.Compact(names)

return
}

// Get the (unique) names of all middlewares.
var allNames []string
for _, middlewares := range middlewaresByRequestType {
allNames = append(allNames, getMiddlewareNames(middlewares.instances)...)
}
slices.Sort(allNames)
allNames = slices.Compact(allNames)

// Ensure that all request types implements all middlewares, except exclusions.
for requestType, middlewares := range middlewaresByRequestType {
t.Run(requestType, func(t *testing.T) {
actualNames := getMiddlewareNames(middlewares.instances)
expectedNames := slices.DeleteFunc(slices.Clone(allNames), func(s string) bool {
return slices.Contains(middlewares.exceptions, s)
})

assert.ElementsMatch(t, expectedNames, actualNames)
})
}
}

func TestConfig_Validate(t *testing.T) {
tests := map[string]struct {
config Config
Expand Down

0 comments on commit 023e3a6

Please sign in to comment.