Skip to content

Commit

Permalink
Use thread safe chunk iterator in sharding tests (#7504)
Browse files Browse the repository at this point in the history
The PromQL engine currently uses pooling of histogram pointers, optimized
for reading chunks where the iterator returns a copy of the data.
However in the sharding tests we run multiple engines on the same data
with an iterator from promql.StorageSeries that does not copy the histograms.
So it can happen that pointers are reused between the goroutines and
end up in a race.

Signed-off-by: György Krajcsovits <[email protected]>
  • Loading branch information
krajorama authored Dec 19, 2024
1 parent 7fa48db commit 8d0af42
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 24 deletions.
108 changes: 88 additions & 20 deletions pkg/frontend/querymiddleware/querysharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/prometheus/prometheus/promql"
"github.com/prometheus/prometheus/promql/parser"
"github.com/prometheus/prometheus/storage"
"github.com/prometheus/prometheus/tsdb/chunkenc"
"github.com/prometheus/prometheus/util/annotations"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -623,7 +624,7 @@ func TestQuerySharding_Correctness(t *testing.T) {
},
}

series := make([]*promql.StorageSeries, 0, numSeries+(numConvHistograms*len(histogramBuckets))+numNativeHistograms)
series := make([]storage.Series, 0, numSeries+(numConvHistograms*len(histogramBuckets))+numNativeHistograms)
seriesID := 0

// Add counter series.
Expand Down Expand Up @@ -791,7 +792,7 @@ func TestQuerySharding_NonMonotonicHistogramBuckets(t *testing.T) {
`histogram_quantile(1, sum by(le) (rate(metric_histogram_bucket[1m])))`,
}

