Skip to content

Commit

Permalink
ddl: check set default value is string type
Browse files Browse the repository at this point in the history
  • Loading branch information
zimulala committed Sep 18, 2019
1 parent 3c0ed92 commit 89ac601
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 14 deletions.
21 changes: 17 additions & 4 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1839,16 +1839,29 @@ func (s *testDBSuite2) TestCreateTableWithSetCol(c *C) {
" `a` int(11) DEFAULT NULL,\n" +
" `b` set('e') DEFAULT ''\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('a', 'b', 'c', 'd') default 'a,C,c');")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('a','b','c','d') DEFAULT 'a,c'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))

// The type of default value is int.
// for failure cases
// It's for failure cases.
// The type of default value is string.
s.tk.MustExec("drop table t_set")
failedSQL := "create table t_set (a set('1', '4', '10') default 0);"
failedSQL := "create table t_set (a set('1', '4', '10') default '3');"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default '1,4,11');"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default '1 ,4');"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
// The type of default value is int.
failedSQL = "create table t_set (a set('1', '4', '10') default 0);"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default 8);"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)

// for successful cases
// The type of default value is int.
// It's for successful cases
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 1);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1'\n" +
Expand Down
58 changes: 48 additions & 10 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,26 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOpt
}
return strconv.FormatUint(value, 10), nil
}
if tp == mysql.TypeSet && v.Kind() == types.KindInt64 {

switch tp {
case mysql.TypeSet:
return setSetDefaultValue(v, col)
case mysql.TypeDuration:
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil {
return "", errors.Trace(err)
}
case mysql.TypeBit:
if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 {
// For BIT fields, convert int into BinaryLiteral.
return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), nil
}
}

return v.ToString()
}

func setSetDefaultValue(v types.Datum, col *table.Column) (string, error) {
if v.Kind() == types.KindInt64 {
setCnt := len(col.Elems)
maxLimit := int64(1<<uint(setCnt) - 1)
val := v.GetInt64()
Expand All @@ -692,21 +711,40 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOpt
return "", errors.Trace(err)
}
v.SetMysqlSet(setVal)
return v.ToString()
}

if tp == mysql.TypeDuration {
var err error
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil {
return "", errors.Trace(err)
}
str, err := v.ToString()
if err != nil {
return "", errors.Trace(err)
}
if str == "" {
return str, nil
}

if tp == mysql.TypeBit {
if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 {
// For BIT fields, convert int into BinaryLiteral.
return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), nil
valMap := make(map[string]struct{}, len(col.Elems))
dVals := strings.SplitN(strings.ToLower(str), ",", -1)
for _, dv := range dVals {
valMap[dv] = struct{}{}
}
var existCnt int
for dv := range valMap {
for i := range col.Elems {
e := strings.ToLower(col.Elems[i])
if e == dv {
existCnt++
break
}
}
}
if existCnt != len(valMap) {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
setVal, err := types.ParseSetName(col.Elems, str)
if err != nil {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
v.SetMysqlSet(setVal)

return v.ToString()
}
Expand Down

0 comments on commit 89ac601

Please sign in to comment.