diff --git a/executor/delete.go b/executor/delete.go index 3aa0932a07c22..97b3487ffa3f9 100644 --- a/executor/delete.go +++ b/executor/delete.go @@ -245,7 +245,8 @@ func (e *DeleteExec) removeRow(ctx sessionctx.Context, t table.Table, h kv.Handl if err != nil { return err } - err = e.onRemoveRowForFK(ctx, t, data) + tid := t.Meta().ID + err = onRemoveRowForFK(ctx, data, e.fkChecks[tid], e.fkCascades[tid]) if err != nil { return err } @@ -253,8 +254,7 @@ func (e *DeleteExec) removeRow(ctx sessionctx.Context, t table.Table, h kv.Handl return nil } -func (e *DeleteExec) onRemoveRowForFK(ctx sessionctx.Context, t table.Table, data []types.Datum) error { - fkChecks := e.fkChecks[t.Meta().ID] +func onRemoveRowForFK(ctx sessionctx.Context, data []types.Datum, fkChecks []*FKCheckExec, fkCascades []*FKCascadeExec) error { sc := ctx.GetSessionVars().StmtCtx for _, fkc := range fkChecks { err := fkc.deleteRowNeedToCheck(sc, data) @@ -262,7 +262,6 @@ func (e *DeleteExec) onRemoveRowForFK(ctx sessionctx.Context, t table.Table, dat return err } } - fkCascades := e.fkCascades[t.Meta().ID] for _, fkc := range fkCascades { err := fkc.onDeleteRow(sc, data) if err != nil { diff --git a/executor/fktest/foreign_key_test.go b/executor/fktest/foreign_key_test.go index a162bd22b96aa..8d6442f39fad4 100644 --- a/executor/fktest/foreign_key_test.go +++ b/executor/fktest/foreign_key_test.go @@ -2539,3 +2539,107 @@ func TestForeignKeyIssue39732(t *testing.T) { tk.MustExec("execute stmt1 using @a;") tk.MustQuery("select * from t1 order by id").Check(testkit.Rows()) } + +func TestForeignKeyOnReplaceIntoChildTable(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_enable_foreign_key=1") + tk.MustExec("set @@foreign_key_checks=1") + tk.MustExec("use test") + tk.MustExec("create table t_data (id int, a int, b int)") + tk.MustExec("insert into t_data (id, a, b) values (1, 1, 1), (2, 2, 2);") + for _, ca := range foreignKeyTestCase1 { + tk.MustExec("drop table if exists t2;") + tk.MustExec("drop table if exists t1;") + for _, sql := range ca.prepareSQLs { + tk.MustExec(sql) + } + tk.MustExec("replace into t1 (id, a, b) values (1, 1, 1);") + tk.MustExec("replace into t2 (id, a, b) values (1, 1, 1)") + tk.MustGetDBError("replace into t1 (id, a, b) values (1, 2, 3);", plannercore.ErrRowIsReferenced2) + if !ca.notNull { + tk.MustExec("replace into t2 (id, a, b) values (2, null, 1)") + tk.MustExec("replace into t2 (id, a, b) values (3, 1, null)") + tk.MustExec("replace into t2 (id, a, b) values (4, null, null)") + } + tk.MustGetDBError("replace into t2 (id, a, b) values (5, 1, 0);", plannercore.ErrNoReferencedRow2) + tk.MustGetDBError("replace into t2 (id, a, b) values (6, 0, 1);", plannercore.ErrNoReferencedRow2) + tk.MustGetDBError("replace into t2 (id, a, b) values (7, 2, 2);", plannercore.ErrNoReferencedRow2) + // Test replace into from select. + tk.MustExec("delete from t2") + tk.MustExec("replace into t2 (id, a, b) select id, a, b from t_data where t_data.id=1") + tk.MustGetDBError("replace into t2 (id, a, b) select id, a, b from t_data where t_data.id=2", plannercore.ErrNoReferencedRow2) + + // Test in txn + tk.MustExec("delete from t2") + tk.MustExec("begin") + tk.MustExec("delete from t1 where a=1") + tk.MustGetDBError("replace into t2 (id, a, b) values (1, 1, 1)", plannercore.ErrNoReferencedRow2) + tk.MustExec("replace into t1 (id, a, b) values (2, 2, 2)") + tk.MustExec("replace into t2 (id, a, b) values (2, 2, 2)") + tk.MustGetDBError("replace into t1 (id, a, b) values (2, 2, 3);", plannercore.ErrRowIsReferenced2) + tk.MustExec("rollback") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 1 1")) + tk.MustQuery("select id, a, b from t2 order by id").Check(testkit.Rows()) + } + + // Case-10: test primary key is handle and contain foreign key column, and foreign key column has default value. + tk.MustExec("drop table if exists t2;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("set @@tidb_enable_clustered_index=0;") + tk.MustExec("create table t1 (id int,a int, primary key(id));") + tk.MustExec("create table t2 (id int key,a int not null default 0, index (a), foreign key fk(a) references t1(id));") + tk.MustExec("replace into t1 values (1, 1);") + tk.MustExec("replace into t2 values (1, 1);") + tk.MustGetDBError("replace into t2 (id) values (10);", plannercore.ErrNoReferencedRow2) + tk.MustGetDBError("replace into t2 values (3, 2);", plannercore.ErrNoReferencedRow2) + + // Case-11: test primary key is handle and contain foreign key column, and foreign key column doesn't have default value. + tk.MustExec("drop table if exists t2;") + tk.MustExec("create table t2 (id int key,a int, index (a), foreign key fk(a) references t1(id));") + tk.MustExec("replace into t2 values (1, 1);") + tk.MustExec("replace into t2 (id) values (10);") + tk.MustGetDBError("replace into t2 values (3, 2);", plannercore.ErrNoReferencedRow2) +} + +func TestForeignKeyOnReplaceInto(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@foreign_key_checks=1") + tk.MustExec("use test") + tk.MustExec("create table t1 (id int key, a int, index (a));") + tk.MustExec("create table t2 (id int key, a int, index (a), constraint fk_1 foreign key (a) references t1(a));") + tk.MustExec("replace into t1 values (1, 1);") + tk.MustExec("replace into t2 values (1, 1);") + tk.MustExec("replace into t2 (id) values (2);") + tk.MustGetDBError("replace into t2 values (1, 2);", plannercore.ErrNoReferencedRow2) + // Test fk check on replace into parent table. + tk.MustGetDBError("replace into t1 values (1, 2);", plannercore.ErrRowIsReferenced2) + // Test fk cascade delete on replace into parent table. + tk.MustExec("alter table t2 drop foreign key fk_1") + tk.MustExec("alter table t2 add constraint fk_1 foreign key (a) references t1(a) on delete cascade") + tk.MustExec("replace into t1 values (1, 2);") + tk.MustQuery("select id, a from t1").Check(testkit.Rows("1 2")) + tk.MustQuery("select * from t2").Check(testkit.Rows("2 ")) + // Test fk cascade delete on replace into parent table. + tk.MustExec("alter table t2 drop foreign key fk_1") + tk.MustExec("alter table t2 add constraint fk_1 foreign key (a) references t1(a) on delete set null") + tk.MustExec("delete from t2") + tk.MustExec("delete from t1") + tk.MustExec("replace into t1 values (1, 1);") + tk.MustExec("replace into t2 values (1, 1);") + tk.MustExec("replace into t1 values (1, 2);") + tk.MustQuery("select id, a from t1").Check(testkit.Rows("1 2")) + tk.MustQuery("select id, a from t2").Check(testkit.Rows("1 ")) + + // Test cascade delete in self table by replace into statement. + tk.MustExec("drop table t1,t2") + tk.MustExec("create table t1 (id int key, name varchar(10), leader int, index(leader), foreign key (leader) references t1(id) ON DELETE CASCADE);") + tk.MustExec("replace into t1 values (1, 'boss', null), (10, 'l1_a', 1), (11, 'l1_b', 1), (12, 'l1_c', 1)") + tk.MustExec("replace into t1 values (100, 'l2_a1', 10), (101, 'l2_a2', 10), (102, 'l2_a3', 10)") + tk.MustExec("replace into t1 values (110, 'l2_b1', 11), (111, 'l2_b2', 11), (112, 'l2_b3', 11)") + tk.MustExec("replace into t1 values (120, 'l2_c1', 12), (121, 'l2_c2', 12), (122, 'l2_c3', 12)") + tk.MustExec("replace into t1 values (1000,'l3_a1', 100)") + tk.MustExec("replace into t1 values (1, 'new-boss', null)") + tk.MustQuery("select id from t1 order by id").Check(testkit.Rows("1")) +} diff --git a/executor/replace.go b/executor/replace.go index 158a620fb300e..bfc70ebc4451c 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -92,6 +92,10 @@ func (e *ReplaceExec) removeRow(ctx context.Context, txn kv.Transaction, handle if err != nil { return false, err } + err = onRemoveRowForFK(e.ctx, oldRow, e.fkChecks, e.fkCascades) + if err != nil { + return false, err + } e.ctx.GetSessionVars().StmtCtx.AddAffectedRows(1) return false, nil } @@ -277,3 +281,18 @@ func (e *ReplaceExec) setMessage() { stmtCtx.SetMessage(msg) } } + +// GetFKChecks implements WithForeignKeyTrigger interface. +func (e *ReplaceExec) GetFKChecks() []*FKCheckExec { + return e.fkChecks +} + +// GetFKCascades implements WithForeignKeyTrigger interface. +func (e *ReplaceExec) GetFKCascades() []*FKCascadeExec { + return e.fkCascades +} + +// HasFKCascades implements WithForeignKeyTrigger interface. +func (e *ReplaceExec) HasFKCascades() bool { + return len(e.fkCascades) > 0 +} diff --git a/planner/core/foreign_key.go b/planner/core/foreign_key.go index d63a5f489b7ac..f6cec0fa8f069 100644 --- a/planner/core/foreign_key.go +++ b/planner/core/foreign_key.go @@ -159,6 +159,17 @@ func (p *Insert) buildOnInsertFKTriggers(ctx sessionctx.Context, is infoschema.I if len(referredFKCascades) > 0 { fkCascades = append(fkCascades, referredFKCascades...) } + } else if p.IsReplace { + referredFKChecks, referredFKCascades, err := p.buildOnReplaceReferredFKTriggers(ctx, is, dbName, tblInfo) + if err != nil { + return err + } + if len(referredFKChecks) > 0 { + fkChecks = append(fkChecks, referredFKChecks...) + } + if len(referredFKCascades) > 0 { + fkCascades = append(fkCascades, referredFKCascades...) + } } for _, fk := range tblInfo.ForeignKeys { if fk.Version < 1 { @@ -186,6 +197,25 @@ func (p *Insert) buildOnDuplicateUpdateColumns() map[string]struct{} { return m } +func (p *Insert) buildOnReplaceReferredFKTriggers(ctx sessionctx.Context, is infoschema.InfoSchema, dbName string, tblInfo *model.TableInfo) ([]*FKCheck, []*FKCascade, error) { + referredFKs := is.GetTableReferredForeignKeys(dbName, tblInfo.Name.L) + fkChecks := make([]*FKCheck, 0, len(referredFKs)) + fkCascades := make([]*FKCascade, 0, len(referredFKs)) + for _, referredFK := range referredFKs { + fkCheck, fkCascade, err := buildOnDeleteOrUpdateFKTrigger(ctx, is, referredFK, FKCascadeOnDelete) + if err != nil { + return nil, nil, err + } + if fkCheck != nil { + fkChecks = append(fkChecks, fkCheck) + } + if fkCascade != nil { + fkCascades = append(fkCascades, fkCascade) + } + } + return fkChecks, fkCascades, nil +} + func (updt *Update) buildOnUpdateFKTriggers(ctx sessionctx.Context, is infoschema.InfoSchema, tblID2table map[int64]table.Table) error { if !ctx.GetSessionVars().ForeignKeyChecks { return nil