From 3fd392a7b69015d5560fa94fbb6f21522698c36d Mon Sep 17 00:00:00 2001 From: CbcWestwolf <1004626265@qq.com> Date: Fri, 24 Feb 2023 13:15:48 +0800 Subject: [PATCH] Support check constraint (#2) --- ddl/BUILD.bazel | 2 + ddl/column_modify_test.go | 11 - ddl/constraint.go | 383 +++++++++++++++++++++++++++++++++ ddl/constraint_test.go | 320 +++++++++++++++++++++++++++ ddl/db_change_test.go | 42 ++++ ddl/db_test.go | 42 ---- ddl/ddl_api.go | 219 +++++++++++++++++-- ddl/ddl_worker.go | 6 + ddl/rollingback.go | 48 ++++- errno/errcode.go | 14 ++ errno/errname.go | 14 ++ errors.toml | 35 +++ executor/show.go | 19 ++ executor/showtest/show_test.go | 34 +++ infoschema/error.go | 2 + parser/model/ddl.go | 2 +- planner/core/preprocess.go | 1 + table/BUILD.bazel | 2 + table/constraint.go | 68 ++++++ table/table.go | 2 + table/tables/partition.go | 2 +- table/tables/tables.go | 76 ++++++- table/tables/tables_test.go | 67 ++++++ util/dbterror/ddl_terror.go | 8 + 24 files changed, 1345 insertions(+), 74 deletions(-) create mode 100644 ddl/constraint.go create mode 100644 ddl/constraint_test.go create mode 100644 table/constraint.go diff --git a/ddl/BUILD.bazel b/ddl/BUILD.bazel index 6dfaedd66846f..ed2112733d8ee 100644 --- a/ddl/BUILD.bazel +++ b/ddl/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "cluster.go", "column.go", "constant.go", + "constraint.go", "ddl.go", "ddl_algorithm.go", "ddl_api.go", @@ -152,6 +153,7 @@ go_test( "column_modify_test.go", "column_test.go", "column_type_change_test.go", + "constraint_test.go", "db_cache_test.go", "db_change_failpoints_test.go", "db_change_test.go", diff --git a/ddl/column_modify_test.go b/ddl/column_modify_test.go index 2055f9df9fa9c..207697718991b 100644 --- a/ddl/column_modify_test.go +++ b/ddl/column_modify_test.go @@ -704,17 +704,6 @@ func TestTransactionWithWriteOnlyColumn(t *testing.T) { tk.MustQuery("select a from t1").Check(testkit.Rows("2")) } -func TestColumnCheck(t *testing.T) { - store := testkit.CreateMockStoreWithSchemaLease(t, columnModifyLease) - tk := testkit.NewTestKit(t, store) - tk.MustExec("use test") - tk.MustExec("drop table if exists column_check") - tk.MustExec("create table column_check (pk int primary key, a int check (a > 1))") - defer tk.MustExec("drop table if exists column_check") - require.Equal(t, uint16(1), tk.Session().GetSessionVars().StmtCtx.WarningCount()) - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|8231|CONSTRAINT CHECK is not supported")) -} - func TestModifyGeneratedColumn(t *testing.T) { store := testkit.CreateMockStoreWithSchemaLease(t, columnModifyLease) tk := testkit.NewTestKit(t, store) diff --git a/ddl/constraint.go b/ddl/constraint.go new file mode 100644 index 0000000000000..6c384bfbd26b8 --- /dev/null +++ b/ddl/constraint.go @@ -0,0 +1,383 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/format" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/dbterror" + "github.com/pingcap/tidb/util/sqlexec" +) + +func (w *worker) onAddCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + // Handle the rolling back job. + if job.IsRollingback() { + ver, err = onDropCheckConstraint(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + return ver, nil + } + + failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(ver, errors.New("occur an error before decode args")) + } + }) + + dbInfo, tblInfo, constraintInfoInMeta, constraintInfoInJob, err := checkAddCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + if constraintInfoInMeta == nil { + // It's first time to run add constraint job, so there is no constraint info in meta. + // Use the raw constraint info from job directly and modify table info here. + constraintInfoInJob.ID = allocateConstraintID(tblInfo) + // Reset constraint name according to real-time constraints name at this point. + constrNames := map[string]bool{} + for _, constr := range tblInfo.Constraints { + constrNames[constr.Name.L] = true + } + setNameForConstraintInfo(tblInfo.Name.L, constrNames, []*model.ConstraintInfo{constraintInfoInJob}) + // Double check the constraint dependency. + existedColsMap := make(map[string]struct{}) + cols := tblInfo.Columns + for _, v := range cols { + if v.State == model.StatePublic { + existedColsMap[v.Name.L] = struct{}{} + } + } + dependedCols := constraintInfoInJob.ConstraintCols + for _, k := range dependedCols { + if _, ok := existedColsMap[k.L]; !ok { + // The table constraint depended on a non-existed column. + return ver, dbterror.ErrTableCheckConstraintReferUnknown.GenWithStackByArgs(constraintInfoInJob.Name, k) + } + } + + tblInfo.Constraints = append(tblInfo.Constraints, constraintInfoInJob) + constraintInfoInMeta = constraintInfoInJob + } + + originalState := constraintInfoInMeta.State + switch constraintInfoInMeta.State { + case model.StateNone: + // none -> write only + job.SchemaState = model.StateWriteOnly + constraintInfoInMeta.State = model.StateWriteOnly + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != constraintInfoInMeta.State) + case model.StateWriteOnly: + // write only -> public + skipCheck := false + failpoint.Inject("mockPassAddConstraintCheck", func(val failpoint.Value) { + if val.(bool) { + skipCheck = true + } + }) + if !skipCheck { + err = w.addTableCheckConstraint(dbInfo, tblInfo, constraintInfoInMeta, job) + if err != nil { + return ver, errors.Trace(err) + } + } + constraintInfoInMeta.State = model.StatePublic + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != constraintInfoInMeta.State) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("constraint", constraintInfoInMeta.State) + } + + return ver, errors.Trace(err) +} + +// onDropCheckConstraint can be called from two case: +// 1: rollback in add constraint.(in rollback function the job.args will be changed) +// 2: user drop constraint ddl. +func onDropCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + tblInfo, constraintInfo, err := checkDropCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + + originalState := constraintInfo.State + switch constraintInfo.State { + case model.StatePublic: + // public -> write only + job.SchemaState = model.StateWriteOnly + constraintInfo.State = model.StateWriteOnly + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != constraintInfo.State) + case model.StateWriteOnly: + // write only -> None + // write only state constraint will still take effect to check the newly inserted data. + // So the depended column shouldn't be dropped even in this intermediate state. + constraintInfo.State = model.StateNone + // remove the constraint from tableInfo. + for i, constr := range tblInfo.Constraints { + if constr.Name.L == constraintInfo.Name.L { + tblInfo.Constraints = append(tblInfo.Constraints[0:i], tblInfo.Constraints[i+1:]...) + } + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != constraintInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + if job.IsRollingback() { + job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) + } else { + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + } + default: + err = dbterror.ErrInvalidDDLJob.GenWithStackByArgs("constraint", tblInfo.State) + } + return ver, errors.Trace(err) +} + +func (w *worker) onAlterCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + dbInfo, tblInfo, constraintInfo, enforced, err := checkAlterCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + + // enforced will fetch table data and check the constraint. + if constraintInfo.Enforced != enforced && enforced { + skipCheck := false + failpoint.Inject("mockPassAlterConstraintCheck", func(val failpoint.Value) { + if val.(bool) { + skipCheck = true + } + }) + if !skipCheck { + err = w.addTableCheckConstraint(dbInfo, tblInfo, constraintInfo, job) + if err != nil { + // check constraint error will cancel the job, job state has been changed + // to cancelled in addTableCheckConstraint. + return ver, errors.Trace(err) + } + } + } + constraintInfo.Enforced = enforced + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + if err != nil { + // update version and tableInfo error will cause retry. + return ver, errors.Trace(err) + } + + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func checkDropCheckConstraint(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ConstraintInfo, error) { + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, errors.Trace(err) + } + + var constrName model.CIStr + err = job.DecodeArgs(&constrName) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, errors.Trace(err) + } + + // do the double-check with constraint existence. + constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) + if constraintInfo == nil { + job.State = model.JobStateCancelled + return nil, nil, dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) + } + return tblInfo, constraintInfo, nil +} + +func checkAddCheckConstraint(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.TableInfo, *model.ConstraintInfo, *model.ConstraintInfo, error) { + schemaID := job.SchemaID + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return nil, nil, nil, nil, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, nil, nil, errors.Trace(err) + } + constraintInfo1 := &model.ConstraintInfo{} + err = job.DecodeArgs(constraintInfo1) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, nil, errors.Trace(err) + } + // do the double-check with constraint existence. + constraintInfo2 := tblInfo.FindConstraintInfoByName(constraintInfo1.Name.L) + if constraintInfo2 != nil { + if constraintInfo2.State == model.StatePublic { + // We already have a constraint with the same constraint name. + job.State = model.JobStateCancelled + return nil, nil, nil, nil, infoschema.ErrColumnExists.GenWithStackByArgs(constraintInfo1.Name) + } + // if not, that means constraint was in intermediate state. + } + return dbInfo, tblInfo, constraintInfo2, constraintInfo1, nil +} + +func checkAlterCheckConstraint(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.TableInfo, *model.ConstraintInfo, bool, error) { + schemaID := job.SchemaID + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return nil, nil, nil, false, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, nil, false, errors.Trace(err) + } + + var ( + enforced bool + constrName model.CIStr + ) + err = job.DecodeArgs(&constrName, &enforced) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, false, errors.Trace(err) + } + + // do the double-check with constraint existence. + constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) + if constraintInfo == nil { + job.State = model.JobStateCancelled + return nil, nil, nil, false, dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) + } + return dbInfo, tblInfo, constraintInfo, enforced, nil +} + +func allocateConstraintID(tblInfo *model.TableInfo) int64 { + tblInfo.MaxConstraintID++ + return tblInfo.MaxConstraintID +} + +func buildConstraintInfo(tblInfo *model.TableInfo, dependedCols []model.CIStr, constr *ast.Constraint, state model.SchemaState) (*model.ConstraintInfo, error) { + constraintName := model.NewCIStr(constr.Name) + if err := checkTooLongConstraint(constraintName); err != nil { + return nil, errors.Trace(err) + } + + // Restore check constraint expression to string. + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + + sb.Reset() + err := constr.Expr.Restore(restoreCtx) + if err != nil { + return nil, errors.Trace(err) + } + + // Create constraint info. + constraintInfo := &model.ConstraintInfo{ + Name: constraintName, + Table: tblInfo.Name, + ConstraintCols: dependedCols, + ExprString: sb.String(), + Enforced: constr.Enforced, + InColumn: constr.InColumn, + State: state, + } + + return constraintInfo, nil +} + +func checkTooLongConstraint(constr model.CIStr) error { + if len(constr.L) > mysql.MaxConstraintIdentifierLen { + return dbterror.ErrTooLongIdent.GenWithStackByArgs(constr) + } + return nil +} + +// findDependedColsMapInExpr returns a set of string, which indicates +// the names of the columns that are depended by exprNode. +func findDependedColsMapInExpr(expr ast.ExprNode) map[string]struct{} { + colNames := FindColumnNamesInExpr(expr) + colsMap := make(map[string]struct{}, len(colNames)) + for _, depCol := range colNames { + colsMap[depCol.Name.L] = struct{}{} + } + return colsMap +} + +func (w *worker) addTableCheckConstraint(dbInfo *model.DBInfo, tableInfo *model.TableInfo, constr *model.ConstraintInfo, job *model.Job) error { + // Get sessionctx from ddl context resource pool in ddl worker. + var sctx sessionctx.Context + sctx, err := w.sessPool.get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.put(sctx) + + // If there is any row can't pass the check expression, the add constraint action will error. + // It's no need to construct expression node out and pull the chunk rows through it. Here we + // can let the check expression restored string as the filter in where clause directly. + // Prepare internal SQL to fetch data from physical table under this filter. + sql := fmt.Sprintf("select count(1) from `%s`.`%s` where not %s limit 1", dbInfo.Name.L, tableInfo.Name.L, constr.ExprString) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, nil, sql) + if err != nil { + return errors.Trace(err) + } + rowCount := rows[0].GetInt64(0) + if rowCount != 0 { + // If check constraint fail, the job state should be changed to canceled, otherwise it will tracked in. + job.State = model.JobStateCancelled + return dbterror.ErrCheckConstraintIsViolated.GenWithStackByArgs(constr.Name.L) + } + return nil +} + +func setNameForConstraintInfo(tableLowerName string, namesMap map[string]bool, infos []*model.ConstraintInfo) { + cnt := 1 + constraintPrefix := tableLowerName + "_chk_" + for _, constrInfo := range infos { + if constrInfo.Name.O == "" { + constrName := fmt.Sprintf("%s%d", constraintPrefix, cnt) + for { + // loop until find constrName that haven't been used. + if !namesMap[constrName] { + namesMap[constrName] = true + break + } + cnt++ + constrName = fmt.Sprintf("%s%d", constraintPrefix, cnt) + } + constrInfo.Name = model.NewCIStr(constrName) + } + } +} diff --git a/ddl/constraint_test.go b/ddl/constraint_test.go new file mode 100644 index 0000000000000..788ba964bad80 --- /dev/null +++ b/ddl/constraint_test.go @@ -0,0 +1,320 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl_test + +import ( + "sort" + "testing" + + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/testkit/external" + "github.com/stretchr/testify/require" +) + +func TestCreateTableWithCheckConstraints(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + + // Test column-type check constraint. + tk.MustExec("create table t(a int not null check(a>0))") + constraintTable := external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 1, len(constraintTable.Meta().Columns)) + require.Equal(t, 1, len(constraintTable.Meta().Constraints)) + constrs := constraintTable.Meta().Constraints + require.Equal(t, int64(1), constrs[0].ID) + require.True(t, constrs[0].InColumn) + require.True(t, constrs[0].Enforced) + require.Equal(t, "t", constrs[0].Table.L) + require.Equal(t, model.StatePublic, constrs[0].State) + require.Equal(t, 1, len(constrs[0].ConstraintCols)) + require.Equal(t, model.NewCIStr("a"), constrs[0].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("t_chk_1"), constrs[0].Name) + require.Equal(t, "`a` > 0", constrs[0].ExprString) + + tk.MustExec("drop table t") + tk.MustExec("create table t(a bigint key constraint my_constr check(a<10), b int constraint check(b > 1) not enforced)") + constraintTable = external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 2, len(constraintTable.Meta().Columns)) + constrs = constraintTable.Meta().Constraints + require.Equal(t, 2, len(constrs)) + require.Equal(t, int64(1), constrs[0].ID) + require.True(t, constrs[0].InColumn) + require.True(t, constrs[0].Enforced) + require.Equal(t, "t", constrs[0].Table.L) + require.Equal(t, model.StatePublic, constrs[0].State) + require.Equal(t, 1, len(constrs[0].ConstraintCols)) + require.Equal(t, model.NewCIStr("a"), constrs[0].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("my_constr"), constrs[0].Name) + require.Equal(t, "`a` < 10", constrs[0].ExprString) + + require.Equal(t, int64(2), constrs[1].ID) + require.True(t, constrs[1].InColumn) + require.False(t, constrs[1].Enforced) + require.Equal(t, 1, len(constrs[0].ConstraintCols)) + require.Equal(t, "t", constrs[1].Table.L) + require.Equal(t, model.StatePublic, constrs[1].State) + require.Equal(t, 1, len(constrs[1].ConstraintCols)) + require.Equal(t, model.NewCIStr("b"), constrs[1].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("t_chk_1"), constrs[1].Name) + require.Equal(t, "`b` > 1", constrs[1].ExprString) + + // Test table-type check constraint. + tk.MustExec("drop table t") + tk.MustExec("create table t(a int constraint check(a > 1) not enforced, constraint my_constr check(a < 10))") + constraintTable = external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 1, len(constraintTable.Meta().Columns)) + constrs = constraintTable.Meta().Constraints + require.Equal(t, 2, len(constrs)) + // table-type check constraint. + require.Equal(t, int64(1), constrs[0].ID) + require.False(t, constrs[0].InColumn) + require.True(t, constrs[0].Enforced) + require.Equal(t, "t", constrs[0].Table.L) + require.Equal(t, model.StatePublic, constrs[0].State) + require.Equal(t, 1, len(constrs[0].ConstraintCols)) + require.Equal(t, model.NewCIStr("a"), constrs[0].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("my_constr"), constrs[0].Name) + require.Equal(t, "`a` < 10", constrs[0].ExprString) + + // column-type check constraint. + require.Equal(t, int64(2), constrs[1].ID) + require.True(t, constrs[1].InColumn) + require.False(t, constrs[1].Enforced) + require.Equal(t, "t", constrs[1].Table.L) + require.Equal(t, model.StatePublic, constrs[1].State) + require.Equal(t, 1, len(constrs[1].ConstraintCols)) + require.Equal(t, model.NewCIStr("a"), constrs[1].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("t_chk_1"), constrs[1].Name) + require.Equal(t, "`a` > 1", constrs[1].ExprString) + + // Test column-type check constraint fail on dependency. + tk.MustExec("drop table t") + _, err := tk.Exec("create table t(a int not null check(b>0))") + require.Errorf(t, err, "[ddl:3813]Column check constraint 't_chk_1' references other column.") + + _, err = tk.Exec("create table t(a int not null check(b>a))") + require.Errorf(t, err, "[ddl:3813]Column check constraint 't_chk_1' references other column.") + + _, err = tk.Exec("create table t(a int not null check(a>0), b int, constraint check(c>b))") + require.Errorf(t, err, "[ddl:3820]Check constraint 't_chk_1' refers to non-existing column 'c'.") + + tk.MustExec("create table t(a int not null check(a>0), b int, constraint check(a>b))") + tk.MustExec("drop table t") + + tk.MustExec("create table t(a int not null check(a > '12345'))") + tk.MustExec("drop table t") + + tk.MustExec("create table t(a int not null primary key check(a > '12345'))") + tk.MustExec("drop table t") + + tk.MustExec("create table t(a varchar(10) not null primary key check(a > '12345'))") + tk.MustExec("drop table t") +} + +func TestAlterTableAddCheckConstraints(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + + tk.MustExec("create table t(a int not null check(a>0))") + // Add constraint with name. + tk.MustExec("alter table t add constraint haha check(a<10)") + constraintTable := external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 1, len(constraintTable.Meta().Columns)) + require.Equal(t, 2, len(constraintTable.Meta().Constraints)) + constrs := constraintTable.Meta().Constraints + require.Equal(t, int64(2), constrs[1].ID) + require.False(t, constrs[1].InColumn) + require.True(t, constrs[1].Enforced) + require.Equal(t, "t", constrs[1].Table.L) + require.Equal(t, model.StatePublic, constrs[1].State) + require.Equal(t, 1, len(constrs[1].ConstraintCols)) + require.Equal(t, model.NewCIStr("a"), constrs[1].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("haha"), constrs[1].Name) + require.Equal(t, "`a` < 10", constrs[1].ExprString) + + // Add constraint without name. + tk.MustExec("alter table t add constraint check(a<11) not enforced") + constraintTable = external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 1, len(constraintTable.Meta().Columns)) + require.Equal(t, 3, len(constraintTable.Meta().Constraints)) + constrs = constraintTable.Meta().Constraints + require.Equal(t, int64(3), constrs[2].ID) + require.False(t, constrs[2].InColumn) + require.False(t, constrs[2].Enforced) + require.Equal(t, "t", constrs[2].Table.L) + require.Equal(t, model.StatePublic, constrs[2].State) + require.Equal(t, 1, len(constrs[2].ConstraintCols)) + require.Equal(t, model.NewCIStr("a"), constrs[2].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("t_chk_2"), constrs[2].Name) + require.Equal(t, "`a` < 11", constrs[2].ExprString) + + // Add constraint with the name has already existed. + _, err := tk.Exec("alter table t add constraint haha check(a)") + require.Errorf(t, err, "[schema:3822]Duplicate check constraint name 'haha'.") + + // Add constraint with the unknown column. + _, err = tk.Exec("alter table t add constraint check(b)") + require.Errorf(t, err, "[ddl:3820]Check constraint 't_chk_3' refers to non-existing column 'b'.") + + tk.MustExec("alter table t add constraint check(a*2 < a+1) not enforced") +} + +func TestAlterTableDropCheckConstraints(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + + tk.MustExec("create table t(a int not null check(a>0), b int, constraint haha check(a < b), check(a 0", constrs[2].ExprString) + + // Drop a non-exist constraint + _, err := tk.Exec("alter table t drop constraint not_exist_constraint") + require.Errorf(t, err, "[ddl:3940]Constraint 'not_exist_constraint' does not exist") + + tk.MustExec("alter table t drop constraint haha") + constraintTable = external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 2, len(constraintTable.Meta().Columns)) + require.Equal(t, 2, len(constraintTable.Meta().Constraints)) + constrs = constraintTable.Meta().Constraints + require.Equal(t, model.NewCIStr("t_chk_1"), constrs[0].Name) + require.Equal(t, model.NewCIStr("t_chk_2"), constrs[1].Name) + + tk.MustExec("alter table t drop constraint t_chk_2") + constraintTable = external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 2, len(constraintTable.Meta().Columns)) + require.Equal(t, 1, len(constraintTable.Meta().Constraints)) + constrs = constraintTable.Meta().Constraints + require.Equal(t, model.NewCIStr("t_chk_1"), constrs[0].Name) +} + +func TestAlterTableAlterCheckConstraints(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + + tk.MustExec("create table t(a int not null check(a>0) not enforced, b int, constraint haha check(a < b))") + constraintTable := external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 2, len(constraintTable.Meta().Columns)) + require.Equal(t, 2, len(constraintTable.Meta().Constraints)) + constrs := constraintTable.Meta().Constraints + + require.Equal(t, int64(1), constrs[0].ID) + require.False(t, constrs[0].InColumn) + require.True(t, constrs[0].Enforced) + require.Equal(t, "t", constrs[0].Table.L) + require.Equal(t, model.StatePublic, constrs[0].State) + require.Equal(t, 2, len(constrs[0].ConstraintCols)) + sort.Slice(constrs[0].ConstraintCols, func(i, j int) bool { + return constrs[0].ConstraintCols[i].L < constrs[0].ConstraintCols[j].L + }) + require.Equal(t, model.NewCIStr("a"), constrs[0].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("b"), constrs[0].ConstraintCols[1]) + require.Equal(t, model.NewCIStr("haha"), constrs[0].Name) + require.Equal(t, "`a` < `b`", constrs[0].ExprString) + + require.Equal(t, int64(2), constrs[1].ID) + require.True(t, constrs[1].InColumn) + require.False(t, constrs[1].Enforced) + require.Equal(t, "t", constrs[1].Table.L) + require.Equal(t, model.StatePublic, constrs[1].State) + require.Equal(t, 1, len(constrs[1].ConstraintCols)) + require.Equal(t, model.NewCIStr("a"), constrs[1].ConstraintCols[0]) + require.Equal(t, model.NewCIStr("t_chk_1"), constrs[1].Name) + require.Equal(t, "`a` > 0", constrs[1].ExprString) + + // Alter constraint alter constraint with unknown name. + _, err := tk.Exec("alter table t alter constraint unknown not enforced") + require.Errorf(t, err, "[ddl:3940]Constraint 'unknown' does not exist") + + // Alter table alter constraint with user name. + tk.MustExec("alter table t alter constraint haha not enforced") + constraintTable = external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 2, len(constraintTable.Meta().Columns)) + require.Equal(t, 2, len(constraintTable.Meta().Constraints)) + constrs = constraintTable.Meta().Constraints + require.False(t, constrs[0].Enforced) + require.Equal(t, model.NewCIStr("haha"), constrs[0].Name) + require.False(t, constrs[1].Enforced) + require.Equal(t, model.NewCIStr("t_chk_1"), constrs[1].Name) + + // Alter table alter constraint with generated name. + tk.MustExec("alter table t alter constraint t_chk_1 enforced") + constraintTable = external.GetTableByName(t, tk, "test", "t") + require.Equal(t, 2, len(constraintTable.Meta().Columns)) + require.Equal(t, 2, len(constraintTable.Meta().Constraints)) + constrs = constraintTable.Meta().Constraints + require.False(t, constrs[0].Enforced) + require.Equal(t, model.NewCIStr("haha"), constrs[0].Name) + require.True(t, constrs[1].Enforced) + require.Equal(t, model.NewCIStr("t_chk_1"), constrs[1].Name) + + // Alter table alter constraint will violate check. + // Here a=1, b=0 doesn't satisfy "a < b" constraint. + // Since "a1), b int, constraint a_b check(a StateWriteOnly -> StatePublic + // Node in StateWriteOnly and StatePublic should check the constraint check. + _, checkErr = tk1.Exec("insert into t (a, b) values(5,6) ") + // Don't do the assert in the callback function. + } + } + callback.OnJobUpdatedExported.Store(&onJobUpdatedExportedFunc) + d.SetHook(callback) + tk.MustExec("alter table t add constraint cc check ( b < 5 )") + require.Errorf(t, err, "[table:3819]Check constraint 'cc' is violated.") + + tk.MustExec("alter table t drop constraint cc") + require.Errorf(t, err, "[table:3819]Check constraint 'cc' is violated.") + tk.MustExec("drop table if exists t") +} + func TestTwoStates(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, 200*time.Millisecond) tk := testkit.NewTestKit(t, store) diff --git a/ddl/db_test.go b/ddl/db_test.go index edc891ad16ccf..89d5a3cd86cd3 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -44,7 +44,6 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" - "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/testkit/external" "github.com/pingcap/tidb/types" @@ -350,19 +349,6 @@ func TestIssue23473(t *testing.T) { require.True(t, mysql.HasNoDefaultValueFlag(tbl.Cols()[0].GetFlag())) } -func TestDropCheck(t *testing.T) { - store := testkit.CreateMockStoreWithSchemaLease(t, dbTestLease) - - tk := testkit.NewTestKit(t, store) - tk.MustExec("use test") - tk.MustExec("drop table if exists drop_check") - tk.MustExec("create table drop_check (pk int primary key)") - defer tk.MustExec("drop table if exists drop_check") - tk.MustExec("alter table drop_check drop check crcn") - require.Equal(t, uint16(1), tk.Session().GetSessionVars().StmtCtx.WarningCount()) - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|8231|DROP CHECK is not supported")) -} - func TestAlterOrderBy(t *testing.T) { store := testkit.CreateMockStoreWithSchemaLease(t, dbTestLease) @@ -505,34 +491,6 @@ func TestSelectInViewFromAnotherDB(t *testing.T) { tk.MustExec("select test_db2.v.a from test_db2.v") } -func TestAddConstraintCheck(t *testing.T) { - store := testkit.CreateMockStoreWithSchemaLease(t, dbTestLease) - - tk := testkit.NewTestKit(t, store) - tk.MustExec("use test") - tk.MustExec("drop table if exists add_constraint_check") - tk.MustExec("create table add_constraint_check (pk int primary key, a int)") - defer tk.MustExec("drop table if exists add_constraint_check") - tk.MustExec("alter table add_constraint_check add constraint crn check (a > 1)") - require.Equal(t, uint16(1), tk.Session().GetSessionVars().StmtCtx.WarningCount()) - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|8231|ADD CONSTRAINT CHECK is not supported")) -} - -func TestCreateTableIgnoreCheckConstraint(t *testing.T) { - store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) - - tk := testkit.NewTestKit(t, store) - tk.MustExec("use test") - tk.MustExec("drop table if exists table_constraint_check") - tk.MustExec("CREATE TABLE admin_user (enable bool, CHECK (enable IN (0, 1)));") - require.Equal(t, uint16(1), tk.Session().GetSessionVars().StmtCtx.WarningCount()) - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|8231|CONSTRAINT CHECK is not supported")) - tk.MustQuery("show create table admin_user").Check(testkit.RowsWithSep("|", ""+ - "admin_user CREATE TABLE `admin_user` (\n"+ - " `enable` tinyint(1) DEFAULT NULL\n"+ - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) -} - func TestAutoConvertBlobTypeByLength(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 19f1ce98dde73..c2afd27e45fd9 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1151,7 +1151,17 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o case ast.ColumnOptionFulltext: ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.GenWithStackByArgs()) case ast.ColumnOptionCheck: - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedConstraintCheck.GenWithStackByArgs("CONSTRAINT CHECK")) + // Check the column CHECK constraint dependency lazily, after fill all the name. + // Extract column constraint from column option. + constraint := &ast.Constraint{ + Tp: ast.ConstraintCheck, + Expr: v.Expr, + Enforced: v.Enforced, + Name: v.ConstraintName, + InColumn: true, + InColumnName: colDef.Name.Name.O, + } + constraints = append(constraints, constraint) } } } @@ -1684,6 +1694,26 @@ func checkDuplicateConstraint(namesMap map[string]bool, name string, foreign boo return nil } +func setEmptyCheckConstraintName(tableLowerName string, namesMap map[string]bool, constrs []*ast.Constraint) { + cnt := 1 + constraintPrefix := tableLowerName + "_chk_" + for _, constr := range constrs { + if constr.Name == "" { + constrName := fmt.Sprintf("%s%d", constraintPrefix, cnt) + for { + // loop until find constrName that haven't been used. + if !namesMap[constrName] { + namesMap[constrName] = true + break + } + cnt++ + constrName = fmt.Sprintf("%s%d", constraintPrefix, cnt) + } + constr.Name = constrName + } + } +} + func setEmptyConstraintName(namesMap map[string]bool, constr *ast.Constraint) { if constr.Name == "" && len(constr.Keys) > 0 { var colName string @@ -1711,7 +1741,7 @@ func setEmptyConstraintName(namesMap map[string]bool, constr *ast.Constraint) { } } -func checkConstraintNames(constraints []*ast.Constraint) error { +func checkConstraintNames(tableName model.CIStr, constraints []*ast.Constraint) error { constrNames := map[string]bool{} fkNames := map[string]bool{} @@ -1730,13 +1760,20 @@ func checkConstraintNames(constraints []*ast.Constraint) error { } } + checkConstraints := make([]*ast.Constraint, 0, len(constraints)) // Set empty constraint names. for _, constr := range constraints { + if constr.Tp == ast.ConstraintCheck { + checkConstraints = append(checkConstraints, constr) + } if constr.Tp != ast.ConstraintForeignKey { setEmptyConstraintName(constrNames, constr) } } - + // Set check constraint name under its order. + if len(checkConstraints) > 0 { + setEmptyCheckConstraintName(tableName.L, constrNames, checkConstraints) + } return nil } @@ -1838,11 +1875,14 @@ func BuildTableInfo( Charset: charset, Collate: collate, } + // existedColsMap is used to check existence of the depended column. + existedColsMap := make(map[string]struct{}, len(cols)) tblColumns := make([]*table.Column, 0, len(cols)) for _, v := range cols { v.ID = AllocateColumnID(tbInfo) tbInfo.Columns = append(tbInfo.Columns, v.ToInfo()) tblColumns = append(tblColumns, table.ToColumn(v.ToInfo())) + existedColsMap[v.Name.L] = struct{}{} } foreignKeyID := tbInfo.MaxForeignKeyID for _, constr := range constraints { @@ -1912,10 +1952,6 @@ func BuildTableInfo( ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.GenWithStackByArgs()) continue } - if constr.Tp == ast.ConstraintCheck { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedConstraintCheck.GenWithStackByArgs("CONSTRAINT CHECK")) - continue - } var ( indexName = constr.Name @@ -1932,6 +1968,43 @@ func BuildTableInfo( unique = true } + if constr.Tp == ast.ConstraintCheck { + // Since column check constraint dependency has been done in columnDefToCol. + // Here do the table check constraint dependency check, table constraint + // can only refer the columns in defined columns of the table. + // Refer: https://dev.mysql.com/doc/refman/8.0/en/create-table-check-constraints.html + var dependedCols []model.CIStr + dependedColsMap := findDependedColsMapInExpr(constr.Expr) + if !constr.InColumn { + dependedCols = make([]model.CIStr, 0, len(dependedColsMap)) + for k := range dependedColsMap { + if _, ok := existedColsMap[k]; !ok { + // The table constraint depended on a non-existed column. + return nil, dbterror.ErrTableCheckConstraintReferUnknown.GenWithStackByArgs(constr.Name, k) + } + dependedCols = append(dependedCols, model.NewCIStr(k)) + } + } else { + // Check the column-type constraint dependency. + if len(dependedColsMap) != 1 { + return nil, dbterror.ErrColumnCheckConstraintReferOther.GenWithStackByArgs(constr.Name) + } + if _, ok := dependedColsMap[constr.InColumnName]; !ok { + return nil, dbterror.ErrColumnCheckConstraintReferOther.GenWithStackByArgs(constr.Name) + } + dependedCols = []model.CIStr{model.NewCIStr(constr.InColumnName)} + } + + // build constraint meta info. + constraintInfo, err := buildConstraintInfo(tbInfo, dependedCols, constr, model.StatePublic) + if err != nil { + return nil, errors.Trace(err) + } + constraintInfo.ID = allocateConstraintID(tbInfo) + tbInfo.Constraints = append(tbInfo.Constraints, constraintInfo) + continue + } + // build index info. idxInfo, err := BuildIndexInfo( ctx, @@ -2274,7 +2347,7 @@ func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCh if err != nil { return nil, errors.Trace(err) } - err = checkConstraintNames(newConstraints) + err = checkConstraintNames(s.Table.Name, newConstraints) if err != nil { return nil, errors.Trace(err) } @@ -3300,6 +3373,10 @@ func (d *ddl) AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast err = d.dropIndex(sctx, ident, model.NewCIStr(spec.Name), spec.IfExists) case ast.AlterTableDropPrimaryKey: err = d.dropIndex(sctx, ident, model.NewCIStr(mysql.PrimaryKeyName), spec.IfExists) + case ast.AlterTableDropCheck: + err = d.DropCheckConstraint(sctx, ident, model.NewCIStr(spec.Constraint.Name)) + case ast.AlterTableAlterCheck: + err = d.AlterCheckConstraint(sctx, ident, model.NewCIStr(spec.Constraint.Name), spec.Constraint.Enforced) case ast.AlterTableRenameIndex: err = d.RenameIndex(sctx, ident, spec) case ast.AlterTableDropPartition, ast.AlterTableDropFirstPartition: @@ -3344,7 +3421,7 @@ func (d *ddl) AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast case ast.ConstraintFulltext: sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt) case ast.ConstraintCheck: - sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedConstraintCheck.GenWithStackByArgs("ADD CONSTRAINT CHECK")) + err = d.CreateCheckConstraint(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint) default: // Nothing to do now. } @@ -3439,10 +3516,6 @@ func (d *ddl) AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast err = d.OrderByColumns(sctx, ident) case ast.AlterTableIndexInvisible: err = d.AlterIndexVisibility(sctx, ident, spec.IndexName, spec.Visibility) - case ast.AlterTableAlterCheck: - sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedConstraintCheck.GenWithStackByArgs("ALTER CHECK")) - case ast.AlterTableDropCheck: - sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedConstraintCheck.GenWithStackByArgs("DROP CHECK")) case ast.AlterTableWithValidation: sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedAlterTableWithValidation) case ast.AlterTableWithoutValidation: @@ -6127,6 +6200,126 @@ func GetName4AnonymousIndex(t table.Table, colName model.CIStr, idxName model.CI return indexName } +func (d *ddl) CreateCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr, constr *ast.Constraint) error { + schema, t, err := d.getSchemaAndTableByIdent(ctx, ti) + if err != nil { + return errors.Trace(err) + } + + if err = checkTooLongConstraint(constrName); err != nil { + return errors.Trace(err) + } + + if constraintInfo := t.Meta().FindConstraintInfoByName(constrName.L); constraintInfo != nil { + return infoschema.ErrCheckConstraintDupName.GenWithStackByArgs(constrName.L) + } + + // allocate the temporary constraint name for dependency-check-error-output below. + constrNames := map[string]bool{} + for _, constr := range t.Meta().Constraints { + constrNames[constr.Name.L] = true + } + setEmptyCheckConstraintName(t.Meta().Name.L, constrNames, []*ast.Constraint{constr}) + + // existedColsMap can be used to check the existence of depended. + existedColsMap := make(map[string]struct{}) + cols := t.Cols() + for _, v := range cols { + existedColsMap[v.Name.L] = struct{}{} + } + + dependedColsMap := findDependedColsMapInExpr(constr.Expr) + dependedCols := make([]model.CIStr, 0, len(dependedColsMap)) + for k := range dependedColsMap { + if _, ok := existedColsMap[k]; !ok { + // The table constraint depended on a non-existed column. + return dbterror.ErrTableCheckConstraintReferUnknown.GenWithStackByArgs(constr.Name, k) + } + dependedCols = append(dependedCols, model.NewCIStr(k)) + } + + // build constraint meta info. + tblInfo := t.Meta() + constraintInfo, err := buildConstraintInfo(tblInfo, dependedCols, constr, model.StateNone) + if err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + Type: model.ActionAddCheckConstraint, + BinlogInfo: &model.HistoryInfo{}, + Args: []interface{}{constraintInfo}, + Priority: ctx.GetSessionVars().DDLReorgPriority, + } + + err = d.DoDDLJob(ctx, job) + err = d.callHookOnChanged(job, err) + return errors.Trace(err) +} + +func (d *ddl) DropCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr) error { + is := d.infoCache.GetLatest() + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists) + } + t, err := is.TableByName(ti.Schema, ti.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) + } + + constraintInfo := t.Meta().FindConstraintInfoByName(constrName.L) + if constraintInfo == nil { + return dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + Type: model.ActionDropCheckConstraint, + BinlogInfo: &model.HistoryInfo{}, + Args: []interface{}{constrName}, + } + + err = d.DoDDLJob(ctx, job) + err = d.callHookOnChanged(job, err) + return errors.Trace(err) +} + +func (d *ddl) AlterCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr, enforced bool) error { + is := d.infoCache.GetLatest() + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists) + } + t, err := is.TableByName(ti.Schema, ti.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) + } + + constraintInfo := t.Meta().FindConstraintInfoByName(constrName.L) + if constraintInfo == nil { + return dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + Type: model.ActionAlterCheckConstraint, + BinlogInfo: &model.HistoryInfo{}, + Args: []interface{}{constrName, enforced}, + } + + err = d.DoDDLJob(ctx, job) + err = d.callHookOnChanged(job, err) + return errors.Trace(err) +} + func (d *ddl) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName model.CIStr, indexPartSpecifications []*ast.IndexPartSpecification, indexOption *ast.IndexOption) error { if indexOption != nil && indexOption.PrimaryKeyTp == model.PrimaryKeyTypeClustered { diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index dac2f01216edb..a08d16f8e47b0 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -1302,6 +1302,12 @@ func (w *worker) runDDLJob(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, ver, err = onTTLInfoChange(d, t, job) case model.ActionAlterTTLRemove: ver, err = onTTLInfoRemove(d, t, job) + case model.ActionAddCheckConstraint: + ver, err = w.onAddCheckConstraint(d, t, job) + case model.ActionDropCheckConstraint: + ver, err = onDropCheckConstraint(d, t, job) + case model.ActionAlterCheckConstraint: + ver, err = w.onAlterCheckConstraint(d, t, job) default: // Invalid job, cancel it. job.State = model.JobStateCancelled diff --git a/ddl/rollingback.go b/ddl/rollingback.go index c6f75442479b6..a68fe36e47369 100644 --- a/ddl/rollingback.go +++ b/ddl/rollingback.go @@ -240,6 +240,48 @@ func rollingbackAddIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, isP return } +func rollingbackAddConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + job.State = model.JobStateRollingback + _, tblInfo, constrInfoInMeta, _, err := checkAddCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + if constrInfoInMeta == nil { + // Add constraint hasn't stored constraint info into meta, so we can cancel the job + // directly without further rollback action. + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + // Add constraint has stored constraint info into meta, that means the job has at least + // arrived write only state. + originalState := constrInfoInMeta.State + constrInfoInMeta.State = model.StateWriteOnly + job.SchemaState = model.StateWriteOnly + + job.Args = []interface{}{constrInfoInMeta.Name} + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != constrInfoInMeta.State) + if err != nil { + return ver, errors.Trace(err) + } + return ver, dbterror.ErrCancelledDDLJob +} + +func rollingbackDropConstraint(t *meta.Meta, job *model.Job) (ver int64, err error) { + _, constrInfoInMeta, err := checkDropCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + + // StatePublic means when the job is not running yet. + if constrInfoInMeta.State == model.StatePublic { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + // Can not rollback like drop other element, so just continue to drop constraint. + job.State = model.JobStateRunning + return ver, nil +} + func needNotifyAndStopReorgWorker(job *model.Job) bool { if job.SchemaState == model.StateWriteReorganization && job.SnapshotVer != 0 { // If the value of SnapshotVer isn't zero, it means the reorg workers have been started. @@ -379,6 +421,10 @@ func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) ver, err = rollingbackRenameIndex(t, job) case model.ActionTruncateTable: ver, err = rollingbackTruncateTable(t, job) + case model.ActionAddCheckConstraint: + ver, err = rollingbackAddConstraint(d, t, job) + case model.ActionDropCheckConstraint: + ver, err = rollingbackDropConstraint(t, job) case model.ActionModifyColumn: ver, err = rollingbackModifyColumn(w, d, t, job) case model.ActionDropForeignKey: @@ -389,7 +435,7 @@ func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) model.ActionModifySchemaCharsetAndCollate, model.ActionRepairTable, model.ActionModifyTableAutoIdCache, model.ActionAlterIndexVisibility, model.ActionExchangeTablePartition, model.ActionModifySchemaDefaultPlacement, - model.ActionRecoverSchema: + model.ActionRecoverSchema, model.ActionAlterCheckConstraint: ver, err = cancelOnlyNotHandledJob(job, model.StateNone) case model.ActionMultiSchemaChange: err = rollingBackMultiSchemaChange(job) diff --git a/errno/errcode.go b/errno/errcode.go index a72c6b2ec5e64..4d04da3af3a3a 100644 --- a/errno/errcode.go +++ b/errno/errcode.go @@ -913,6 +913,18 @@ const ( ErrDefValGeneratedNamedFunctionIsNotAllowed = 3770 ErrFKIncompatibleColumns = 3780 ErrFunctionalIndexRowValueIsNotAllowed = 3800 + ErrNonBooleanExprForCheckConstraint = 3812 + ErrColumnCheckConstraintReferencesOtherColumn = 3813 + ErrCheckConstraintNamedFunctionIsNotAllowed = 3814 + ErrCheckConstraintFunctionIsNotAllowed = 3815 + ErrCheckConstraintVariables = 3816 + ErrCheckConstraintRowValue = 3817 + ErrCheckConstraintRefersAutoIncrementColumn = 3818 + ErrCheckConstraintViolated = 3819 + ErrTableCheckConstraintReferUnknown = 3820 + ErrCheckConstraintNotFound = 3821 + ErrCheckConstraintDupName = 3822 + ErrCheckConstraintClauseUsingFKReferActionColumn = 3823 ErrDependentByFunctionalIndex = 3837 ErrCannotConvertString = 3854 ErrInvalidJSONValueForFuncIndex = 3903 @@ -920,7 +932,9 @@ const ( ErrFunctionalIndexDataIsTooLong = 3907 ErrFunctionalIndexNotApplicable = 3909 ErrDynamicPrivilegeNotRegistered = 3929 + ErrConstraintNotFound = 3940 ErUserAccessDeniedForUserAccountBlockedByPasswordLock = 3955 + ErrDependentByCheckConstraint = 3959 ErrTableWithoutPrimaryKey = 3750 // MariaDB errors. ErrOnlyOneDefaultPartionAllowed = 4030 diff --git a/errno/errname.go b/errno/errname.go index 8eb0cec9d8775..fd6443311e344 100644 --- a/errno/errname.go +++ b/errno/errname.go @@ -837,6 +837,7 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{ ErrGeneratedColumnRefAutoInc: mysql.Message("Generated column '%s' cannot refer to auto-increment column.", nil), ErrAccountHasBeenLocked: mysql.Message("Access denied for user '%s'@'%s'. Account is locked.", nil), ErUserAccessDeniedForUserAccountBlockedByPasswordLock: mysql.Message("Access denied for user '%s'@'%s'. Account is blocked for %s day(s) (%s day(s) remaining) due to %d consecutive failed logins.", nil), + ErrDependentByCheckConstraint: mysql.Message("Check constraint '%s' uses column '%s', hence column cannot be dropped or renamed.", nil), ErrWarnConflictingHint: mysql.Message("Hint %s is ignored as conflicting/duplicated.", nil), ErrUnresolvedHintName: mysql.Message("Unresolved name '%s' for %s hint", nil), ErrForeignKeyCascadeDepthExceeded: mysql.Message("Foreign key cascade delete/update exceeds max depth of %v.", nil), @@ -913,8 +914,21 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{ ErrJSONValueOutOfRangeForFuncIndex: mysql.Message("Out of range JSON value for CAST for expression index '%s'", nil), ErrFunctionalIndexDataIsTooLong: mysql.Message("Data too long for expression index '%s'", nil), ErrFunctionalIndexNotApplicable: mysql.Message("Cannot use expression index '%s' due to type or collation conversion", nil), + ErrNonBooleanExprForCheckConstraint: mysql.Message("An expression of non-boolean type specified to a check constraint '%s'.", nil), + ErrColumnCheckConstraintReferencesOtherColumn: mysql.Message("Column check constraint '%s' references other column.", nil), + ErrCheckConstraintNamedFunctionIsNotAllowed: mysql.Message("An expression of a check constraint '%s' contains disallowed function: %s.", nil), + ErrCheckConstraintFunctionIsNotAllowed: mysql.Message("An expression of a check constraint '%s' contains disallowed function.", nil), + ErrCheckConstraintVariables: mysql.Message("An expression of a check constraint '%s' cannot refer to a user or system variable.", nil), + ErrCheckConstraintRowValue: mysql.Message("Check constraint '%s' cannot refer to a row value.", nil), + ErrCheckConstraintRefersAutoIncrementColumn: mysql.Message("Check constraint '%s' cannot refer to an auto-increment column.", nil), + ErrCheckConstraintViolated: mysql.Message("Check constraint '%s' is violated.", nil), + ErrTableCheckConstraintReferUnknown: mysql.Message("Check constraint '%s' refers to non-existing column '%s'.", nil), + ErrCheckConstraintNotFound: mysql.Message("Check constraint '%s' is not found in the table.", nil), + ErrCheckConstraintDupName: mysql.Message("Duplicate check constraint name '%s'.", nil), + ErrCheckConstraintClauseUsingFKReferActionColumn: mysql.Message("Column '%s' cannot be used in a check constraint '%s': needed in a foreign key constraint '%s' referential action.", nil), ErrUnsupportedConstraintCheck: mysql.Message("%s is not supported", nil), ErrDynamicPrivilegeNotRegistered: mysql.Message("Dynamic privilege '%s' is not registered with the server.", nil), + ErrConstraintNotFound: mysql.Message("Constraint '%s' does not exist.", nil), ErrIllegalPrivilegeLevel: mysql.Message("Illegal privilege level specified for %s", nil), ErrCTERecursiveRequiresUnion: mysql.Message("Recursive Common Table Expression '%s' should contain a UNION", nil), ErrCTERecursiveRequiresNonRecursiveFirst: mysql.Message("Recursive Common Table Expression '%s' should have one or more non-recursive query blocks followed by one or more recursive ones", nil), diff --git a/errors.toml b/errors.toml index 36865561a3c1e..1ccc91ecd4cff 100644 --- a/errors.toml +++ b/errors.toml @@ -1096,6 +1096,11 @@ error = ''' JSON column '%-.192s' cannot be used in key specification. ''' +["ddl:3184"] +error = ''' +Invalid encryption option. +''' + ["ddl:3505"] error = ''' Too long enumeration/set value for column %s. @@ -1171,11 +1176,31 @@ error = ''' Expression of expression index '%s' cannot refer to a row value ''' +["ddl:3813"] +error = ''' +Column check constraint '%s' references other column. +''' + +["ddl:3819"] +error = ''' +Check constraint '%s' is violated. +''' + +["ddl:3820"] +error = ''' +Check constraint '%s' refers to non-existing column '%s'. +''' + ["ddl:3837"] error = ''' Column '%s' has an expression index dependency and cannot be dropped or renamed ''' +["ddl:3940"] +error = ''' +Constraint '%s' does not exist. +''' + ["ddl:4135"] error = ''' Sequence '%-.64s.%-.64s' has run out @@ -2451,6 +2476,11 @@ error = ''' Unable to create or change a table without a primary key, when the system variable 'sql_require_primary_key' is set. Add a primary key to the table or unset this variable to avoid this message. Note that tables without a primary key can cause performance problems in row-based replication, so please consult your DBA before changing this setting. ''' +["schema:3822"] +error = ''' +Duplicate check constraint name '%s'. +''' + ["schema:4139"] error = ''' Unknown SEQUENCE: '%-.300s' @@ -2556,6 +2586,11 @@ error = ''' Found a row not matching the given partition set ''' +["table:3819"] +error = ''' +Check constraint '%s' is violated. +''' + ["table:4135"] error = ''' Sequence '%-.64s.%-.64s' has run out diff --git a/executor/show.go b/executor/show.go index 2206610cfc517..0ae2b9b15a1e7 100644 --- a/executor/show.go +++ b/executor/show.go @@ -1150,6 +1150,25 @@ func ConstructResultOfShowCreateTable(ctx sessionctx.Context, tableInfo *model.T } } + publicConstraints := make([]*model.ConstraintInfo, 0, len(tableInfo.Indices)) + for _, constr := range tableInfo.Constraints { + if constr.State == model.StatePublic { + publicConstraints = append(publicConstraints, constr) + } + } + if len(publicConstraints) > 0 { + buf.WriteString(",\n") + } + for i, constrInfo := range publicConstraints { + fmt.Fprintf(buf, "CONSTRAINT %s CHECK ((%s))", stringutil.Escape(constrInfo.Name.O, sqlMode), constrInfo.ExprString) + if !constrInfo.Enforced { + buf.WriteString(" /*!80016 NOT ENFORCED */") + } + if i != len(publicConstraints)-1 { + buf.WriteString(",\n") + } + } + buf.WriteString("\n") buf.WriteString(") ENGINE=InnoDB") diff --git a/executor/showtest/show_test.go b/executor/showtest/show_test.go index 11b2db520d723..267c40daecadf 100644 --- a/executor/showtest/show_test.go +++ b/executor/showtest/show_test.go @@ -1554,6 +1554,40 @@ func TestShowClusterConfig(t *testing.T) { require.EqualError(t, tk.QueryToErr("show config"), confErr.Error()) } +func TestShowCheckConstraint(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists t") + // Create table with check constraint + tk.MustExec("create table t(a int check (a>1), b int, constraint my_constr check(a<10))") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` int(11) DEFAULT NULL,\n" + + "CONSTRAINT `my_constr` CHECK ((`a` < 10)),\n" + + "CONSTRAINT `t_chk_1` CHECK ((`a` > 1))\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + // Alter table add constraint. + tk.MustExec("alter table t add constraint my_constr2 check (a 1)),\n" + + "CONSTRAINT `my_constr2` CHECK ((`a` < `b`)) /*!80016 NOT ENFORCED */\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + // Alter table drop constraint. + tk.MustExec("alter table t drop constraint t_chk_1") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` int(11) DEFAULT NULL,\n" + + "CONSTRAINT `my_constr` CHECK ((`a` < 10)),\n" + + "CONSTRAINT `my_constr2` CHECK ((`a` < `b`)) /*!80016 NOT ENFORCED */\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + tk.MustExec("drop table if exists t") +} + func TestInvisibleCoprCacheConfig(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/infoschema/error.go b/infoschema/error.go index a7e4929a35bcc..f2ec28f6facd2 100644 --- a/infoschema/error.go +++ b/infoschema/error.go @@ -50,6 +50,8 @@ var ( ErrNonuniqTable = dbterror.ClassSchema.NewStd(mysql.ErrNonuniqTable) // ErrMultiplePriKey returns for multiple primary keys. ErrMultiplePriKey = dbterror.ClassSchema.NewStd(mysql.ErrMultiplePriKey) + // ErrCheckConstraintDupName returns for duplicate constraint names. + ErrCheckConstraintDupName = dbterror.ClassSchema.NewStd(mysql.ErrCheckConstraintDupName) // ErrTooManyKeyParts returns for too many key parts. ErrTooManyKeyParts = dbterror.ClassSchema.NewStd(mysql.ErrTooManyKeyParts) // ErrForeignKeyNotExists returns for foreign key not exists. diff --git a/parser/model/ddl.go b/parser/model/ddl.go index c9b36a9e9ef3a..52e1534a503dd 100644 --- a/parser/model/ddl.go +++ b/parser/model/ddl.go @@ -828,7 +828,7 @@ func (job *Job) IsRollbackable() bool { ActionTruncateTable, ActionAddForeignKey, ActionRenameTable, ActionModifyTableCharsetAndCollate, ActionTruncateTablePartition, ActionModifySchemaCharsetAndCollate, ActionRepairTable, - ActionModifyTableAutoIdCache, ActionModifySchemaDefaultPlacement: + ActionModifyTableAutoIdCache, ActionModifySchemaDefaultPlacement, ActionDropCheckConstraint: return job.SchemaState == StateNone case ActionMultiSchemaChange: return job.MultiSchemaInfo.Revertible diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index c5a2a9999db90..f8a186c12d0f7 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -682,6 +682,7 @@ func isConstraintKeyTp(constraints []*ast.Constraint, colDef *ast.ColumnDef) boo for _, c := range constraints { // ignore constraint check if c.Tp == ast.ConstraintCheck { + // TODO continue } if c.Keys[0].Expr != nil { diff --git a/table/BUILD.bazel b/table/BUILD.bazel index a1f2feab60722..577def58a0af5 100644 --- a/table/BUILD.bazel +++ b/table/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "table", srcs = [ "column.go", + "constraint.go", "index.go", "table.go", ], @@ -27,6 +28,7 @@ go_library( "//util/dbterror", "//util/hack", "//util/logutil", + "//util/mock", "//util/sqlexec", "//util/timeutil", "@com_github_opentracing_opentracing_go//:opentracing-go", diff --git a/table/constraint.go b/table/constraint.go new file mode 100644 index 0000000000000..72d530af7bbd7 --- /dev/null +++ b/table/constraint.go @@ -0,0 +1,68 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package table + +import ( + "github.com/pingcap/errors" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/mock" + "go.uber.org/zap" +) + +// Constraint provides meta and map dependency describing a table constraint. +type Constraint struct { + *model.ConstraintInfo + + ConstraintExpr expression.Expression +} + +// ToConstraint converts model.ConstraintInfo to Constraint +func ToConstraint(constraintInfo *model.ConstraintInfo, tblInfo *model.TableInfo) (*Constraint, error) { + ctx := mock.NewContext() + dbName := model.NewCIStr(ctx.GetSessionVars().CurrentDB) + columns, names, err := expression.ColumnInfos2ColumnsAndNames(ctx, dbName, tblInfo.Name, tblInfo.Columns, tblInfo) + if err != nil { + return nil, errors.Trace(err) + } + expr, err := buildConstraintExpression(ctx, constraintInfo.ExprString, columns, names) + if err != nil { + return nil, errors.Trace(err) + } + return &Constraint{ + constraintInfo, + expr, + }, nil +} + +func buildConstraintExpression(ctx sessionctx.Context, exprString string, + columns []*expression.Column, names types.NameSlice) (expression.Expression, error) { + schema := expression.NewSchema(columns...) + exprs, err := expression.ParseSimpleExprsWithNames(ctx, exprString, schema, names) + if err != nil { + // If it got an error here, ddl may hang forever, so this error log is important. + logutil.BgLogger().Error("wrong check constraint expression", zap.String("expression", exprString), zap.Error(err)) + return nil, errors.Trace(err) + } + return exprs[0], nil +} + +// ToInfo get the ConstraintInfo of the Constraint +func (c *Constraint) ToInfo() *model.ConstraintInfo { + return c.ConstraintInfo +} diff --git a/table/table.go b/table/table.go index 813131df90896..39e23c09793fa 100644 --- a/table/table.go +++ b/table/table.go @@ -104,6 +104,8 @@ var ( ErrTempTableFull = dbterror.ClassTable.NewStd(mysql.ErrRecordFileFull) // ErrOptOnCacheTable returns when exec unsupported opt at cache mode ErrOptOnCacheTable = dbterror.ClassDDL.NewStd(mysql.ErrOptOnCacheTable) + // ErrCheckConstraintViolated return when check constraint is violated. + ErrCheckConstraintViolated = dbterror.ClassTable.NewStd(mysql.ErrCheckConstraintViolated) ) // RecordIterFunc is used for low-level record iteration. diff --git a/table/tables/partition.go b/table/tables/partition.go index 6a0b315b856e9..196a78d38b195 100644 --- a/table/tables/partition.go +++ b/table/tables/partition.go @@ -104,7 +104,7 @@ func newPartitionedTable(tbl *TableCommon, tblInfo *model.TableInfo) (table.Tabl partitions := make(map[int64]*partition, len(pi.Definitions)) for _, p := range pi.Definitions { var t partition - err := initTableCommonWithIndices(&t.TableCommon, tblInfo, p.ID, tbl.Columns, tbl.allocs) + err := initTableCommonWithIndices(&t.TableCommon, tblInfo, p.ID, tbl.Columns, tbl.allocs, tbl.Constraints) if err != nil { return nil, errors.Trace(err) } diff --git a/table/tables/tables.go b/table/tables/tables.go index da20b1647fbd8..a4616a35d985c 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -44,6 +44,7 @@ import ( "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/generatedexpr" @@ -70,6 +71,8 @@ type TableCommon struct { meta *model.TableInfo allocs autoid.Allocators sequence *sequenceCommon + Constraints []*table.Constraint + WritableConstraints []*table.Constraint // recordPrefix and indexPrefix are generated using physicalTableID. recordPrefix kv.Key @@ -84,8 +87,17 @@ func MockTableFromMeta(tblInfo *model.TableInfo) table.Table { columns = append(columns, col) } + constraints := make([]*table.Constraint, 0, len(tblInfo.Constraints)) + for _, constraintInfo := range tblInfo.Constraints { + constraint, err := table.ToConstraint(constraintInfo, tblInfo) + if err != nil { + return nil + } + constraints = append(constraints, constraint) + } + var t TableCommon - initTableCommon(&t, tblInfo, tblInfo.ID, columns, autoid.NewAllocators(false)) + initTableCommon(&t, tblInfo, tblInfo.ID, columns, autoid.NewAllocators(false), constraints) if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { ret, err := newCachedTable(&t) if err != nil { @@ -148,8 +160,17 @@ func TableFromMeta(allocs autoid.Allocators, tblInfo *model.TableInfo) (table.Ta columns = append(columns, col) } + constraints := make([]*table.Constraint, 0, len(tblInfo.Constraints)) + for _, conInfo := range tblInfo.Constraints { + con, err := table.ToConstraint(conInfo, tblInfo) + if err != nil { + return nil, err + } + constraints = append(constraints, con) + } + var t TableCommon - initTableCommon(&t, tblInfo, tblInfo.ID, columns, allocs) + initTableCommon(&t, tblInfo, tblInfo.ID, columns, allocs, constraints) if tblInfo.GetPartitionInfo() == nil { if err := initTableIndices(&t); err != nil { return nil, err @@ -163,16 +184,18 @@ func TableFromMeta(allocs autoid.Allocators, tblInfo *model.TableInfo) (table.Ta } // initTableCommon initializes a TableCommon struct. -func initTableCommon(t *TableCommon, tblInfo *model.TableInfo, physicalTableID int64, cols []*table.Column, allocs autoid.Allocators) { +func initTableCommon(t *TableCommon, tblInfo *model.TableInfo, physicalTableID int64, cols []*table.Column, allocs autoid.Allocators, constraints []*table.Constraint) { t.tableID = tblInfo.ID t.physicalTableID = physicalTableID t.allocs = allocs t.meta = tblInfo t.Columns = cols + t.Constraints = constraints t.PublicColumns = t.Cols() t.VisibleColumns = t.VisibleCols() t.HiddenColumns = t.HiddenCols() t.WritableColumns = t.WritableCols() + t.WritableConstraints = t.WritableConstraint() t.FullHiddenColsAndVisibleColumns = t.FullHiddenColsAndVisibleCols() t.recordPrefix = tablecodec.GenTableRecordPrefix(physicalTableID) t.indexPrefix = tablecodec.GenTableIndexPrefix(physicalTableID) @@ -196,8 +219,8 @@ func initTableIndices(t *TableCommon) error { return nil } -func initTableCommonWithIndices(t *TableCommon, tblInfo *model.TableInfo, physicalTableID int64, cols []*table.Column, allocs autoid.Allocators) error { - initTableCommon(t, tblInfo, physicalTableID, cols, allocs) +func initTableCommonWithIndices(t *TableCommon, tblInfo *model.TableInfo, physicalTableID int64, cols []*table.Column, allocs autoid.Allocators, constraints []*table.Constraint) error { + initTableCommon(t, tblInfo, physicalTableID, cols, allocs, constraints) return initTableIndices(t) } @@ -297,6 +320,24 @@ func (t *TableCommon) WritableCols() []*table.Column { return writableColumns } +// WritableConstraint returns constraints of the table in writable states. +func (t *TableCommon) WritableConstraint() []*table.Constraint { + if len(t.WritableConstraints) > 0 { + return t.WritableConstraints + } + if t.Constraints == nil { + return nil + } + writeableConstraint := make([]*table.Constraint, 0, len(t.Constraints)) + for _, con := range t.Constraints { + if con.State == model.StateDeleteOnly || con.State == model.StateDeleteReorganization { + continue + } + writeableConstraint = append(writeableConstraint, con) + } + return writeableConstraint +} + // DeletableCols implements table DeletableCols interface. func (t *TableCommon) DeletableCols() []*table.Column { return t.Columns @@ -364,6 +405,7 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx sessionctx.Context, binlogOldRow = make([]types.Datum, 0, numColsCap) binlogNewRow = make([]types.Datum, 0, numColsCap) } + rowToCheck := make([]types.Datum, 0, numColsCap) for _, col := range t.Columns { var value types.Datum @@ -402,12 +444,24 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx sessionctx.Context, colIDs = append(colIDs, col.ID) row = append(row, value) } + rowToCheck = append(rowToCheck, value) if shouldWriteBinlog(sctx, t.meta) && !t.canSkipUpdateBinlog(col, value) { binlogColIDs = append(binlogColIDs, col.ID) binlogOldRow = append(binlogOldRow, oldData[col.Offset]) binlogNewRow = append(binlogNewRow, value) } } + + for _, constraint := range t.WritableConstraint() { + ok, isNull, err := constraint.ConstraintExpr.EvalInt(sctx, chunk.MutRowFromDatums(rowToCheck).ToRow()) + if err != nil { + return err + } + if ok == 0 && !isNull { + return table.ErrCheckConstraintViolated.FastGenByArgs(constraint.Name.O) + } + } + sessVars := sctx.GetSessionVars() // rebuild index if !sessVars.InTxn() { @@ -763,6 +817,7 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . var colIDs, binlogColIDs []int64 var row, binlogRow []types.Datum + rowToCheck := make([]types.Datum, 0, len(r)) if recordCtx, ok := sctx.Value(addRecordCtxKey).(*CommonAddRecordCtx); ok { colIDs = recordCtx.colIDs[:0] row = recordCtx.row[:0] @@ -814,12 +869,23 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . } else { value = r[col.Offset] } + rowToCheck = append(rowToCheck, value) if !t.canSkip(col, &value) { colIDs = append(colIDs, col.ID) row = append(row, value) } } + for _, constraint := range t.WritableConstraint() { + ok, isNull, err := constraint.ConstraintExpr.EvalInt(sctx, chunk.MutRowFromDatums(r).ToRow()) + if err != nil { + return nil, err + } + if ok == 0 && !isNull { + return nil, table.ErrCheckConstraintViolated.FastGenByArgs(constraint.Name.O) + } + } + writeBufs := sessVars.GetWriteStmtBufs() adjustRowValuesBuf(writeBufs, len(row)) key := t.RecordKey(recordID) diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index 661770b868383..5357078ec2e2a 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -600,6 +600,73 @@ func TestHiddenColumn(t *testing.T) { "f|tinyint(4)|YES|||VIRTUAL GENERATED")) } +func TestCheckConstraintOnInsert(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("DROP DATABASE IF EXISTS test_insert_check_constraint;") + tk.MustExec("CREATE DATABASE test_insert_check_constraint;") + tk.MustExec("USE test_insert_check_constraint;") + tk.MustExec("CREATE TABLE t1 (CHECK (c1 <> c2), c1 INT CHECK (c1 > 10), c2 INT CONSTRAINT c2_positive CHECK (c2 > 0));") + tk.MustGetErrMsg("insert into t1 values (2, 2)", "[table:3819]Check constraint 't1_chk_1' is violated.") + tk.MustGetErrMsg("insert into t1 values (9, 2)", "[table:3819]Check constraint 't1_chk_2' is violated.") + tk.MustGetErrMsg("insert into t1 values (14, -4)", "[table:3819]Check constraint 'c2_positive' is violated.") + tk.MustGetErrMsg("insert into t1(c1) values (9)", "[table:3819]Check constraint 't1_chk_2' is violated.") + tk.MustGetErrMsg("insert into t1(c2) values (-3)", "[table:3819]Check constraint 'c2_positive' is violated.") + tk.MustExec("insert into t1 values (14, 4)") + tk.MustExec("insert into t1 values (null, 4)") + tk.MustExec("insert into t1 values (13, null)") + tk.MustExec("insert into t1 values (null, null)") + tk.MustExec("insert into t1(c1) values (null)") + tk.MustExec("insert into t1(c2) values (null)") + + // Test generated column with check constraint. + tk.MustExec("CREATE TABLE t2 (CHECK (c1 <> c2), c1 INT CHECK (c1 > 10), c2 INT CONSTRAINT c2_positive CHECK (c2 > 0), c3 int as (c1 + c2) check(c3 > 15));") + tk.MustGetErrMsg("insert into t2(c1, c2) values (11, 1)", "[table:3819]Check constraint 't2_chk_3' is violated.") + tk.MustExec("insert into t2(c1, c2) values (12, 7)") +} + +func TestCheckConstraintOnUpdate(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("DROP DATABASE IF EXISTS test_update_check_constraint;") + tk.MustExec("CREATE DATABASE test_update_check_constraint;") + tk.MustExec("USE test_update_check_constraint;") + + tk.MustExec("CREATE TABLE t1 (CHECK (c1 <> c2), c1 INT CHECK (c1 > 10), c2 INT CONSTRAINT c2_positive CHECK (c2 > 0));") + tk.MustExec("insert into t1 values (11, 12), (12, 13), (13, 14), (14, 15), (15, 16);") + tk.MustGetErrMsg("update t1 set c2 = -c2;", "[table:3819]Check constraint 'c2_positive' is violated.") + tk.MustGetErrMsg("update t1 set c2 = c1;", "[table:3819]Check constraint 't1_chk_1' is violated.") + tk.MustGetErrMsg("update t1 set c1 = c1 - 10;", "[table:3819]Check constraint 't1_chk_2' is violated.") + tk.MustGetErrMsg("update t1 set c2 = -10 where c2 = 12;", "[table:3819]Check constraint 'c2_positive' is violated.") + + // Test generated column with check constraint. + tk.MustExec("CREATE TABLE t2 (CHECK (c1 <> c2), c1 INT CHECK (c1 > 10), c2 INT CONSTRAINT c2_positive CHECK (c2 > 0), c3 int as (c1 + c2) check(c3 > 15));") + tk.MustExec("insert into t2(c1, c2) values (11, 12), (12, 13), (13, 14), (14, 15), (15, 16);") + tk.MustGetErrMsg("update t2 set c2 = c2 - 10;", "[table:3819]Check constraint 't2_chk_3' is violated.") + tk.MustExec("update t2 set c2 = c2 - 5;") +} + +func TestCheckConstraintOnUpdateWithPartition(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("DROP DATABASE IF EXISTS test_update_check_constraint_hash;") + tk.MustExec("CREATE DATABASE test_update_check_constraint_hash;") + tk.MustExec("USE test_update_check_constraint_hash;") + + tk.MustExec("CREATE TABLE t1 (CHECK (c1 <> c2), c1 INT CHECK (c1 > 10), c2 INT CONSTRAINT c2_positive CHECK (c2 > 0)) partition by hash(c2) partitions 5;") + tk.MustExec("insert into t1 values (11, 12), (12, 13), (13, 14), (14, 15), (15, 16);") + tk.MustGetErrMsg("update t1 set c2 = -c2;", "[table:3819]Check constraint 'c2_positive' is violated.") + tk.MustGetErrMsg("update t1 set c2 = c1;", "[table:3819]Check constraint 't1_chk_1' is violated.") + tk.MustGetErrMsg("update t1 set c1 = c1 - 10;", "[table:3819]Check constraint 't1_chk_2' is violated.") + tk.MustGetErrMsg("update t1 set c2 = -10 where c2 = 12;", "[table:3819]Check constraint 'c2_positive' is violated.") + + // Test generated column with check constraint. + tk.MustExec("CREATE TABLE t2 (CHECK (c1 <> c2), c1 INT CHECK (c1 > 10), c2 INT CONSTRAINT c2_positive CHECK (c2 > 0), c3 int as (c1 + c2) check(c3 > 15)) partition by hash(c2) partitions 5;") + tk.MustExec("insert into t2(c1, c2) values (11, 12), (12, 13), (13, 14), (14, 15), (15, 16);") + tk.MustGetErrMsg("update t2 set c2 = c2 - 10;", "[table:3819]Check constraint 't2_chk_3' is violated.") + tk.MustExec("update t2 set c2 = c2 - 5;") +} + func TestAddRecordWithCtx(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) diff --git a/util/dbterror/ddl_terror.go b/util/dbterror/ddl_terror.go index 1db0630977945..ecc7a9482018f 100644 --- a/util/dbterror/ddl_terror.go +++ b/util/dbterror/ddl_terror.go @@ -298,6 +298,14 @@ var ( ErrAddColumnWithSequenceAsDefault = ClassDDL.NewStd(mysql.ErrAddColumnWithSequenceAsDefault) // ErrUnsupportedExpressionIndex is returned when create an expression index without allow-expression-index. ErrUnsupportedExpressionIndex = ClassDDL.NewStdErr(mysql.ErrUnsupportedDDLOperation, parser_mysql.Message(fmt.Sprintf(mysql.MySQLErrName[mysql.ErrUnsupportedDDLOperation].Raw, "creating expression index containing unsafe functions without allow-expression-index in config"), nil)) + // ErrColumnCheckConstraintReferOther is returned when create column check constraint referring other column. + ErrColumnCheckConstraintReferOther = ClassDDL.NewStd(mysql.ErrColumnCheckConstraintReferencesOtherColumn) + // ErrTableCheckConstraintReferUnknown is returned when create table check constraint referring non-existing column. + ErrTableCheckConstraintReferUnknown = ClassDDL.NewStd(mysql.ErrTableCheckConstraintReferUnknown) + // ErrConstraintNotFound is returned for dropping a non-existent constraint. + ErrConstraintNotFound = ClassDDL.NewStd(mysql.ErrConstraintNotFound) + // ErrCheckConstraintIsViolated is returned for violating an existent check constraint. + ErrCheckConstraintIsViolated = ClassDDL.NewStd(mysql.ErrCheckConstraintViolated) // ErrPartitionExchangePartTable is returned when exchange table partition with another table is partitioned. ErrPartitionExchangePartTable = ClassDDL.NewStd(mysql.ErrPartitionExchangePartTable) // ErrPartitionExchangeTempTable is returned when exchange table partition with a temporary table