diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index 812660c32f8ff..c7eaed7f7c815 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -41,6 +41,7 @@ func (s *testSuite) createSelectNormal(batch, totalRows int, c *C) (*selectResul SetDesc(false). SetKeepOrder(false). SetFromSessionVars(variable.NewSessionVars()). + SetMemTracker(s.sctx, stringutil.StringerStr("testSuite.createSelectNormal")). Build() c.Assert(err, IsNil) @@ -94,6 +95,21 @@ func (s *testSuite) TestSelectNormal(c *C) { c.Assert(numAllRows, Equals, 2) err := response.Close() c.Assert(err, IsNil) + c.Assert(response.memTracker.BytesConsumed(), Equals, int64(0)) +} + +func (s *testSuite) TestSelectMemTracker(c *C) { + response, colTypes := s.createSelectNormal(2, 6, c, nil) + response.Fetch(context.TODO()) + + // Test Next. + chk := chunk.New(colTypes, 3, 3) + err := response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.IsFull(), Equals, true) + err = response.Close() + c.Assert(err, IsNil) + c.Assert(response.memTracker.BytesConsumed(), Equals, int64(0)) } func (s *testSuite) TestSelectNormalChunkSize(c *C) { @@ -101,6 +117,7 @@ func (s *testSuite) TestSelectNormalChunkSize(c *C) { response.Fetch(context.TODO()) s.testChunkSize(response, colTypes, c) c.Assert(response.Close(), IsNil) + c.Assert(response.memTracker.BytesConsumed(), Equals, int64(0)) } func (s *testSuite) createSelectStreaming(batch, totalRows int, c *C) (*streamResult, []*types.FieldType) { diff --git a/distsql/request_builder_test.go b/distsql/request_builder_test.go index b64dd63218892..a2b472b5ad833 100644 --- a/distsql/request_builder_test.go +++ b/distsql/request_builder_test.go @@ -26,8 +26,10 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/ranger" + "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tipb/go-tipb" ) @@ -49,6 +51,9 @@ type testSuite struct { func (s *testSuite) SetUpSuite(c *C) { ctx := mock.NewContext() + ctx.GetSessionVars().StmtCtx = &stmtctx.StatementContext{ + MemTracker: memory.NewTracker(stringutil.StringerStr("testSuite"), variable.DefTiDBMemQuotaDistSQL), + } ctx.Store = &mock.Store{ Client: &mock.Client{ MockResponse: &mockResponse{ diff --git a/distsql/select_result.go b/distsql/select_result.go index b11afbee38f7e..c558d0a46cd59 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -66,8 +66,9 @@ type selectResult struct { fieldTypes []*types.FieldType ctx sessionctx.Context - selectResp *tipb.SelectResponse - respChkIdx int + selectResp *tipb.SelectResponse + selectRespSize int // record the selectResp.Size() when it is initialized. + respChkIdx int feedback *statistics.QueryFeedback partialCount int64 // number of partial results. @@ -99,20 +100,25 @@ func (r *selectResult) fetch(ctx context.Context) { if err != nil { result.err = err } else if resultSubset == nil { + // If the result is drained, the resultSubset would be nil return } else { result.result = resultSubset - if r.memTracker != nil { - r.memTracker.Consume(int64(resultSubset.MemSize())) - } + r.memConsume(int64(resultSubset.MemSize())) } select { case r.results <- result: case <-r.closed: // If selectResult called Close() already, make fetch goroutine exit. + if resultSubset != nil { + r.memConsume(-int64(resultSubset.MemSize())) + } return case <-ctx.Done(): + if resultSubset != nil { + r.memConsume(-int64(resultSubset.MemSize())) + } return } } @@ -157,24 +163,21 @@ func (r *selectResult) getSelectResp() error { if re.err != nil { return errors.Trace(re.err) } - if r.memTracker != nil && r.selectResp != nil { - r.memTracker.Consume(-int64(r.selectResp.Size())) + if r.selectResp != nil { + r.memConsume(-int64(r.selectRespSize)) } if re.result == nil { r.selectResp = nil return nil } - if r.memTracker != nil { - r.memTracker.Consume(-int64(re.result.MemSize())) - } + r.memConsume(-int64(re.result.MemSize())) r.selectResp = new(tipb.SelectResponse) err := r.selectResp.Unmarshal(re.result.GetData()) if err != nil { return errors.Trace(err) } - if r.memTracker != nil && r.selectResp != nil { - r.memTracker.Consume(int64(r.selectResp.Size())) - } + r.selectRespSize = r.selectResp.Size() + r.memConsume(int64(r.selectRespSize)) if err := r.selectResp.Error; err != nil { return terror.ClassTiKV.New(terror.ErrCode(err.Code), err.Msg) } @@ -207,13 +210,27 @@ func (r *selectResult) readRowsData(chk *chunk.Chunk) (err error) { return nil } +func (r *selectResult) memConsume(bytes int64) { + if r.memTracker != nil { + r.memTracker.Consume(bytes) + } +} + // Close closes selectResult. func (r *selectResult) Close() error { - // Close this channel tell fetch goroutine to exit. if r.feedback.Actual() >= 0 { metrics.DistSQLScanKeysHistogram.Observe(float64(r.feedback.Actual())) } metrics.DistSQLPartialCountHistogram.Observe(float64(r.partialCount)) + // Close this channel to tell the fetch goroutine to exit. close(r.closed) + for re := range r.results { + if re.result != nil { + r.memConsume(-int64(re.result.MemSize())) + } + } + if r.selectResp != nil { + r.memConsume(-int64(r.selectRespSize)) + } return r.resp.Close() }