diff --git a/executor/executor.go b/executor/executor.go index c69a6d21dacc3..77805cc455915 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -941,6 +941,7 @@ func newLockCtx(seVars *variable.SessionVars, lockWaitTime int64) *kv.LockCtx { LockKeysDuration: &seVars.StmtCtx.LockKeysDuration, LockKeysCount: &seVars.StmtCtx.LockKeysCount, LockExpired: &seVars.TxnCtx.LockExpire, + CheckKeyExists: seVars.StmtCtx.CheckKeyExists, } } @@ -1677,6 +1678,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { } sc.TblInfo2UnionScan = make(map[*model.TableInfo]bool) + sc.CheckKeyExists = make(map[string]struct{}) errCount, warnCount := vars.StmtCtx.NumErrorWarnings() vars.SysErrorCount = errCount vars.SysWarningCount = warnCount diff --git a/kv/kv.go b/kv/kv.go index 71e002e493f58..da0751d691fae 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -58,6 +58,8 @@ const ( TaskID // CollectRuntimeStats is used to enable collect runtime stats. CollectRuntimeStats + // CheckExist map for key existence check. + CheckExists ) // Priority value for transaction priority. @@ -226,6 +228,7 @@ type LockCtx struct { Values map[string]ReturnedValue ValuesLock sync.Mutex LockExpired *uint32 + CheckKeyExists map[string]struct{} } // ReturnedValue pairs the Value and AlreadyLocked flag for PessimisticLock return values result. diff --git a/kv/union_store.go b/kv/union_store.go index 6b96bc5d78d47..567464a7c7b5f 100644 --- a/kv/union_store.go +++ b/kv/union_store.go @@ -107,6 +107,10 @@ func (us *unionStore) Get(ctx context.Context, k Key) ([]byte, error) { e, ok := us.opts.Get(PresumeKeyNotExistsError) if ok { us.keyExistErrs[string(k)] = e.(*existErrInfo) + if val, ok := us.opts.Get(CheckExists); ok { + checkExistMap := val.(map[string]struct{}) + checkExistMap[string(k)] = struct{}{} + } } return nil, ErrNotExist } diff --git a/session/pessimistic_test.go b/session/pessimistic_test.go index fa83f40457417..19996a9a15d0f 100644 --- a/session/pessimistic_test.go +++ b/session/pessimistic_test.go @@ -1353,3 +1353,185 @@ func (s *testPessimisticSuite) TestPointGetWithDeleteInMem(c *C) { tk2.MustQuery("select * from uk where c1 = 10").Check(testkit.Rows("10 77")) tk.MustExec("drop table if exists uk") } + +func (s *testPessimisticSuite) TestInsertDupKeyAfterLock(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop database if exists test_db") + tk.MustExec("create database test_db") + tk.MustExec("use test_db") + tk2.MustExec("use test_db") + tk2.MustExec("drop table if exists t1") + tk2.MustExec("create table t1(c1 int primary key, c2 int, c3 int, unique key uk(c2));") + tk2.MustExec("insert into t1 values(1, 2, 3);") + tk2.MustExec("insert into t1 values(10, 20, 30);") + + // Test insert after lock. + tk.MustExec("begin pessimistic") + err := tk.ExecToErr("update t1 set c2 = 20 where c1 = 1;") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 30")) + + tk.MustExec("begin pessimistic") + tk.MustExec("select * from t1 for update") + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 30")) + + tk.MustExec("begin pessimistic") + tk.MustExec("select * from t1 where c2 = 2 for update") + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 30")) + + // Test insert after insert. + tk.MustExec("begin pessimistic") + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("insert into t1 values(5, 6, 7)") + err = tk.ExecToErr("insert into t1 values(6, 6, 7);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "5 6 7", "10 20 30")) + + // Test insert after delete. + tk.MustExec("begin pessimistic") + tk.MustExec("delete from t1 where c2 > 2") + tk.MustExec("insert into t1 values(10, 20, 500);") + err = tk.ExecToErr("insert into t1 values(20, 20, 30);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("insert into t1 values(1, 20, 30);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 500")) + + // Test range. + tk.MustExec("begin pessimistic") + err = tk.ExecToErr("update t1 set c2 = 20 where c1 >= 1 and c1 < 5;") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("update t1 set c2 = 20 where c1 >= 1 and c1 < 50;") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 500")) + + // Test select for update after dml. + tk.MustExec("begin pessimistic") + tk.MustExec("insert into t1 values(5, 6, 7)") + tk.MustExec("select * from t1 where c1 = 5 for update") + tk.MustExec("select * from t1 where c1 = 6 for update") + tk.MustExec("select * from t1 for update") + err = tk.ExecToErr("insert into t1 values(7, 6, 7)") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("insert into t1 values(5, 8, 6)") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("select * from t1 where c1 = 5 for update") + tk.MustExec("select * from t1 where c2 = 8 for update") + tk.MustExec("select * from t1 for update") + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "5 6 7", "10 20 500")) + + // Test optimistic for update. + tk.MustExec("begin optimistic") + tk.MustQuery("select * from t1 where c1 = 1 for update").Check(testkit.Rows("1 2 3")) + tk.MustExec("insert into t1 values(10, 10, 10)") + err = tk.ExecToErr("commit") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) +} + +func (s *testPessimisticSuite) TestInsertDupKeyAfterLockBatchPointGet(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop database if exists test_db") + tk.MustExec("create database test_db") + tk.MustExec("use test_db") + tk2.MustExec("use test_db") + tk2.MustExec("drop table if exists t1") + tk2.MustExec("create table t1(c1 int primary key, c2 int, c3 int, unique key uk(c2));") + tk2.MustExec("insert into t1 values(1, 2, 3);") + tk2.MustExec("insert into t1 values(10, 20, 30);") + + // Test insert after lock. + tk.MustExec("begin pessimistic") + err := tk.ExecToErr("update t1 set c2 = 20 where c1 in (1);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 30")) + + tk.MustExec("begin pessimistic") + tk.MustExec("select * from t1 for update") + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 30")) + + tk.MustExec("begin pessimistic") + tk.MustExec("select * from t1 where c2 in (2) for update") + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 30")) + + // Test insert after insert. + tk.MustExec("begin pessimistic") + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("insert into t1 values(5, 6, 7)") + err = tk.ExecToErr("insert into t1 values(6, 6, 7);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "5 6 7", "10 20 30")) + + // Test insert after delete. + tk.MustExec("begin pessimistic") + tk.MustExec("delete from t1 where c2 > 2") + tk.MustExec("insert into t1 values(10, 20, 500);") + err = tk.ExecToErr("insert into t1 values(20, 20, 30);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("insert into t1 values(1, 20, 30);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 500")) + + // Test range. + tk.MustExec("begin pessimistic") + err = tk.ExecToErr("update t1 set c2 = 20 where c1 >= 1 and c1 < 5;") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("update t1 set c2 = 20 where c1 >= 1 and c1 < 50;") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("insert into t1 values(1, 15, 300);") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "10 20 500")) + + // Test select for update after dml. + tk.MustExec("begin pessimistic") + tk.MustExec("insert into t1 values(5, 6, 7)") + tk.MustExec("select * from t1 where c1 in (5, 6) for update") + tk.MustExec("select * from t1 where c1 = 6 for update") + tk.MustExec("select * from t1 for update") + err = tk.ExecToErr("insert into t1 values(7, 6, 7)") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + err = tk.ExecToErr("insert into t1 values(5, 8, 6)") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + tk.MustExec("select * from t1 where c2 = 8 for update") + tk.MustExec("select * from t1 where c1 in (5, 8) for update") + tk.MustExec("select * from t1 for update") + tk.MustExec("commit") + tk2.MustQuery("select * from t1").Check(testkit.Rows("1 2 3", "5 6 7", "10 20 500")) + + // Test optimistic for update. + tk.MustExec("begin optimistic") + tk.MustQuery("select * from t1 where c1 in (1) for update").Check(testkit.Rows("1 2 3")) + tk.MustExec("insert into t1 values(10, 10, 10)") + err = tk.ExecToErr("commit") + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) +} diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 49fc488a9fd10..b17b49f34d042 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -152,7 +152,8 @@ type StatementContext struct { LockKeysDuration time.Duration LockKeysCount int32 TblInfo2UnionScan map[*model.TableInfo]bool - TaskID uint64 // unique ID for an execution of a statement + TaskID uint64 // unique ID for an execution of a statement + CheckKeyExists map[string]struct{} // mark the keys needs to check for existence for pessimistic locks. } // StmtHints are SessionVars related sql hints. diff --git a/store/tikv/txn.go b/store/tikv/txn.go index b4c6aff9716ea..0aa4466dde148 100644 --- a/store/tikv/txn.go +++ b/store/tikv/txn.go @@ -55,7 +55,7 @@ type tikvTxn struct { startTime time.Time // Monotonic timestamp for recording txn time consuming. commitTS uint64 lockKeys [][]byte - lockedMap map[string]struct{} + lockedMap map[string]bool mu sync.Mutex // For thread-safe LockKeys function. setCnt int64 vars *kv.Variables @@ -88,7 +88,7 @@ func newTikvTxnWithStartTS(store *tikvStore, startTS uint64, replicaReadSeed uin return &tikvTxn{ snapshot: snapshot, us: kv.NewUnionStore(snapshot), - lockedMap: map[string]struct{}{}, + lockedMap: make(map[string]bool), store: store, startTS: startTS, startTime: time.Now(), @@ -213,6 +213,8 @@ func (txn *tikvTxn) SetOption(opt kv.Option, val interface{}) { txn.snapshot.keyOnly = val.(bool) case kv.SnapshotTS: txn.snapshot.setSnapshotTS(val.(uint64)) + case kv.CheckExists: + txn.us.SetOption(kv.CheckExists, val.(map[string]struct{})) } } @@ -361,9 +363,23 @@ func (txn *tikvTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keysInput }() txn.mu.Lock() for _, key := range keysInput { - if _, ok := txn.lockedMap[string(key)]; !ok { + // The value of lockedMap is only used by pessimistic transactions. + valueExist, locked := txn.lockedMap[string(key)] + _, checkKeyExists := lockCtx.CheckKeyExists[string(key)] + if !locked { keys = append(keys, key) - } else if lockCtx.ReturnValues { + } else if txn.IsPessimistic() { + if checkKeyExists && valueExist { + existErrInfo := txn.us.GetKeyExistErrInfo(key) + if existErrInfo == nil { + logutil.Logger(ctx).Error("key exist error not found", zap.Uint64("connID", txn.committer.connID), + zap.Stringer("key", key)) + return errors.Errorf("conn %d, existErr for key:%s should not be nil", txn.committer.connID, key) + } + return existErrInfo.Err() + } + } + if lockCtx.ReturnValues && locked { // An already locked key can not return values, we add an entry to let the caller get the value // in other ways. lockCtx.Values[string(key)] = kv.ReturnedValue{AlreadyLocked: true} @@ -438,7 +454,15 @@ func (txn *tikvTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keysInput txn.mu.Lock() txn.lockKeys = append(txn.lockKeys, keys...) for _, key := range keys { - txn.lockedMap[string(key)] = struct{}{} + // PointGet and BatchPointGet will return value in pessimistic lock response, the value may not exists. + // For other lock modes, the locked key values always exist. + if lockCtx.ReturnValues { + val, _ := lockCtx.Values[string(key)] + valExists := len(val.Value) > 0 + txn.lockedMap[string(key)] = valExists + } else { + txn.lockedMap[string(key)] = true + } } txn.dirty = true txn.mu.Unlock() diff --git a/table/tables/tables.go b/table/tables/tables.go index 9940a324d8ea4..9e7ef3f18a91f 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -663,6 +663,7 @@ func (t *TableCommon) addIndices(sctx sessionctx.Context, recordID int64, r []ty } existErrInfo := kv.NewExistErrInfo(v.Meta().Name.String(), entryKey) txn.SetOption(kv.PresumeKeyNotExistsError, existErrInfo) + txn.SetOption(kv.CheckExists, sctx.GetSessionVars().StmtCtx.CheckKeyExists) dupErr = existErrInfo.Err() } if dupHandle, err := v.Create(sctx, rm, indexVals, recordID, opts...); err != nil { @@ -1219,6 +1220,7 @@ func CheckHandleExists(ctx context.Context, sctx sessionctx.Context, t table.Tab recordKey := t.RecordKey(recordID) existErrInfo := kv.NewExistErrInfo("PRIMARY", strconv.Itoa(int(recordID))) txn.SetOption(kv.PresumeKeyNotExistsError, existErrInfo) + txn.SetOption(kv.CheckExists, sctx.GetSessionVars().StmtCtx.CheckKeyExists) defer txn.DelOption(kv.PresumeKeyNotExistsError) _, err = txn.Get(ctx, recordKey) if err == nil { diff --git a/util/mock/context.go b/util/mock/context.go index 089edaafb171c..418c3bac1109d 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -278,6 +278,7 @@ func NewContext() *Context { sctx.sessionVars.StmtCtx.TimeZone = time.UTC sctx.sessionVars.StmtCtx.MemTracker = memory.NewTracker(stringutil.StringerStr("mock.NewContext"), -1) sctx.sessionVars.StmtCtx.DiskTracker = disk.NewTracker(stringutil.StringerStr("mock.NewContext"), -1) + sctx.sessionVars.StmtCtx.CheckKeyExists = make(map[string]struct{}) sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor() if err := sctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864"); err != nil { panic(err)