Skip to content

Commit

Permalink
ddl: fix the issue when the set type default value is int type
Browse files Browse the repository at this point in the history
  • Loading branch information
zimulala committed Sep 18, 2019
1 parent c5cad51 commit 3c0ed92
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
40 changes: 40 additions & 0 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,46 @@ func (s *testDBSuite1) TestCreateTable(c *C) {
c.Assert(err.Error(), Equals, "[types:1291]Column 'a' has duplicated value 'B' in ENUM")
}

func (s *testDBSuite2) TestCreateTableWithSetCol(c *C) {
s.tk = testkit.NewTestKitWithInit(c, s.store)
s.tk.MustExec("create table t_set (a int, b set('e') default '');")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` int(11) DEFAULT NULL,\n" +
" `b` set('e') DEFAULT ''\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))

// The type of default value is int.
// for failure cases
s.tk.MustExec("drop table t_set")
failedSQL := "create table t_set (a set('1', '4', '10') default 0);"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default 8);"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)

// for successful cases
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 1);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 2);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '4'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 3);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1,4'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 15);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1,4,10,21'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("insert into t_set value()")
s.tk.MustQuery("select * from t_set").Check(testkit.Rows("1,4,10,21"))
}

func (s *testDBSuite2) TestTableForeignKey(c *C) {
s.tk = testkit.NewTestKit(c, s.store)
s.tk.MustExec("use test")
Expand Down
25 changes: 19 additions & 6 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,8 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
return col, constraints, nil
}

func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption, t *types.FieldType) (interface{}, error) {
tp, fsp := t.Tp, t.Decimal
func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, error) {
tp, fsp := col.FieldType.Tp, col.FieldType.Decimal
if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime {
switch x := c.Expr.(type) {
case *ast.FuncCallExpr:
Expand All @@ -633,14 +633,14 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption
}
}
if defaultFsp != fsp {
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName)
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
}
}
vd, err := expression.GetTimeValue(ctx, c.Expr, tp, int8(fsp))
value := vd.GetValue()
if err != nil {
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName)
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}

// Value is nil means `default null`.
Expand Down Expand Up @@ -680,10 +680,23 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption
}
return strconv.FormatUint(value, 10), nil
}
if tp == mysql.TypeSet && v.Kind() == types.KindInt64 {
setCnt := len(col.Elems)
maxLimit := int64(1<<uint(setCnt) - 1)
val := v.GetInt64()
if val < 1 || val > maxLimit {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
setVal, err := types.ParseSetValue(col.Elems, uint64(val))
if err != nil {
return "", errors.Trace(err)
}
v.SetMysqlSet(setVal)
}

if tp == mysql.TypeDuration {
var err error
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, t); err != nil {
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil {
return "", errors.Trace(err)
}
}
Expand Down Expand Up @@ -2491,7 +2504,7 @@ func modifiable(origin *types.FieldType, to *types.FieldType) error {

func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (bool, error) {
hasDefaultValue := false
value, err := getDefaultValue(ctx, col.Name.L, option, &col.FieldType)
value, err := getDefaultValue(ctx, col, option)
if err != nil {
return hasDefaultValue, errors.Trace(err)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ require (
golang.org/x/text v0.3.2
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect
golang.org/x/tools v0.0.0-20190911022129-16c5e0f7d110
google.golang.org/appengine v1.4.0 // indirect
google.golang.org/genproto v0.0.0-20190905072037-92dd089d5514 // indirect
google.golang.org/grpc v1.23.0
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
Expand Down

0 comments on commit 3c0ed92

Please sign in to comment.