series := []*promql.StorageSeries{}
series := []storage.Series{}
for i := 0; i < 100; i++ {
series = append(series, newSeries(labels.FromStrings(labels.MetricName, "metric_histogram_bucket", "app", strconv.Itoa(i), "le", "10"), start.Add(-lookbackDelta), end, step, arithmeticSequence(1)))
series = append(series, newSeries(labels.FromStrings(labels.MetricName, "metric_histogram_bucket", "app", strconv.Itoa(i), "le", "20"), start.Add(-lookbackDelta), end, step, arithmeticSequence(3)))
Expand Down Expand Up @@ -913,7 +914,7 @@ func TestQueryshardingDeterminism(t *testing.T) {
)

labelsForShard := labelsForShardsGenerator([]labels.Label{{Name: labels.MetricName, Value: "metric"}}, shards)
storageSeries := []*promql.StorageSeries{
storageSeries := []storage.Series{
newSeries(labelsForShard(0), from, to, step, constant(evilFloatA)),
newSeries(labelsForShard(1), from, to, step, constant(evilFloatA)),
newSeries(labelsForShard(2), from, to, step, constant(evilFloatB)),
Expand Down Expand Up @@ -1066,7 +1067,7 @@ func TestQuerySharding_FunctionCorrectness(t *testing.T) {
}

t.Run("floats", func(t *testing.T) {
queryableFloats := storageSeriesQueryable([]*promql.StorageSeries{
queryableFloats := storageSeriesQueryable([]storage.Series{
newSeries(labels.FromStrings("__name__", "bar1", "baz", "blip", "bar", "blop", "foo", "barr"), start.Add(-lookbackDelta), end, step, factor(5)),
newSeries(labels.FromStrings("__name__", "bar1", "baz", "blip", "bar", "blop", "foo", "bazz"), start.Add(-lookbackDelta), end, step, factor(7)),
newSeries(labels.FromStrings("__name__", "bar1", "baz", "blip", "bar", "blap", "foo", "buzz"), start.Add(-lookbackDelta), end, step, factor(12)),
Expand All @@ -1078,7 +1079,7 @@ func TestQuerySharding_FunctionCorrectness(t *testing.T) {
})

t.Run("native histograms", func(t *testing.T) {
queryableNativeHistograms := storageSeriesQueryable([]*promql.StorageSeries{
queryableNativeHistograms := storageSeriesQueryable([]storage.Series{
newNativeHistogramSeries(labels.FromStrings("__name__", "bar1", "baz", "blip", "bar", "blop", "foo", "barr"), start.Add(-lookbackDelta), end, step, factor(5)),
newNativeHistogramSeries(labels.FromStrings("__name__", "bar1", "baz", "blip", "bar", "blop", "foo", "bazz"), start.Add(-lookbackDelta), end, step, factor(7)),
newNativeHistogramSeries(labels.FromStrings("__name__", "bar1", "baz", "blip", "bar", "blap", "foo", "buzz"), start.Add(-lookbackDelta), end, step, factor(12)),
Expand Down Expand Up @@ -1548,7 +1549,7 @@ func TestQuerySharding_ShouldReturnErrorInCorrectFormat(t *testing.T) {
queryablePrometheusExecErr = storage.QueryableFunc(func(int64, int64) (storage.Querier, error) {
return nil, apierror.Newf(apierror.TypeExec, "expanding series: %s", querier.NewMaxQueryLengthError(744*time.Hour, 720*time.Hour))
})
queryable = storageSeriesQueryable([]*promql.StorageSeries{
queryable = storageSeriesQueryable([]storage.Series{
newSeries(labels.FromStrings("__name__", "bar1"), start.Add(-lookbackDelta), end, step, factor(5)),
})
queryableSlow = newMockShardedQueryable(
Expand Down Expand Up @@ -1654,7 +1655,7 @@ func TestQuerySharding_EngineErrorMapping(t *testing.T) {
engine = newEngine()
)

series := make([]*promql.StorageSeries, 0, numSeries)
series := make([]storage.Series, 0, numSeries)
for i := 0; i < numSeries; i++ {
series = append(series, newSeries(newTestCounterLabels(i), start.Add(-lookbackDelta), end, step, factor(float64(i)*0.1)))
}
Expand Down Expand Up @@ -1752,7 +1753,7 @@ func TestQuerySharding_ShouldUseCardinalityEstimate(t *testing.T) {
func TestQuerySharding_Annotations(t *testing.T) {
numSeries := 10
endTime := 100
storageSeries := make([]*promql.StorageSeries, 0, numSeries)
storageSeries := make([]storage.Series, 0, numSeries)
floats := make([]promql.FPoint, 0, endTime)
for i := 0; i < endTime; i++ {
floats = append(floats, promql.FPoint{
Expand Down Expand Up @@ -2253,14 +2254,14 @@ func (h *downstreamHandler) Do(ctx context.Context, r MetricsQueryRequest) (Resp
return resp, nil
}

func storageSeriesQueryable(series []*promql.StorageSeries) storage.Queryable {
func storageSeriesQueryable(series []storage.Series) storage.Queryable {
return storage.QueryableFunc(func(int64, int64) (storage.Querier, error) {
return &querierMock{series: series}, nil
})
}

type querierMock struct {
series []*promql.StorageSeries
series []storage.Series
}

func (m *querierMock) Select(_ context.Context, sorted bool, _ *storage.SelectHints, matchers ...*labels.Matcher) storage.SeriesSet {
Expand All @@ -2270,7 +2271,7 @@ func (m *querierMock) Select(_ context.Context, sorted bool, _ *storage.SelectHi
}

// Filter series by label matchers.
var filtered []*promql.StorageSeries
var filtered []storage.Series

for _, series := range m.series {
if seriesMatches(series, matchers...) {
Expand Down Expand Up @@ -2301,7 +2302,7 @@ func (m *querierMock) LabelNames(context.Context, *storage.LabelHints, ...*label

func (m *querierMock) Close() error { return nil }

func seriesMatches(series *promql.StorageSeries, matchers ...*labels.Matcher) bool {
func seriesMatches(series storage.Series, matchers ...*labels.Matcher) bool {
for _, m := range matchers {
if !m.Matches(series.Labels().Get(m.Name)) {
return false
Expand All @@ -2311,12 +2312,12 @@ func seriesMatches(series *promql.StorageSeries, matchers ...*labels.Matcher) bo
return true
}

func filterSeriesByShard(series []*promql.StorageSeries, shard *sharding.ShardSelector) []*promql.StorageSeries {
func filterSeriesByShard(series []storage.Series, shard *sharding.ShardSelector) []storage.Series {
if shard == nil {
return series
}

var filtered []*promql.StorageSeries
var filtered []storage.Series

for _, s := range series {
if labels.StableHash(s.Labels())%shard.ShardCount == shard.ShardIndex {
Expand All @@ -2327,15 +2328,15 @@ func filterSeriesByShard(series []*promql.StorageSeries, shard *sharding.ShardSe
return filtered
}

func newSeries(metric labels.Labels, from, to time.Time, step time.Duration, gen generator) *promql.StorageSeries {
func newSeries(metric labels.Labels, from, to time.Time, step time.Duration, gen generator) storage.Series {
return newSeriesInner(metric, from, to, step, gen, false)
}

func newNativeHistogramSeries(metric labels.Labels, from, to time.Time, step time.Duration, gen generator) *promql.StorageSeries {
func newNativeHistogramSeries(metric labels.Labels, from, to time.Time, step time.Duration, gen generator) storage.Series {
return newSeriesInner(metric, from, to, step, gen, true)
}

func newSeriesInner(metric labels.Labels, from, to time.Time, step time.Duration, gen generator, histogram bool) *promql.StorageSeries {
func newSeriesInner(metric labels.Labels, from, to time.Time, step time.Duration, gen generator, histogram bool) storage.Series {
var (
floats []promql.FPoint
histograms []promql.HPoint
Expand Down Expand Up @@ -2367,7 +2368,7 @@ func newSeriesInner(metric labels.Labels, from, to time.Time, step time.Duration
}
}

return promql.NewStorageSeries(promql.Series{
return NewThreadSafeStorageSeries(promql.Series{
Metric: metric,
Floats: floats,
Histograms: histograms,
Expand Down Expand Up @@ -2478,10 +2479,10 @@ func constant(value float64) generator {

type seriesIteratorMock struct {
idx int
series []*promql.StorageSeries
series []storage.Series
}

func newSeriesIteratorMock(series []*promql.StorageSeries) *seriesIteratorMock {
func newSeriesIteratorMock(series []storage.Series) *seriesIteratorMock {
return &seriesIteratorMock{
idx: -1,
series: series,
Expand Down Expand Up @@ -2509,6 +2510,73 @@ func (i *seriesIteratorMock) Warnings() annotations.Annotations {
return nil
}

// Usually series are read by a single engine in a single goroutine but in
// sharding tests we have multiple engines in multiple goroutines. Thus we
// need a series iterator that doesn't share pointers between goroutines.
type ThreadSafeStorageSeries struct {
storageSeries *promql.StorageSeries
}

// NewStorageSeries returns a StorageSeries from a Series.
func NewThreadSafeStorageSeries(series promql.Series) *ThreadSafeStorageSeries {
return &ThreadSafeStorageSeries{
storageSeries: promql.NewStorageSeries(series),
}
}

func (ss *ThreadSafeStorageSeries) Labels() labels.Labels {
return ss.storageSeries.Labels()
}

// Iterator returns a new iterator of the data of the series. In case of
// multiple samples with the same timestamp, it returns the float samples first.
func (ss *ThreadSafeStorageSeries) Iterator(it chunkenc.Iterator) chunkenc.Iterator {
if ssi, ok := it.(*ThreadSafeStorageSeriesIterator); ok {
return &ThreadSafeStorageSeriesIterator{underlying: ss.storageSeries.Iterator(ssi.underlying)}
}
return &ThreadSafeStorageSeriesIterator{underlying: ss.storageSeries.Iterator(nil)}
}

type ThreadSafeStorageSeriesIterator struct {
underlying chunkenc.Iterator
}

func (ssi *ThreadSafeStorageSeriesIterator) Seek(t int64) chunkenc.ValueType {
return ssi.underlying.Seek(t)
}

func (ssi *ThreadSafeStorageSeriesIterator) At() (t int64, v float64) {
return ssi.underlying.At()
}

func (ssi *ThreadSafeStorageSeriesIterator) AtHistogram(*histogram.Histogram) (int64, *histogram.Histogram) {
panic(errors.New("storageSeriesIterator: AtHistogram not supported"))
}

// AtFloatHistogram returns the timestamp and the float histogram at the current position.
// This is different from the underlying iterator in that it does a copy so that the user
// can modify the returned histogram without affecting the underlying series.
func (ssi *ThreadSafeStorageSeriesIterator) AtFloatHistogram(toFH *histogram.FloatHistogram) (int64, *histogram.FloatHistogram) {
t, fh := ssi.underlying.AtFloatHistogram(nil)
if toFH == nil {
return t, fh.Copy()
}
fh.CopyTo(toFH)
return t, toFH
}

func (ssi *ThreadSafeStorageSeriesIterator) AtT() int64 {
return ssi.underlying.AtT()
}

func (ssi *ThreadSafeStorageSeriesIterator) Next() chunkenc.ValueType {
return ssi.underlying.Next()
}

func (ssi *ThreadSafeStorageSeriesIterator) Err() error {
return nil
}

// newEngine creates and return a new promql.Engine used for testing.
func newEngine() *promql.Engine {
return promql.NewEngine(promql.EngineOpts{
Expand Down
4 changes: 2 additions & 2 deletions pkg/frontend/querymiddleware/split_and_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import (
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/model/histogram"
"github.com/prometheus/prometheus/promql"
"github.com/prometheus/prometheus/promql/parser"
"github.com/prometheus/prometheus/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand Down Expand Up @@ -931,7 +931,7 @@ func TestSplitAndCacheMiddleware_ResultsCacheFuzzy(t *testing.T) {
step := 2 * time.Minute

// Generate series.
series := make([]*promql.StorageSeries, 0, numSeries)
series := make([]storage.Series, 0, numSeries)
for i := 0; i < numSeries; i++ {
series = append(series, newSeries(newTestCounterLabels(i), minTime, maxTime, step, factor(float64(i))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
"github.com/grafana/dskit/user"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/prometheus/prometheus/promql"
"github.com/prometheus/prometheus/promql/parser"
"github.com/prometheus/prometheus/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -434,7 +434,7 @@ func TestInstantQuerySplittingCorrectness(t *testing.T) {
},
}

series := make([]*promql.StorageSeries, 0, numSeries+(numConvHistograms*len(histogramBuckets))+numNativeHistograms)
series := make([]storage.Series, 0, numSeries+(numConvHistograms*len(histogramBuckets))+numNativeHistograms)
seriesID := 0
end := start.Add(30 * time.Minute)

Expand Down

0 comments on commit 8d0af42

Please sign in to comment.