diff --git a/pkg/distributor/query.go b/pkg/distributor/query.go index 75a86c5d7d1..64c28ce2c08 100644 --- a/pkg/distributor/query.go +++ b/pkg/distributor/query.go @@ -334,10 +334,7 @@ func (d *Distributor) queryIngesterStream(ctx context.Context, replicationSet ri } quorumConfig := d.queryQuorumConfig(ctx, replicationSet) - quorumConfig.IsTerminalError = func(err error) bool { - _, isLimitError := err.(validation.LimitError) - return isLimitError - } + quorumConfig.IsTerminalError = validation.IsLimitError results, err := ring.DoUntilQuorumWithoutSuccessfulContextCancellation(ctx, replicationSet, quorumConfig, queryIngester, cleanup) if err != nil { diff --git a/pkg/ingester/client/streaming.go b/pkg/ingester/client/streaming.go index 6a6faa0360c..2b5a69cd376 100644 --- a/pkg/ingester/client/streaming.go +++ b/pkg/ingester/client/streaming.go @@ -233,7 +233,7 @@ func (s *SeriesChunksStreamReader) readNextBatch(seriesIndex uint64) error { select { case err, haveError := <-s.errorChan: if haveError { - if _, ok := err.(validation.LimitError); ok { + if validation.IsLimitError(err) { return err } return fmt.Errorf("attempted to read series at index %v from ingester chunks stream, but the stream has failed: %w", seriesIndex, err) diff --git a/pkg/querier/block_streaming.go b/pkg/querier/block_streaming.go index 1206c73b5c5..6a02880cf1b 100644 --- a/pkg/querier/block_streaming.go +++ b/pkg/querier/block_streaming.go @@ -259,10 +259,10 @@ func (s *storeGatewayStreamReader) readStream(log *spanlogger.SpanLogger) error } totalChunks += numChunks if err := s.queryLimiter.AddChunks(numChunks); err != nil { - return validation.LimitError(err.Error()) + return err } if err := s.queryLimiter.AddChunkBytes(chunkBytes); err != nil { - return validation.LimitError(err.Error()) + return err } s.stats.AddFetchedChunks(uint64(numChunks)) @@ -365,7 +365,7 @@ func (s *storeGatewayStreamReader) readNextBatch(seriesIndex uint64) error { select { case err, haveError := <-s.errorChan: if haveError { - if _, ok := err.(validation.LimitError); ok { + if validation.IsLimitError(err) { return err } return errors.Wrapf(err, "attempted to read series at index %v from store-gateway chunks stream, but the stream has failed", seriesIndex) diff --git a/pkg/querier/blocks_store_queryable.go b/pkg/querier/blocks_store_queryable.go index a0571946b24..80c353ca0b5 100644 --- a/pkg/querier/blocks_store_queryable.go +++ b/pkg/querier/blocks_store_queryable.go @@ -51,7 +51,6 @@ import ( "github.com/grafana/mimir/pkg/util/limiter" util_log "github.com/grafana/mimir/pkg/util/log" "github.com/grafana/mimir/pkg/util/spanlogger" - "github.com/grafana/mimir/pkg/util/validation" ) const ( @@ -831,9 +830,8 @@ func (q *blocksStoreQuerier) fetchSeriesFromStores(ctx context.Context, sp *stor if ss := resp.GetStreamingSeries(); ss != nil { for _, s := range ss.Series { // Add series fingerprint to query limiter; will return error if we are over the limit - limitErr := queryLimiter.AddSeries(s.Labels) - if limitErr != nil { - return validation.LimitError(limitErr.Error()) + if limitErr := queryLimiter.AddSeries(s.Labels); limitErr != nil { + return limitErr } } myStreamingSeries = append(myStreamingSeries, ss.Series...) diff --git a/pkg/querier/error_translate_queryable_test.go b/pkg/querier/error_translate_queryable_test.go index 28e148190f8..18b3b1d12b1 100644 --- a/pkg/querier/error_translate_queryable_test.go +++ b/pkg/querier/error_translate_queryable_test.go @@ -44,7 +44,7 @@ func TestApiStatusCodes(t *testing.T) { }, { - err: validation.LimitError("limit exceeded"), + err: validation.NewLimitError("limit exceeded"), expectedString: "limit exceeded", expectedCode: 422, }, diff --git a/pkg/querier/errors.go b/pkg/querier/errors.go index d733f2befc5..180a497dbfb 100644 --- a/pkg/querier/errors.go +++ b/pkg/querier/errors.go @@ -18,7 +18,7 @@ var ( ) func NewMaxQueryLengthError(actualQueryLen, maxQueryLength time.Duration) validation.LimitError { - return validation.LimitError(globalerror.MaxQueryLength.MessageWithPerTenantLimitConfig( + return validation.NewLimitError(globalerror.MaxQueryLength.MessageWithPerTenantLimitConfig( fmt.Sprintf("the query time range exceeds the limit (query length: %s, limit: %s)", actualQueryLen, maxQueryLength), validation.MaxPartialQueryLengthFlag)) } diff --git a/pkg/storegateway/bucket_test.go b/pkg/storegateway/bucket_test.go index 92d86f252b5..44ffb00d647 100644 --- a/pkg/storegateway/bucket_test.go +++ b/pkg/storegateway/bucket_test.go @@ -270,7 +270,7 @@ func TestBlockLabelNames(t *testing.T) { slices.Sort(jNotFooLabelNames) sl := NewLimiter(math.MaxUint64, promauto.With(nil).NewCounter(prometheus.CounterOpts{Name: "test"}), func(limit uint64) validation.LimitError { - return validation.LimitError(fmt.Sprintf("exceeded unlimited limit of %v", limit)) + return validation.NewLimitError(fmt.Sprintf("exceeded unlimited limit of %v", limit)) }) newTestBucketBlock := prepareTestBlock(test.NewTB(t), appendTestSeries(series)) diff --git a/pkg/storegateway/limiter_test.go b/pkg/storegateway/limiter_test.go index f995db64f70..115c67f6065 100644 --- a/pkg/storegateway/limiter_test.go +++ b/pkg/storegateway/limiter_test.go @@ -22,7 +22,7 @@ import ( func TestLimiter(t *testing.T) { c := promauto.With(nil).NewCounter(prometheus.CounterOpts{}) l := NewLimiter(10, c, func(limit uint64) validation.LimitError { - return validation.LimitError(fmt.Sprintf("limit of %v exceeded", limit)) + return validation.NewLimitError(fmt.Sprintf("limit of %v exceeded", limit)) }) assert.NoError(t, l.Reserve(5)) diff --git a/pkg/util/limiter/errors.go b/pkg/util/limiter/errors.go index 9e406927eeb..3597939b7cb 100644 --- a/pkg/util/limiter/errors.go +++ b/pkg/util/limiter/errors.go @@ -33,7 +33,7 @@ var ( ) func limitError(format string, limit uint64) validation.LimitError { - return validation.LimitError(fmt.Sprintf(format, limit)) + return validation.NewLimitError(fmt.Sprintf(format, limit)) } func NewMaxSeriesHitLimitError(maxSeriesPerQuery uint64) validation.LimitError { diff --git a/pkg/util/limiter/query_limiter.go b/pkg/util/limiter/query_limiter.go index f6a193d92ba..fcd61b88869 100644 --- a/pkg/util/limiter/query_limiter.go +++ b/pkg/util/limiter/query_limiter.go @@ -13,6 +13,7 @@ import ( "github.com/grafana/mimir/pkg/mimirpb" "github.com/grafana/mimir/pkg/querier/stats" + "github.com/grafana/mimir/pkg/util/validation" ) type queryLimiterCtxKey struct{} @@ -73,7 +74,7 @@ func QueryLimiterFromContextWithFallback(ctx context.Context) *QueryLimiter { } // AddSeries adds the input series and returns an error if the limit is reached. -func (ql *QueryLimiter) AddSeries(seriesLabels []mimirpb.LabelAdapter) error { +func (ql *QueryLimiter) AddSeries(seriesLabels []mimirpb.LabelAdapter) validation.LimitError { // If the max series is unlimited just return without managing map if ql.maxSeriesPerQuery == 0 { return nil @@ -106,7 +107,7 @@ func (ql *QueryLimiter) uniqueSeriesCount() int { } // AddChunkBytes adds the input chunk size in bytes and returns an error if the limit is reached. -func (ql *QueryLimiter) AddChunkBytes(chunkSizeInBytes int) error { +func (ql *QueryLimiter) AddChunkBytes(chunkSizeInBytes int) validation.LimitError { if ql.maxChunkBytesPerQuery == 0 { return nil } @@ -124,7 +125,7 @@ func (ql *QueryLimiter) AddChunkBytes(chunkSizeInBytes int) error { return nil } -func (ql *QueryLimiter) AddChunks(count int) error { +func (ql *QueryLimiter) AddChunks(count int) validation.LimitError { if ql.maxChunksPerQuery == 0 { return nil } @@ -142,7 +143,7 @@ func (ql *QueryLimiter) AddChunks(count int) error { return nil } -func (ql *QueryLimiter) AddEstimatedChunks(count int) error { +func (ql *QueryLimiter) AddEstimatedChunks(count int) validation.LimitError { if ql.maxEstimatedChunksPerQuery == 0 { return nil } diff --git a/pkg/util/validation/limits.go b/pkg/util/validation/limits.go index cebca84cac9..f8b45149179 100644 --- a/pkg/util/validation/limits.go +++ b/pkg/util/validation/limits.go @@ -60,13 +60,31 @@ const ( MinCompactorPartialBlockDeletionDelay = 4 * time.Hour ) -// LimitError are errors that do not comply with the limits specified. -type LimitError string +// LimitError is a marker interface for the errors that do not comply with the specified limits. +type LimitError interface { + error + limitError() +} + +type limitErr string -func (e LimitError) Error() string { +// limitErr implements error and LimitError interfaces +func (e limitErr) Error() string { return string(e) } +// limitErr implements LimitError interface +func (e limitErr) limitError() {} + +func NewLimitError(msg string) LimitError { + return limitErr(msg) +} + +func IsLimitError(err error) bool { + var limitErr LimitError + return errors.As(err, &limitErr) +} + // Limits describe all the limits for users; can be used to describe global default // limits via flags, or per-user limits via yaml config. type Limits struct { diff --git a/pkg/util/validation/limits_test.go b/pkg/util/validation/limits_test.go index 58b6fffa791..833d8b3026e 100644 --- a/pkg/util/validation/limits_test.go +++ b/pkg/util/validation/limits_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/relabel" "github.com/stretchr/testify/assert" @@ -880,3 +881,29 @@ func TestExtensionMarshalling(t *testing.T) { require.Contains(t, string(val), `{"user":{"test_extension_struct":{"foo":42},"test_extension_string":"default string extension value","request_rate":0,"request_burst_size":0,`) }) } + +func TestIsLimitError(t *testing.T) { + const msg = "this is an error" + testCases := map[string]struct { + err error + expectedOutcome bool + }{ + "a random error is not a LimitError": { + err: errors.New(msg), + expectedOutcome: false, + }, + "errors implementing LimitError interface are LimitErrors": { + err: NewLimitError(msg), + expectedOutcome: true, + }, + "wrapped LimitErrors are LimitErrors": { + err: errors.Wrap(NewLimitError(msg), "wrapped"), + expectedOutcome: true, + }, + } + for testName, testData := range testCases { + t.Run(testName, func(t *testing.T) { + require.Equal(t, testData.expectedOutcome, IsLimitError(testData.err)) + }) + } +}