diff --git a/executor/executor_test.go b/executor/executor_test.go index 29d306a06280e..5c6a53005713b 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -77,6 +77,7 @@ func TestT(t *testing.T) { var _ = Suite(&testSuite{}) var _ = Suite(&testContextOptionSuite{}) var _ = Suite(&testBypassSuite{}) +var _ = Suite(&testUpdateSuite{}) type testSuite struct { cluster *mocktikv.Cluster diff --git a/executor/update_test.go b/executor/update_test.go new file mode 100644 index 0000000000000..fda7bc0c35c07 --- /dev/null +++ b/executor/update_test.go @@ -0,0 +1,92 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + "flag" + "fmt" + + . "github.com/pingcap/check" + "github.com/pingcap/parser" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/store/mockstore/mocktikv" + "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" +) + +type testUpdateSuite struct { + cluster *mocktikv.Cluster + mvccStore mocktikv.MVCCStore + store kv.Storage + domain *domain.Domain + *parser.Parser + ctx *mock.Context +} + +func (s *testUpdateSuite) SetUpSuite(c *C) { + testleak.BeforeTest() + s.Parser = parser.New() + flag.Lookup("mockTikv") + useMockTikv := *mockTikv + if useMockTikv { + s.cluster = mocktikv.NewCluster() + mocktikv.BootstrapWithSingleStore(s.cluster) + s.mvccStore = mocktikv.MustNewMVCCStore() + store, err := mockstore.NewMockTikvStore( + mockstore.WithCluster(s.cluster), + mockstore.WithMVCCStore(s.mvccStore), + ) + c.Assert(err, IsNil) + s.store = store + session.SetSchemaLease(0) + session.SetStatsLease(0) + } + d, err := session.BootstrapSession(s.store) + c.Assert(err, IsNil) + d.SetStatsUpdating(true) + s.domain = d +} + +func (s *testUpdateSuite) TearDownSuite(c *C) { + s.domain.Close() + s.store.Close() + testleak.AfterTest(c, TestLeakCheckCnt)() +} + +func (s *testUpdateSuite) TearDownTest(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + r := tk.MustQuery("show tables") + for _, tb := range r.Rows() { + tableName := tb[0] + tk.MustExec(fmt.Sprintf("drop table %v", tableName)) + } +} + +func (s *testUpdateSuite) TestUpdateGenColInTxn(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`create table t(a bigint, b bigint as (a+1));`) + tk.MustExec(`begin;`) + tk.MustExec(`insert into t(a) values(1);`) + err := tk.ExecToErr(`update t set b=6 where b=2;`) + c.Assert(err.Error(), Equals, "[planner:3105]The value specified for generated column 'b' in table 't' is not allowed.") + tk.MustExec(`commit;`) + tk.MustQuery(`select * from t;`).Check(testkit.Rows( + `1 2`)) +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index e04f60c86204b..c47f6380c0fc7 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2214,15 +2214,27 @@ func extractTableAsNameForUpdate(p LogicalPlan, asNames map[*model.TableInfo][]* asNames[x.tableInfo] = append(asNames[x.tableInfo], alias) } case *LogicalProjection: - if x.calculateGenCols { - ds := x.Children()[0].(*DataSource) - alias := extractTableAlias(x) - if alias != nil { - if _, ok := asNames[ds.tableInfo]; !ok { - asNames[ds.tableInfo] = make([]*model.CIStr, 0, 1) - } - asNames[ds.tableInfo] = append(asNames[ds.tableInfo], alias) + if !x.calculateGenCols { + return + } + + ds, isDS := x.Children()[0].(*DataSource) + if !isDS { + // try to extract the DataSource below a LogicalUnionScan. + if us, isUS := x.Children()[0].(*LogicalUnionScan); isUS { + ds, isDS = us.Children()[0].(*DataSource) + } + } + if !isDS { + return + } + + alias := extractTableAlias(x) + if alias != nil { + if _, ok := asNames[ds.tableInfo]; !ok { + asNames[ds.tableInfo] = make([]*model.CIStr, 0, 1) } + asNames[ds.tableInfo] = append(asNames[ds.tableInfo], alias) } default: for _, child := range p.Children() {