diff --git a/ddl/column.go b/ddl/column.go index e5b342b0236e9..078fa8415b419 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -588,9 +588,9 @@ func generateOriginDefaultValue(col *model.ColumnInfo) (interface{}, error) { return odValue, nil } -func findColumnInIndexCols(c *model.ColumnInfo, cols []*ast.IndexColName) bool { +func findColumnInIndexCols(c string, cols []*ast.IndexColName) bool { for _, c1 := range cols { - if c.Name.L == c1.Column.Name.L { + if c == c1.Column.Name.L { return true } } diff --git a/ddl/db_partition_test.go b/ddl/db_partition_test.go index e2810a620cfcb..08a9cb823f156 100644 --- a/ddl/db_partition_test.go +++ b/ddl/db_partition_test.go @@ -1080,6 +1080,38 @@ func (s *testIntegrationSuite5) TestPartitionUniqueKeyNeedAllFieldsInPf(c *C) { partition p2 values less than (15) )` assertErrorCode(c, tk, sql9, tmysql.ErrUniqueKeyNeedAllFieldsInPf) + + sql10 := `create table part8 ( + a int not null, + b int not null, + c int default null, + d int default null, + e int default null, + primary key (a, b), + unique key (c, d) + ) + partition by range columns (b) ( + partition p0 values less than (4), + partition p1 values less than (7), + partition p2 values less than (11) + )` + assertErrorCode(c, tk, sql10, tmysql.ErrUniqueKeyNeedAllFieldsInPf) + + sql11 := `create table part9 ( + a int not null, + b int not null, + c int default null, + d int default null, + e int default null, + primary key (a, b), + unique key (b, c, d) + ) + partition by range columns (b, c) ( + partition p0 values less than (4, 5), + partition p1 values less than (7, 9), + partition p2 values less than (11, 22) + )` + assertErrorCode(c, tk, sql11, tmysql.ErrUniqueKeyNeedAllFieldsInPf) } func (s *testIntegrationSuite3) TestPartitionDropIndex(c *C) { diff --git a/ddl/partition.go b/ddl/partition.go index 16bbb8fb05b6d..cea935091c350 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -493,17 +493,22 @@ func getPartitionIDs(table *model.TableInfo) []int64 { // checkRangePartitioningKeysConstraints checks that the range partitioning key is included in the table constraint. func checkRangePartitioningKeysConstraints(sctx sessionctx.Context, s *ast.CreateTableStmt, tblInfo *model.TableInfo, constraints []*ast.Constraint) error { // Returns directly if there is no constraint in the partition table. - // TODO: Remove the test 's.Partition.Expr == nil' when we support 'PARTITION BY RANGE COLUMNS' - if len(constraints) == 0 || s.Partition.Expr == nil { + if len(constraints) == 0 { return nil } - // Parse partitioning key, extract the column names in the partitioning key to slice. - buf := new(bytes.Buffer) - s.Partition.Expr.Format(buf) - partCols, err := extractPartitionColumns(buf.String(), tblInfo) - if err != nil { - return err + var partCols stringSlice + if s.Partition.Expr != nil { + // Parse partitioning key, extract the column names in the partitioning key to slice. + buf := new(bytes.Buffer) + s.Partition.Expr.Format(buf) + partColumns, err := extractPartitionColumns(buf.String(), tblInfo) + if err != nil { + return err + } + partCols = columnInfoSlice(partColumns) + } else if len(s.Partition.ColumnNames) > 0 { + partCols = columnNameSlice(s.Partition.ColumnNames) } // Checks that the partitioning key is included in the constraint. @@ -549,7 +554,7 @@ func checkPartitionKeysConstraint(pi *model.PartitionInfo, idxColNames []*ast.In // Every unique key on the table must use every column in the table's partitioning expression. // See https://dev.mysql.com/doc/refman/5.7/en/partitioning-limitations-partitioning-keys-unique-keys.html - if !checkUniqueKeyIncludePartKey(partCols, idxColNames) { + if !checkUniqueKeyIncludePartKey(columnInfoSlice(partCols), idxColNames) { return ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("UNIQUE INDEX") } return nil @@ -596,9 +601,17 @@ func extractPartitionColumns(partExpr string, tblInfo *model.TableInfo) ([]*mode return extractor.extractedColumns, nil } +// stringSlice is defined for checkUniqueKeyIncludePartKey. +// if Go supports covariance, the code shouldn't be so complex. +type stringSlice interface { + Len() int + At(i int) string +} + // checkUniqueKeyIncludePartKey checks that the partitioning key is included in the constraint. -func checkUniqueKeyIncludePartKey(partCols []*model.ColumnInfo, idxCols []*ast.IndexColName) bool { - for _, partCol := range partCols { +func checkUniqueKeyIncludePartKey(partCols stringSlice, idxCols []*ast.IndexColName) bool { + for i := 0; i < partCols.Len(); i++ { + partCol := partCols.At(i) if !findColumnInIndexCols(partCol, idxCols) { return false } @@ -606,6 +619,28 @@ func checkUniqueKeyIncludePartKey(partCols []*model.ColumnInfo, idxCols []*ast.I return true } +// columnInfoSlice implements the stringSlice interface. +type columnInfoSlice []*model.ColumnInfo + +func (cis columnInfoSlice) Len() int { + return len(cis) +} + +func (cis columnInfoSlice) At(i int) string { + return cis[i].Name.L +} + +// columnNameSlice implements the stringSlice interface. +type columnNameSlice []*ast.ColumnName + +func (cns columnNameSlice) Len() int { + return len(cns) +} + +func (cns columnNameSlice) At(i int) string { + return cns[i].Name.L +} + // isRangePartitionColUnsignedBigint returns true if the partitioning key column type is unsigned bigint type. func isRangePartitionColUnsignedBigint(cols []*table.Column, pi *model.PartitionInfo) bool { for _, col := range cols {