diff --git a/ddl/db_change_test.go b/ddl/db_change_test.go index 92f1bed1c07af..91adaee1c12bb 100644 --- a/ddl/db_change_test.go +++ b/ddl/db_change_test.go @@ -501,7 +501,7 @@ func (s *testStateChangeSuite) TestWriteOnlyWriteNULL(c *C) { addColumnSQL := "alter table t add column c5 int not null default 1 after c4" expectQuery := &expectQuery{"select c4, c5 from t", []string{"8 1"}} // TODO: This case should always fail in write-only state, but it doesn't. We use write-reorganization state here to keep it running stable. It need a double check. - s.runTestInSchemaState(c, model.StateWriteReorganization, "", addColumnSQL, sqls, expectQuery) + s.runTestInSchemaState(c, model.StateWriteReorganization, true, addColumnSQL, sqls, expectQuery) } func (s *testStateChangeSuite) TestWriteOnlyOnDupUpdate(c *C) { @@ -512,7 +512,7 @@ func (s *testStateChangeSuite) TestWriteOnlyOnDupUpdate(c *C) { addColumnSQL := "alter table t add column c5 int not null default 1 after c4" expectQuery := &expectQuery{"select c4, c5 from t", []string{"2 1"}} // TODO: This case should always fail in write-only state, but it doesn't. We use write-reorganization state here to keep it running stable. It need a double check. - s.runTestInSchemaState(c, model.StateWriteReorganization, "", addColumnSQL, sqls, expectQuery) + s.runTestInSchemaState(c, model.StateWriteReorganization, true, addColumnSQL, sqls, expectQuery) } // TestWriteOnly tests whether the correct columns is used in PhysicalIndexScan's ToPB function. @@ -522,7 +522,7 @@ func (s *testStateChangeSuite) TestWriteOnly(c *C) { sqls[1] = sqlWithErr{"update t use index(idx2) set c1 = 'c1_update' where c1 = 'a'", nil} sqls[2] = sqlWithErr{"insert t set c1 = 'c1_insert', c3 = '2018-02-12', c4 = 1", nil} addColumnSQL := "alter table t add column c5 int not null default 1 first" - s.runTestInSchemaState(c, model.StateWriteOnly, "", addColumnSQL, sqls, nil) + s.runTestInSchemaState(c, model.StateWriteOnly, true, addColumnSQL, sqls, nil) } // TestDeletaOnly tests whether the correct columns is used in PhysicalIndexScan's ToPB function. @@ -531,17 +531,37 @@ func (s *testStateChangeSuite) TestDeleteOnly(c *C) { sqls[0] = sqlWithErr{"insert t set c1 = 'c1_insert', c3 = '2018-02-12', c4 = 1", errors.Errorf("Can't find column c1")} dropColumnSQL := "alter table t drop column c1" - s.runTestInSchemaState(c, model.StateDeleteOnly, "", dropColumnSQL, sqls, nil) + s.runTestInSchemaState(c, model.StateDeleteOnly, true, dropColumnSQL, sqls, nil) } -func (s *testStateChangeSuite) runTestInSchemaState(c *C, state model.SchemaState, tableName, alterTableSQL string, +func (s *testStateChangeSuite) TestWriteOnlyForDropColumn(c *C) { + _, err := s.se.Execute(context.Background(), "use test_db_state") + c.Assert(err, IsNil) + _, err = s.se.Execute(context.Background(), `create table tt (c1 int, c4 int)`) + c.Assert(err, IsNil) + _, err = s.se.Execute(context.Background(), "insert into tt (c1, c4) values(8, 8)") + c.Assert(err, IsNil) + defer s.se.Execute(context.Background(), "drop table tt") + + sqls := make([]sqlWithErr, 2) + sqls[0] = sqlWithErr{"update t set c1='5', c3='2020-03-01';", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")} + sqls[1] = sqlWithErr{"update t t1, tt t2 set t1.c1='5', t1.c3='2020-03-01', t2.c1='10' where t1.c4=t2.c4", + errors.New("[planner:1054]Unknown column 'c3' in 'field list'")} + // TODO: Fix the case of sqls[2]. + // sqls[2] = sqlWithErr{"update t set c1='5' where c3='2017-07-01';", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")} + dropColumnSQL := "alter table t drop column c3" + query := &expectQuery{sql: "select * from t;", rows: []string{"a N 8"}} + s.runTestInSchemaState(c, model.StateWriteOnly, false, dropColumnSQL, sqls, query) +} + +func (s *testStateChangeSuite) runTestInSchemaState(c *C, state model.SchemaState, isOnJobUpdated bool, alterTableSQL string, sqlWithErrs []sqlWithErr, expectQuery *expectQuery) { _, err := s.se.Execute(context.Background(), `create table t ( - c1 varchar(64), - c2 enum('N','Y') not null default 'N', - c3 timestamp on update current_timestamp, - c4 int primary key, - unique key idx2 (c2, c3))`) + c1 varchar(64), + c2 enum('N','Y') not null default 'N', + c3 timestamp on update current_timestamp, + c4 int primary key, + unique key idx2 (c2))`) c.Assert(err, IsNil) defer s.se.Execute(context.Background(), "drop table t") _, err = s.se.Execute(context.Background(), "insert into t values('a', 'N', '2017-07-01', 8)") @@ -558,7 +578,7 @@ func (s *testStateChangeSuite) runTestInSchemaState(c *C, state model.SchemaStat c.Assert(err, IsNil) _, err = se.Execute(context.Background(), "use test_db_state") c.Assert(err, IsNil) - callback.OnJobUpdatedExported = func(job *model.Job) { + cbFunc := func(job *model.Job) { if job.SchemaState == prevState || checkErr != nil || times >= 3 { return } @@ -570,10 +590,18 @@ func (s *testStateChangeSuite) runTestInSchemaState(c *C, state model.SchemaStat _, err = se.Execute(context.Background(), sqlWithErr.sql) if !terror.ErrorEqual(err, sqlWithErr.expectErr) { checkErr = err + if checkErr == nil { + checkErr = errors.New("err can't be nil") + } break } } } + if isOnJobUpdated { + callback.OnJobUpdatedExported = cbFunc + } else { + callback.OnJobRunBeforeExported = cbFunc + } d := s.dom.DDL() originalCallback := d.GetHook() d.(ddl.DDLForTest).SetHook(callback) @@ -664,7 +692,7 @@ func (s *testStateChangeSuite) TestShowIndex(c *C) { c.Assert(err, IsNil) _, err = s.se.Execute(context.Background(), `create table tr( - id int, name varchar(50), + id int, name varchar(50), purchased date ) partition by range( year(purchased) ) ( diff --git a/executor/update.go b/executor/update.go index 11d2b36d9def4..4a63a86e1b493 100644 --- a/executor/update.go +++ b/executor/update.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/sessionctx" + plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -54,7 +54,7 @@ type UpdateExec struct { } func (e *UpdateExec) exec(ctx context.Context, schema *expression.Schema) ([]types.Datum, error) { - assignFlag, err := e.getUpdateColumns(e.ctx, schema.Len()) + assignFlag, err := plannercore.GetUpdateColumns(e.ctx, e.OrderedList, schema.Len()) if err != nil { return nil, err } @@ -267,18 +267,6 @@ func (e *UpdateExec) Open(ctx context.Context) error { return e.SelectExec.Open(ctx) } -func (e *UpdateExec) getUpdateColumns(ctx sessionctx.Context, schemaLen int) ([]bool, error) { - assignFlag := make([]bool, schemaLen) - for _, v := range e.OrderedList { - if !ctx.GetSessionVars().AllowWriteRowID && v.Col.ColName.L == model.ExtraHandleName.L { - return nil, errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported.") - } - idx := v.Col.Index - assignFlag[idx] = true - } - return assignFlag, nil -} - // setMessage sets info message(ERR_UPDATE_INFO) generated by UPDATE statement func (e *UpdateExec) setMessage() { stmtCtx := e.ctx.GetSessionVars().StmtCtx diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index ef32bb07ac31d..cab17c946c623 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2681,9 +2681,66 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( return nil, err } err = updt.ResolveIndices() + if err != nil { + return nil, err + } + + err = b.checkUpdateList(updt) return updt, err } +// GetUpdateColumns gets the columns of updated lists. +func GetUpdateColumns(ctx sessionctx.Context, orderedList []*expression.Assignment, schemaLen int) ([]bool, error) { + assignFlag := make([]bool, schemaLen) + for _, v := range orderedList { + if !ctx.GetSessionVars().AllowWriteRowID && v.Col.ID == model.ExtraHandleID { + return nil, errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported.") + } + idx := v.Col.Index + assignFlag[idx] = true + } + return assignFlag, nil +} + +func getTableOffset(schema *expression.Schema, handleCol *expression.Column) (int, error) { + for i, col := range schema.Columns { + if col.DBName.L == handleCol.DBName.L && col.TblName.L == handleCol.TblName.L { + return i, nil + } + } + return -1, errors.Errorf("Couldn't get column information when do update") +} + +func (b *PlanBuilder) checkUpdateList(updt *Update) error { + tblID2table := make(map[int64]table.Table) + for id := range updt.SelectPlan.Schema().TblID2Handle { + tblID2table[id], _ = b.is.TableByID(id) + } + + assignFlags, err := GetUpdateColumns(b.ctx, updt.OrderedList, updt.SelectPlan.Schema().Len()) + if err != nil { + return err + } + schema := updt.SelectPlan.Schema() + for id, cols := range schema.TblID2Handle { + tbl := tblID2table[id] + for _, col := range cols { + offset, err := getTableOffset(schema, col) + if err != nil { + return err + } + end := offset + len(tbl.WritableCols()) + flags := assignFlags[offset:end] + for i, col := range tbl.WritableCols() { + if flags[i] && col.State != model.StatePublic { + return ErrUnknownColumn.GenWithStackByArgs(col.Name, clauseMsg[fieldList]) + } + } + } + } + return nil +} + func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan) ([]*expression.Assignment, LogicalPlan, error) { b.curClause = fieldList // modifyColumns indicates which columns are in set list,