Skip to content
This repository has been archived by the owner on Dec 8, 2021. It is now read-only.

encoder: check string value for tidb encoder #378

Merged
merged 5 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions lightning/backend/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ func (row tidbRows) MarshalLogArray(encoder zapcore.ArrayEncoder) error {

type tidbEncoder struct {
mode mysql.SQLMode
tbl table.Table
se *session
}

type tidbBackend struct {
Expand Down Expand Up @@ -104,7 +106,7 @@ func (rows tidbRows) Clear() Rows {
return rows[:0]
}

func (enc tidbEncoder) appendSQLBytes(sb *strings.Builder, value []byte) {
func (enc *tidbEncoder) appendSQLBytes(sb *strings.Builder, value []byte) {
sb.Grow(2 + len(value))
sb.WriteByte('\'')
if enc.mode.HasNoBackslashEscapesMode() {
Expand Down Expand Up @@ -144,7 +146,7 @@ func (enc tidbEncoder) appendSQLBytes(sb *strings.Builder, value []byte) {

// appendSQL appends the SQL representation of the Datum into the string builder.
// Note that we cannot use Datum.ToString since it doesn't perform SQL escaping.
func (enc tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum) error {
func (enc *tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum, col *table.Column) error {
switch datum.Kind() {
case types.KindNull:
sb.WriteString("NULL")
Expand Down Expand Up @@ -172,8 +174,17 @@ func (enc tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum) error
var buffer [32]byte
value := strconv.AppendFloat(buffer[:0], datum.GetFloat64(), 'g', -1, 64)
sb.Write(value)
case types.KindString:
if enc.mode.HasStrictMode() {
d, err := table.CastValue(enc.se, *datum, col.ToInfo(), false, false)
if err != nil {
return errors.Trace(err)
}
datum = &d
}

case types.KindString, types.KindBytes:
enc.appendSQLBytes(sb, datum.GetBytes())
case types.KindBytes:
enc.appendSQLBytes(sb, datum.GetBytes())

case types.KindMysqlJSON:
Expand Down Expand Up @@ -213,17 +224,19 @@ func (enc tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum) error
return nil
}

func (tidbEncoder) Close() {}
func (*tidbEncoder) Close() {}

func (enc *tidbEncoder) Encode(logger log.Logger, row []types.Datum, _ int64, columnPermutation []int) (Row, error) {
cols := enc.tbl.Cols()

func (enc tidbEncoder) Encode(logger log.Logger, row []types.Datum, _ int64, _ []int) (Row, error) {
var encoded strings.Builder
encoded.Grow(8 * len(row))
encoded.WriteByte('(')
for i, field := range row {
if i != 0 {
encoded.WriteByte(',')
}
if err := enc.appendSQL(&encoded, &field); err != nil {
if err := enc.appendSQL(&encoded, &field, cols[columnPermutation[i]]); err != nil {
logger.Error("tidb encode failed",
zap.Array("original", rowArrayMarshaler(row)),
zap.Int("originalCol", i),
Expand Down Expand Up @@ -265,8 +278,15 @@ func (be *tidbBackend) CheckRequirements() error {
return nil
}

func (be *tidbBackend) NewEncoder(_ table.Table, options *SessionOptions) Encoder {
return tidbEncoder{mode: options.SQLMode}
func (be *tidbBackend) NewEncoder(tbl table.Table, options *SessionOptions) Encoder {
var se *session
if options.SQLMode.HasStrictMode() {
se = newSession(options)
se.vars.SkipUTF8Check = false
se.vars.SkipASCIICheck = false
}

return &tidbEncoder{mode: options.SQLMode, tbl: tbl, se: se}
}

func (be *tidbBackend) OpenEngine(context.Context, uuid.UUID) error {
Expand Down
69 changes: 63 additions & 6 deletions lightning/backend/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@ package backend_test
import (
"context"
"database/sql"
"fmt"

"github.com/pingcap/parser/charset"

"github.com/DATA-DOG/go-sqlmock"
. "github.com/pingcap/check"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/types"

kv "github.com/pingcap/tidb-lightning/lightning/backend"
Expand All @@ -33,15 +40,28 @@ type mysqlSuite struct {
dbHandle *sql.DB
mockDB sqlmock.Sqlmock
backend kv.Backend
tbl table.Table
}

func (s *mysqlSuite) SetUpTest(c *C) {
db, mock, err := sqlmock.New()
c.Assert(err, IsNil)

tys := []byte{mysql.TypeLong, mysql.TypeLong, mysql.TypeTiny, mysql.TypeInt24, mysql.TypeFloat, mysql.TypeDouble,
mysql.TypeDouble, mysql.TypeDouble, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeBit, mysql.TypeNewDecimal, mysql.TypeEnum}
cols := make([]*model.ColumnInfo, 0, len(tys))
for i, ty := range tys {
col := &model.ColumnInfo{ID: int64(i + 1), Name: model.NewCIStr(fmt.Sprintf("c%d", i)), State: model.StatePublic, Offset: i, FieldType: *types.NewFieldType(ty)}
cols = append(cols, col)
}
tblInfo := &model.TableInfo{ID: 1, Columns: cols, PKIsHandle: false, State: model.StatePublic}
tbl, err := tables.TableFromMeta(kv.NewPanickingAllocators(0), tblInfo)
c.Assert(err, IsNil)

s.dbHandle = db
s.mockDB = mock
s.backend = kv.NewTiDBBackend(db, config.ReplaceOnDup)
s.tbl = tbl
}

func (s *mysqlSuite) TearDownTest(c *C) {
Expand All @@ -65,7 +85,12 @@ func (s *mysqlSuite) TestWriteRowsReplaceOnDup(c *C) {
indexRows := s.backend.MakeEmptyRows()
indexChecksum := verification.MakeKVChecksum(0, 0, 0)

encoder := s.backend.NewEncoder(nil, &kv.SessionOptions{SQLMode: 0, Timestamp: 1234567890, RowFormatVersion: "1"})
cols := s.tbl.Cols()
perms := make([]int, 0, len(s.tbl.Cols()))
for i := 0; i < len(cols); i++ {
perms = append(perms, i)
}
encoder := s.backend.NewEncoder(s.tbl, &kv.SessionOptions{SQLMode: 0, Timestamp: 1234567890, RowFormatVersion: "1"})
row, err := encoder.Encode(logger, []types.Datum{
types.NewUintDatum(18446744073709551615),
types.NewIntDatum(-9223372036854775808),
Expand All @@ -80,7 +105,7 @@ func (s *mysqlSuite) TestWriteRowsReplaceOnDup(c *C) {
types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(0x98765432, 4)),
types.NewDecimalDatum(types.NewDecFromFloatForTest(12.5)),
types.NewMysqlEnumDatum(types.Enum{Name: "ENUM_NAME", Value: 51}),
}, 1, nil)
}, 1, perms)
c.Assert(err, IsNil)
row.ClassifyAndAppend(&dataRows, &dataChecksum, &indexRows, &indexChecksum)

Expand All @@ -105,10 +130,10 @@ func (s *mysqlSuite) TestWriteRowsIgnoreOnDup(c *C) {
indexRows := ignoreBackend.MakeEmptyRows()
indexChecksum := verification.MakeKVChecksum(0, 0, 0)

encoder := ignoreBackend.NewEncoder(nil, &kv.SessionOptions{})
encoder := ignoreBackend.NewEncoder(s.tbl, &kv.SessionOptions{})
row, err := encoder.Encode(logger, []types.Datum{
types.NewIntDatum(1),
}, 1, nil)
}, 1, []int{0})
c.Assert(err, IsNil)
row.ClassifyAndAppend(&dataRows, &dataChecksum, &indexRows, &indexChecksum)

Expand All @@ -133,13 +158,45 @@ func (s *mysqlSuite) TestWriteRowsErrorOnDup(c *C) {
indexRows := ignoreBackend.MakeEmptyRows()
indexChecksum := verification.MakeKVChecksum(0, 0, 0)

encoder := ignoreBackend.NewEncoder(nil, &kv.SessionOptions{})
encoder := ignoreBackend.NewEncoder(s.tbl, &kv.SessionOptions{})
row, err := encoder.Encode(logger, []types.Datum{
types.NewIntDatum(1),
}, 1, nil)
}, 1, []int{0})
c.Assert(err, IsNil)

row.ClassifyAndAppend(&dataRows, &dataChecksum, &indexRows, &indexChecksum)

err = engine.WriteRows(ctx, []string{"a"}, dataRows)
c.Assert(err, IsNil)
}

func (s *mysqlSuite) TestStrictMode(c *C) {
ft := *types.NewFieldType(mysql.TypeVarchar)
ft.Charset = charset.CharsetUTF8MB4
col0 := &model.ColumnInfo{ID: 1, Name: model.NewCIStr("s0"), State: model.StatePublic, Offset: 0, FieldType: ft}
ft = *types.NewFieldType(mysql.TypeString)
ft.Charset = charset.CharsetASCII
col1 := &model.ColumnInfo{ID: 2, Name: model.NewCIStr("s1"), State: model.StatePublic, Offset: 1, FieldType: ft}
tblInfo := &model.TableInfo{ID: 1, Columns: []*model.ColumnInfo{col0, col1}, PKIsHandle: false, State: model.StatePublic}
tbl, err := tables.TableFromMeta(kv.NewPanickingAllocators(0), tblInfo)
c.Assert(err, IsNil)

bk := kv.NewTiDBBackend(s.dbHandle, config.ErrorOnDup)
encoder := bk.NewEncoder(tbl, &kv.SessionOptions{SQLMode: mysql.ModeStrictAllTables})

logger := log.L()
_, err = encoder.Encode(logger, []types.Datum{
types.NewStringDatum("test"),
}, 1, []int{0})
c.Assert(err, IsNil)

_, err = encoder.Encode(logger, []types.Datum{
types.NewStringDatum("\xff\xff\xff\xff"),
}, 1, []int{0})
c.Assert(err, ErrorMatches, `.*incorrect utf8 value .* for column s0`)

_, err = encoder.Encode(logger, []types.Datum{
types.NewStringDatum("非 ASCII 字符串"),
}, 1, []int{1})
c.Assert(err, ErrorMatches, ".*incorrect ascii value .* for column s1")
}