Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddl: check column and partition value have same type for range column partition (#12664) #12792

Merged
merged 6 commits into from
Oct 18, 2019
17 changes: 17 additions & 0 deletions ddl/db_partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,13 @@ create table log_message_1 (
"partition p1 values less than (1, 'a'))",
ddl.ErrRangeNotIncreasing,
},
{
"create table t (col datetime not null default '2000-01-01')" +
"partition by range columns (col) (" +
"PARTITION p0 VALUES LESS THAN (20190905)," +
"PARTITION p1 VALUES LESS THAN (20190906));",
ddl.ErrWrongTypeColumnValue,
},
}
for i, t := range cases {
_, err := tk.Exec(t.sql)
Expand Down Expand Up @@ -545,6 +552,16 @@ func (s *testIntegrationSuite5) TestAlterTableAddPartition(c *C) {
sql := "alter table t add partition ( partition p3 values less than ('2019-07-01'));"
assertErrorCode(c, tk, sql, tmysql.ErrRangeNotIncreasing)
tk.MustExec("alter table t add partition ( partition p3 values less than ('2019-08-01'));")

// Add partition value's type should be the same with the column's type.
tk.MustExec("drop table if exists t;")
tk.MustExec(`create table t (
col date not null default '2000-01-01')
partition by range columns (col) (
PARTITION p0 VALUES LESS THAN ('20190905'),
PARTITION p1 VALUES LESS THAN ('20190906'));`)
sql = "alter table t add partition (partition p2 values less than (20190907));"
assertErrorCode(c, tk, sql, tmysql.ErrWrongTypeColumnValue)
}

func (s *testIntegrationSuite6) TestAlterTableDropPartition(c *C) {
Expand Down
4 changes: 4 additions & 0 deletions ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ var (
ErrTableCantHandleFt = terror.ClassDDL.New(codeErrTableCantHandleFt, mysql.MySQLErrName[mysql.ErrTableCantHandleFt])
// ErrFieldNotFoundPart returns an error when 'partition by columns' are not found in table columns.
ErrFieldNotFoundPart = terror.ClassDDL.New(codeFieldNotFoundPart, mysql.MySQLErrName[mysql.ErrFieldNotFoundPart])
// ErrWrongTypeColumnValue returns 'Partition column values of incorrect type'
ErrWrongTypeColumnValue = terror.ClassDDL.New(codeWrongTypeColumnValue, mysql.MySQLErrName[mysql.ErrWrongTypeColumnValue])
)

// DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache.
Expand Down Expand Up @@ -751,6 +753,7 @@ const (
codeSubpartition = terror.ErrCode(mysql.ErrSubpartition)
codeSystemVersioningWrongPartitions = terror.ErrCode(mysql.ErrSystemVersioningWrongPartitions)
codeWrongPartitionTypeExpectedSystemTime = terror.ErrCode(mysql.ErrWrongPartitionTypeExpectedSystemTime)
codeWrongTypeColumnValue = terror.ErrCode(mysql.ErrWrongTypeColumnValue)
)

func init() {
Expand Down Expand Up @@ -822,6 +825,7 @@ func init() {
codeSubpartition: mysql.ErrSubpartition,
codeSystemVersioningWrongPartitions: mysql.ErrSystemVersioningWrongPartitions,
codeWrongPartitionTypeExpectedSystemTime: mysql.ErrWrongPartitionTypeExpectedSystemTime,
codeWrongTypeColumnValue: mysql.ErrWrongTypeColumnValue,
}
terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes
}
53 changes: 51 additions & 2 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,15 @@ func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *
return err
}

if s != nil {
for _, def := range s.Partition.Definitions {
exprs := def.Clause.(*ast.PartitionDefinitionClauseLessThan).Exprs
if err := checkRangeColumnsTypeAndValuesMatch(ctx, tbInfo, pi.Columns, exprs); err != nil {
return err
}
}
}

return checkRangeColumnsPartitionValue(ctx, tbInfo, pi)
}

Expand Down Expand Up @@ -2178,7 +2187,7 @@ func (d *ddl) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *
return errors.Trace(ErrPartitionMgmtOnNonpartitioned)
}

partInfo, err := buildPartitionInfo(meta, d, spec)
partInfo, err := buildPartitionInfo(ctx, meta, d, spec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -3328,7 +3337,7 @@ func validateCommentLength(vars *variable.SessionVars, comment string, maxLen in
return comment, nil
}

func buildPartitionInfo(meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) {
func buildPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) {
if meta.Partition.Type == model.PartitionTypeRange {
if len(spec.PartDefinitions) == 0 {
return nil, ast.ErrPartitionsMustBeDefined.GenWithStackByArgs(meta.Partition.Type)
Expand All @@ -3355,6 +3364,11 @@ func buildPartitionInfo(meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec)
}
// For RANGE partition only VALUES LESS THAN should be possible.
clause := def.Clause.(*ast.PartitionDefinitionClauseLessThan)
if len(part.Columns) > 0 {
if err := checkRangeColumnsTypeAndValuesMatch(ctx, meta, part.Columns, clause.Exprs); err != nil {
return nil, err
}
}

comment, _ := def.Comment()
piDef := model.PartitionDefinition{
Expand All @@ -3374,6 +3388,41 @@ func buildPartitionInfo(meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec)
return part, nil
}

func checkRangeColumnsTypeAndValuesMatch(ctx sessionctx.Context, meta *model.TableInfo, colNames []model.CIStr, exprs []ast.ExprNode) error {
// Validate() has already checked len(colNames) = len(exprs)
// create table ... partition by range columns (cols)
// partition p0 values less than (expr)
// check the type of cols[i] and expr is consistent.
for i, colExpr := range exprs {
if _, ok := colExpr.(*ast.MaxValueExpr); ok {
continue
}

colName := colNames[i]
colInfo := getColumnInfoByName(meta, colName.L)
if colInfo == nil {
return errors.Trace(ErrFieldNotFoundPart)
}
colType := &colInfo.FieldType

val, err := expression.EvalAstExpr(ctx, colExpr)
if err != nil {
return err
}

// Check val.ConvertTo(colType) doesn't work, so we need this case by case check.
switch colType.Tp {
case mysql.TypeDate, mysql.TypeDatetime:
switch val.Kind() {
case types.KindString, types.KindBytes:
default:
return ErrWrongTypeColumnValue.GenWithStackByArgs()
}
}
}
return nil
}

// extractCollateFromOption take collates(may multiple) in option into consideration
// when handle charset and collate of a column, rather than handling it separately.
func extractCollateFromOption(def *ast.ColumnDef) []string {
Expand Down