From 866c5514afeb456c687ae9dc9444df4eb08cb841 Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Sun, 19 Dec 2021 08:21:45 -0600 Subject: [PATCH 1/9] session: fix bootstrap to only persist global variables (#30593) close pingcap/tidb#28667 --- session/bootstrap.go | 4 ++-- session/bootstrap_serial_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/session/bootstrap.go b/session/bootstrap.go index 2244e3eaa5484..e5c588ec5af01 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -1764,8 +1764,8 @@ func doDMLWorks(s Session) { // Init global system variables table. values := make([]string, 0, len(variable.GetSysVars())) for k, v := range variable.GetSysVars() { - // Session only variable should not be inserted. - if v.Scope != variable.ScopeSession { + // Only global variables should be inserted. + if v.HasGlobalScope() { vVal := v.Value if v.Name == variable.TiDBTxnMode && config.GetGlobalConfig().Store == "tikv" { vVal = "pessimistic" diff --git a/session/bootstrap_serial_test.go b/session/bootstrap_serial_test.go index 1de04a5ca91f3..6caf7e702c10b 100644 --- a/session/bootstrap_serial_test.go +++ b/session/bootstrap_serial_test.go @@ -113,7 +113,7 @@ func TestBootstrap(t *testing.T) { func globalVarsCount() int64 { var count int64 for _, v := range variable.GetSysVars() { - if v.Scope != variable.ScopeSession { + if v.HasGlobalScope() { count++ } } From 1721706b2386b1600743da4c845c95319990b1c3 Mon Sep 17 00:00:00 2001 From: Lynn Date: Sun, 19 Dec 2021 22:35:45 +0800 Subject: [PATCH 2/9] docs/design: update collation compatibility issues in charsets doc (#30806) --- docs/design/2021-08-18-charsets.md | 41 +++++++----------------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/docs/design/2021-08-18-charsets.md b/docs/design/2021-08-18-charsets.md index 16cad2fd044ed..441f5b0917d6b 100644 --- a/docs/design/2021-08-18-charsets.md +++ b/docs/design/2021-08-18-charsets.md @@ -98,8 +98,10 @@ After receiving the non-utf-8 character set request, this solution will convert ### Collation Add gbk_chinese_ci and gbk_bin collations. In addition, considering the performance, we can add the collation of utf8mb4 (gbk_utf8mb4_bin). +- To support gbk_chinese_ci and gbk_bin collations, it needs to turn on the `new_collations_enabled_on_first_bootstrap` switch. + - If `new_collations_enabled_on_first_bootstrap` is off, it only supports gbk_utf8mb4_bin which does not need to be converted to gbk charset before processing. - Implement the Collator and WildcardPattern interface functions for each collation. - - gbk_chinese_ci and gbk_bin need to convert utf-8 to gbk encoding and then generate a sort key. gbk_utf8mb4_bin does not need to be converted to gbk code for processing. + - gbk_chinese_ci and gbk_bin need to convert utf-8 to gbk encoding and then generate a sort key. - Implement the corresponding functions in the Coprocessor. ### DDL @@ -119,43 +121,18 @@ Other behaviors that need to be dealt with: #### Compatibility between TiDB versions - Upgrade compatibility: - - Upgrades from versions below 4.0 do not support gbk or any character sets other than the original five (binary, ascii, latin1, utf8, utf8mb4). - - Upgrade from version 4.0 or higher - - There may be compatibility issues when performing non-utf-8-related operations during the rolling upgrade. - - The new version of the cluster is expected to have no compatibility issues when reading old data. + - There may be compatibility issues when performing operations during the rolling upgrade. + - The new version of the cluster is expected to have no compatibility issues when reading old data. - Downgrade compatibility: - Downgrade is not compatible. The index key uses the table of gbk_bin/gbk_chinese_ci. The lower version of TiDB will have problems when decoding, and it needs to be transcoded before downgrading. #### Compatibility with MySQL -Illegal character related issue: +- Illegal character related issue: + - Due to the internal conversion of non-utf-8-related encoding to utf8 for processing, it is not fully compatible with MySQL in some cases in terms of illegal character processing. TiDB controls its behavior through sql_mode. -```sql -create table t3(a char(10) charset gbk); -insert into t3 values ('a'); - -// 0xcee5 is a valid gbk hex literal but invalid utf8mb4 hex literal. -select hex(concat(a, 0xcee5)) from t3; --- mysql 61cee5 - -// 0xe4b880 is an invalid gbk hex literal but valid utf8mb4 hex literal. -select hex(concat(a, 0xe4b880)) from t3; --- mysql 61e4b880 (test on mysql 5.7 and 8.0.22) --- mysql returns "Cannot convert string '\x80' from binary to gbk" (test on mysql 8.0.25 and 8.0.26). TiDB will be compatible with this behavior. - -// 0x80 is a hex literal that invalid for neither gbk nor utf8mb4. -select hex(concat(a, 0x80)) from t3; --- mysql 6180 (test on mysql 5.7 and 8.0.22) --- mysql returns "Cannot convert string '\x80' from binary to gbk" (test on mysql 8.0.25 and 8.0.26). TiDB will be compatible with this behavior. - -set @@sql_mode = ''; -insert into t3 values (0x80); --- mysql gets a warning and insert null values (warning: "Incorrect string value: '\x80' for column 'a' at row 1") - -set @@sql_mode = 'STRICT_TRANS_TABLES'; -insert into t3 values (0x80); --- mysql returns "Incorrect string value: '\x80' for column 'a' at row 1" -``` +- Collation + - Fully support `gbk_bin` and `gbk_chinese_ci` only when the config `new_collations_enabled_on_first_bootstrap` is enabled. Otherwise, it only supports gbk_utf8mb4_bin. #### Compatibility with other components From 24d970fc46070d278bc9495199880839d35e07b7 Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Sun, 19 Dec 2021 17:57:45 -0600 Subject: [PATCH 3/9] executor: improve SET sysvar=DEFAULT handling (#29680) close pingcap/tidb#29670 --- executor/set_test.go | 28 +++++++++++++++++++++++++++- expression/helper_test.go | 2 +- sessionctx/variable/sysvar.go | 6 ------ sessionctx/variable/variable.go | 5 ----- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/executor/set_test.go b/executor/set_test.go index 6b166059e6921..da121e77b2422 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -111,7 +111,7 @@ func (s *testSerialSuite1) TestSetVar(c *C) { tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("0")) tk.MustExec(`set @@global.low_priority_updates="ON";`) tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("1")) - tk.MustExec(`set @@global.low_priority_updates=DEFAULT;`) // It will be set to compiled-in default value. + tk.MustExec(`set @@global.low_priority_updates=DEFAULT;`) // It will be set to default var value. tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("0")) // For session tk.MustQuery(`select @@session.low_priority_updates;`).Check(testkit.Rows("0")) @@ -1387,6 +1387,32 @@ func (s *testSuite5) TestEnableNoopFunctionsVar(c *C) { } +// https://github.com/pingcap/tidb/issues/29670 +func (s *testSuite5) TestDefaultBehavior(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustQuery("SELECT @@default_storage_engine").Check(testkit.Rows("InnoDB")) + tk.MustExec("SET GLOBAL default_storage_engine = 'somethingweird'") + tk.MustExec("SET default_storage_engine = 'MyISAM'") + tk.MustQuery("SELECT @@default_storage_engine").Check(testkit.Rows("MyISAM")) + tk.MustExec("SET default_storage_engine = DEFAULT") // reads from global value + tk.MustQuery("SELECT @@default_storage_engine").Check(testkit.Rows("somethingweird")) + tk.MustExec("SET @@SESSION.default_storage_engine = @@GLOBAL.default_storage_engine") // example from MySQL manual + tk.MustQuery("SELECT @@default_storage_engine").Check(testkit.Rows("somethingweird")) + tk.MustExec("SET GLOBAL default_storage_engine = 'somethingweird2'") + tk.MustExec("SET default_storage_engine = @@GLOBAL.default_storage_engine") // variation of example + tk.MustQuery("SELECT @@default_storage_engine").Check(testkit.Rows("somethingweird2")) + tk.MustExec("SET default_storage_engine = DEFAULT") // restore default again for session global + tk.MustExec("SET GLOBAL default_storage_engine = DEFAULT") // restore default for global + tk.MustQuery("SELECT @@SESSION.default_storage_engine, @@GLOBAL.default_storage_engine").Check(testkit.Rows("somethingweird2 InnoDB")) + + // Try sql_mode option which has validation + err := tk.ExecToErr("SET GLOBAL sql_mode = 'DEFAULT'") // illegal now + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, `ERROR 1231 (42000): Variable 'sql_mode' can't be set to the value of 'DEFAULT'`) + tk.MustExec("SET GLOBAL sql_mode = DEFAULT") +} + func (s *testSuite5) TestRemovedSysVars(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/expression/helper_test.go b/expression/helper_test.go index b7e00c221e141..63a9ca4137ed4 100644 --- a/expression/helper_test.go +++ b/expression/helper_test.go @@ -42,7 +42,7 @@ func TestGetTimeValue(t *testing.T) { require.Equal(t, "2012-12-12 00:00:00", timeValue.String()) sessionVars := ctx.GetSessionVars() - err = variable.SetSessionSystemVar(sessionVars, "timestamp", "default") + err = variable.SetSessionSystemVar(sessionVars, "timestamp", "0") require.NoError(t, err) v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, types.MinFsp) require.NoError(t, err) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index fc7ce09cae6a7..3491f28bc73dc 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -120,12 +120,6 @@ var defaultSysVars = []*SysVar{ } timestamp := s.StmtCtx.GetOrStoreStmtCache(stmtctx.StmtNowTsCacheKey, time.Now()).(time.Time) return types.ToString(float64(timestamp.UnixNano()) / float64(time.Second)) - }, GetGlobal: func(s *SessionVars) (string, error) { - // The Timestamp sysvar will have GetGlobal func even though it does not have global scope. - // It's GetGlobal func will only be called when "set timestamp = default". - // Setting timestamp to DEFAULT causes its value to be the current date and time as of the time it is accessed. - // See https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_timestamp - return DefTimestamp, nil }}, {Scope: ScopeGlobal | ScopeSession, Name: CollationDatabase, Value: mysql.DefaultCollationName, skipInit: true, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { return checkCollation(vars, normalizedValue, originalValue, scope) diff --git a/sessionctx/variable/variable.go b/sessionctx/variable/variable.go index 675ca3bdc0887..aeb65a8257d9f 100644 --- a/sessionctx/variable/variable.go +++ b/sessionctx/variable/variable.go @@ -261,11 +261,6 @@ func (sv *SysVar) Validate(vars *SessionVars, value string, scope ScopeFlag) (st // ValidateFromType provides automatic validation based on the SysVar's type func (sv *SysVar) ValidateFromType(vars *SessionVars, value string, scope ScopeFlag) (string, error) { - // The string "DEFAULT" is a special keyword in MySQL, which restores - // the compiled sysvar value. In which case we can skip further validation. - if strings.EqualFold(value, "DEFAULT") { - return sv.Value, nil - } // Some sysvars in TiDB have a special behavior where the empty string means // "use the config file value". This needs to be cleaned up once the behavior // for instance variables is determined. From e1fb2f541454a31f6693d9c34692eec6b4b30c93 Mon Sep 17 00:00:00 2001 From: fengou1 <85682690+fengou1@users.noreply.github.com> Date: Mon, 20 Dec 2021 08:09:45 +0800 Subject: [PATCH 4/9] br: add error handling for group context cancel when restore file is corrupted (#30190) close pingcap/tidb#30135 --- br/pkg/restore/pipeline_items.go | 1 + 1 file changed, 1 insertion(+) diff --git a/br/pkg/restore/pipeline_items.go b/br/pkg/restore/pipeline_items.go index 1bd7502f30642..ce476b1963fa5 100644 --- a/br/pkg/restore/pipeline_items.go +++ b/br/pkg/restore/pipeline_items.go @@ -360,6 +360,7 @@ func (b *tikvSender) restoreWorker(ctx context.Context, ranges <-chan drainResul eg.Go(func() error { e := b.client.RestoreFiles(ectx, files, r.result.RewriteRules, b.updateCh) if e != nil { + r.done() return e } log.Info("restore batch done", rtree.ZapRanges(r.result.Ranges)) From e3c56b75eaebcb6fdfbbd9316de60655a73c21dd Mon Sep 17 00:00:00 2001 From: Zhuhe Fang Date: Mon, 20 Dec 2021 11:11:46 +0800 Subject: [PATCH 5/9] executor: buildWindow cannot call typeInfer twice (#30773) close pingcap/tidb#30402 --- executor/aggregate_test.go | 1 + executor/builder.go | 2 +- executor/tiflash_test.go | 1 + expression/aggregation/descriptor.go | 9 +++++++++ 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index a99d7e71a69f4..976775ce5dc4c 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -1483,6 +1483,7 @@ func TestAvgDecimal(t *testing.T) { tk.MustExec("insert into td values (0,29815);") tk.MustExec("insert into td values (10017,-32661);") tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260")) + tk.MustQuery(" SELECT AVG(col_bigint) OVER (PARTITION BY col_smallint) as field2 FROM td where col_smallint = -23828;").Sort().Check(testkit.Rows("4.0000")) tk.MustExec("drop table td;") } diff --git a/executor/builder.go b/executor/builder.go index 73d20b93eb9bb..78c94fd02f1aa 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -4192,7 +4192,7 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) Executor { partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) for _, desc := range v.WindowFuncDescs { - aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false) + aggDesc, err := aggregation.NewAggFuncDescForWindowFunc(b.ctx, desc, false) if err != nil { b.err = err return nil diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 10d51bb1a27d7..14e7c13e9c17e 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -829,6 +829,7 @@ func (s *tiflashTestSuite) TestAvgOverflow(c *C) { tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") tk.MustExec("set @@session.tidb_enforce_mpp=ON") tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260")) + tk.MustQuery(" SELECT AVG(col_bigint) OVER (PARTITION BY col_smallint) as field2 FROM td where col_smallint = -23828;").Sort().Check(testkit.Rows("4.0000")) tk.MustExec("drop table if exists td;") } diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 30f020e7dfdf2..8882b93ec05d7 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -43,6 +43,7 @@ type AggFuncDesc struct { } // NewAggFuncDesc creates an aggregation function signature descriptor. +// this func cannot be called twice as the TypeInfer has changed the type of args in the first time. func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) { b, err := newBaseFuncDesc(ctx, name, args) if err != nil { @@ -51,6 +52,14 @@ func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expre return &AggFuncDesc{baseFuncDesc: b, HasDistinct: hasDistinct}, nil } +// NewAggFuncDescForWindowFunc creates an aggregation function from window functions, where baseFuncDesc may be ready. +func NewAggFuncDescForWindowFunc(ctx sessionctx.Context, Desc *WindowFuncDesc, hasDistinct bool) (*AggFuncDesc, error) { + if Desc.RetTp == nil { // safety check + return NewAggFuncDesc(ctx, Desc.Name, Desc.Args, hasDistinct) + } + return &AggFuncDesc{baseFuncDesc: baseFuncDesc{Desc.Name, Desc.Args, Desc.RetTp}, HasDistinct: hasDistinct}, nil +} + // String implements the fmt.Stringer interface. func (a *AggFuncDesc) String() string { buffer := bytes.NewBufferString(a.Name) From ab35db14a6fcfe0de43b7bca273d3839ed688176 Mon Sep 17 00:00:00 2001 From: tangenta Date: Mon, 20 Dec 2021 15:27:46 +0800 Subject: [PATCH 6/9] *: refactor encoding and uniform usages (#30288) --- .../r/new_character_set_builtin.result | 9 +- .../t/new_character_set_builtin.test | 1 + expression/builtin_convert_charset.go | 152 ++++++--- expression/builtin_encryption_test.go | 46 +-- expression/builtin_string.go | 115 +++---- expression/builtin_string_test.go | 67 ++-- expression/builtin_string_vec.go | 128 ++++---- expression/collation.go | 25 +- expression/integration_test.go | 20 +- parser/charset/encoding.go | 303 +++++++----------- parser/charset/encoding_ascii.go | 71 ++++ parser/charset/encoding_base.go | 117 +++++++ parser/charset/encoding_bin.go | 61 ++++ parser/charset/encoding_gbk.go | 93 ++++++ parser/charset/encoding_latin1.go | 51 +++ parser/charset/encoding_table.go | 197 +----------- parser/charset/encoding_test.go | 155 +++------ parser/charset/encoding_utf8.go | 114 +++++++ parser/charset/gbk.go | 29 -- parser/charset/latin.go | 48 --- parser/charset/special_case_tables.go | 104 ------ parser/charset/utf.go | 34 -- parser/lexer.go | 12 +- parser/parser.go | 6 +- parser/parser.y | 6 +- parser/yy_parser.go | 4 +- planner/core/integration_test.go | 2 +- server/util.go | 34 +- table/column.go | 75 +++-- 29 files changed, 1051 insertions(+), 1028 deletions(-) create mode 100644 parser/charset/encoding_ascii.go create mode 100644 parser/charset/encoding_base.go create mode 100644 parser/charset/encoding_bin.go create mode 100644 parser/charset/encoding_gbk.go create mode 100644 parser/charset/encoding_latin1.go create mode 100644 parser/charset/encoding_utf8.go delete mode 100644 parser/charset/gbk.go delete mode 100644 parser/charset/latin.go delete mode 100644 parser/charset/special_case_tables.go delete mode 100644 parser/charset/utf.go diff --git a/cmd/explaintest/r/new_character_set_builtin.result b/cmd/explaintest/r/new_character_set_builtin.result index fa733d990ef1e..f587a5ac1370e 100644 --- a/cmd/explaintest/r/new_character_set_builtin.result +++ b/cmd/explaintest/r/new_character_set_builtin.result @@ -1,3 +1,4 @@ +set @@sql_mode = ''; drop table if exists t; create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20)); insert into t values ('一二三', '一二三', '一二三'); @@ -244,8 +245,8 @@ insert into t values ('65'), ('123456'), ('123456789'); select char(a using gbk), char(a using utf8), char(a) from t; char(a using gbk) char(a using utf8) char(a) A A A -釦 @ @ -NULL [ [ +釦  @ +[ [ [ select char(12345678 using gbk); char(12345678 using gbk) 糰N @@ -253,8 +254,8 @@ set @@tidb_enable_vectorized_expression = true; select char(a using gbk), char(a using utf8), char(a) from t; char(a using gbk) char(a using utf8) char(a) A A A -釦 @ @ -NULL [ [ +釦  @ +[ [ [ select char(12345678 using gbk); char(12345678 using gbk) 糰N diff --git a/cmd/explaintest/t/new_character_set_builtin.test b/cmd/explaintest/t/new_character_set_builtin.test index 09b823cdcfaa9..bb0a6321e8a53 100644 --- a/cmd/explaintest/t/new_character_set_builtin.test +++ b/cmd/explaintest/t/new_character_set_builtin.test @@ -1,3 +1,4 @@ +set @@sql_mode = ''; -- test for builtin function hex(), length(), ascii(), octet_length() drop table if exists t; create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20)); diff --git a/expression/builtin_convert_charset.go b/expression/builtin_convert_charset.go index fc709cc7c61f0..a24ed138eb578 100644 --- a/expression/builtin_convert_charset.go +++ b/expression/builtin_convert_charset.go @@ -16,7 +16,6 @@ package expression import ( "fmt" - "unicode/utf8" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/ast" @@ -27,6 +26,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/dbterror" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tipb/go-tipb" ) @@ -92,9 +92,9 @@ func (b *builtinInternalToBinarySig) evalString(row chunk.Row) (res string, isNu return res, isNull, err } tp := b.args[0].GetType() - enc := charset.NewEncoding(tp.Charset) - res, err = enc.EncodeString(val) - return res, false, err + enc := charset.FindEncoding(tp.Charset) + ret, err := enc.Transform(nil, hack.Slice(val), charset.OpEncode) + return string(ret), false, err } func (b *builtinInternalToBinarySig) vectorized() bool { @@ -111,7 +111,7 @@ func (b *builtinInternalToBinarySig) vecEvalString(input *chunk.Chunk, result *c if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } - enc := charset.NewEncoding(b.args[0].GetType().Charset) + enc := charset.FindEncoding(b.args[0].GetType().Charset) result.ReserveString(n) var encodedBuf []byte for i := 0; i < n; i++ { @@ -119,11 +119,11 @@ func (b *builtinInternalToBinarySig) vecEvalString(input *chunk.Chunk, result *c result.AppendNull() continue } - strBytes, err := enc.Encode(encodedBuf, buf.GetBytes(i)) + encodedBuf, err = enc.Transform(encodedBuf, buf.GetBytes(i), charset.OpEncode) if err != nil { return err } - result.AppendBytes(strBytes) + result.AppendBytes(encodedBuf) } return nil } @@ -170,9 +170,13 @@ func (b *builtinInternalFromBinarySig) evalString(row chunk.Row) (res string, is if isNull || err != nil { return val, isNull, err } - transferString := b.getTransferFunc() - tBytes, err := transferString([]byte(val)) - return string(tBytes), false, err + enc := charset.FindEncoding(b.tp.Charset) + ret, err := enc.Transform(nil, hack.Slice(val), charset.OpDecode) + if err != nil { + strHex := fmt.Sprintf("%X", val) + err = errCannotConvertString.GenWithStackByArgs(strHex, charset.CharsetBin, b.tp.Charset) + } + return string(ret), false, err } func (b *builtinInternalFromBinarySig) vectorized() bool { @@ -189,45 +193,25 @@ func (b *builtinInternalFromBinarySig) vecEvalString(input *chunk.Chunk, result if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } - transferString := b.getTransferFunc() + enc := charset.FindEncoding(b.tp.Charset) + var encBuf []byte result.ReserveString(n) for i := 0; i < n; i++ { if buf.IsNull(i) { result.AppendNull() continue } - str, err := transferString(buf.GetBytes(i)) + str := buf.GetBytes(i) + encBuf, err = enc.Transform(encBuf, str, charset.OpDecode) if err != nil { - return err + strHex := fmt.Sprintf("%X", str) + return errCannotConvertString.GenWithStackByArgs(strHex, charset.CharsetBin, b.tp.Charset) } - result.AppendBytes(str) + result.AppendBytes(encBuf) } return nil } -func (b *builtinInternalFromBinarySig) getTransferFunc() func([]byte) ([]byte, error) { - var transferString func([]byte) ([]byte, error) - if b.tp.Charset == charset.CharsetUTF8MB4 || b.tp.Charset == charset.CharsetUTF8 { - transferString = func(s []byte) ([]byte, error) { - if !utf8.Valid(s) { - return nil, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset) - } - return s, nil - } - } else { - enc := charset.NewEncoding(b.tp.Charset) - var buf []byte - transferString = func(s []byte) ([]byte, error) { - str, err := enc.Decode(buf, s) - if err != nil { - return nil, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset) - } - return str, nil - } - } - return transferString -} - // BuildToBinaryFunction builds to_binary function. func BuildToBinaryFunction(ctx sessionctx.Context, expr Expression) (res Expression) { fc := &tidbToBinaryFunctionClass{baseFunctionClass{InternalFuncToBinary, 1, 1}} @@ -258,26 +242,94 @@ func BuildFromBinaryFunction(ctx sessionctx.Context, expr Expression, tp *types. return FoldConstant(res) } +type funcProp int8 + +const ( + funcPropNone funcProp = iota + // The arguments of these functions are wrapped with to_binary(). + // For compatibility reason, legacy charsets arguments are not wrapped. + // Legacy charsets: utf8mb4, utf8, latin1, ascii, binary. + funcPropBinAware + // The arguments of these functions are wrapped with to_binary() or from_binary() according to + // the evaluated result charset and the argument charset. + // For binary argument && string result, wrap it with from_binary(). + // For string argument && binary result, wrap it with to_binary(). + funcPropAuto +) + +// convertActionMap collects from https://dev.mysql.com/doc/refman/8.0/en/string-functions.html. +var convertActionMap = map[funcProp][]string{ + funcPropNone: { + /* args != strings */ + ast.Bin, ast.CharFunc, ast.DateFormat, ast.Oct, ast.Space, + /* only 1 string arg, no implicit conversion */ + ast.CharLength, ast.CharacterLength, ast.FromBase64, ast.Lcase, ast.Left, ast.LoadFile, + ast.Lower, ast.LTrim, ast.Mid, ast.Ord, ast.Quote, ast.Repeat, ast.Reverse, ast.Right, + ast.RTrim, ast.Soundex, ast.Substr, ast.Substring, ast.Ucase, ast.Unhex, ast.Upper, ast.WeightString, + /* args are independent, no implicit conversion */ + ast.Elt, + }, + funcPropBinAware: { + /* result is binary-aware */ + ast.ASCII, ast.BitLength, ast.Hex, ast.Length, ast.OctetLength, ast.ToBase64, + /* encrypt functions */ + ast.AesDecrypt, ast.Decode, ast.Encode, ast.PasswordFunc, ast.MD5, ast.SHA, ast.SHA1, + ast.SHA2, ast.Compress, ast.AesEncrypt, + }, + funcPropAuto: { + /* string functions */ ast.Concat, ast.ConcatWS, ast.ExportSet, ast.Field, ast.FindInSet, + ast.InsertFunc, ast.Instr, ast.Lpad, ast.Locate, ast.Lpad, ast.MakeSet, ast.Position, + ast.Replace, ast.Rpad, ast.SubstringIndex, ast.Trim, + /* operators */ + ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, ast.NE, ast.NullEQ, ast.If, ast.Ifnull, ast.In, + ast.Case, + /* string comparing */ + ast.Like, ast.Strcmp, + /* regex */ + ast.Regexp, + }, +} + +var convertFuncsMap = map[string]funcProp{} + +func init() { + for k, fns := range convertActionMap { + for _, f := range fns { + convertFuncsMap[f] = k + } + } +} + // HandleBinaryLiteral wraps `expr` with to_binary or from_binary sig. func HandleBinaryLiteral(ctx sessionctx.Context, expr Expression, ec *ExprCollation, funcName string) Expression { - switch funcName { - case ast.Concat, ast.ConcatWS, ast.Lower, ast.Lcase, ast.Reverse, ast.Upper, ast.Ucase, ast.Quote, ast.Coalesce, - ast.Left, ast.Right, ast.Repeat, ast.Trim, ast.LTrim, ast.RTrim, ast.Substr, ast.SubstringIndex, ast.Replace, - ast.Substring, ast.Mid, ast.Translate, ast.InsertFunc, ast.Lpad, ast.Rpad, ast.Elt, ast.ExportSet, ast.MakeSet, - ast.FindInSet, ast.Regexp, ast.Field, ast.Locate, ast.Instr, ast.Position, ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, - ast.NE, ast.NullEQ, ast.Strcmp, ast.If, ast.Ifnull, ast.Like, ast.In, ast.DateFormat, ast.TimeFormat: - if ec.Charset == charset.CharsetBin && expr.GetType().Charset != charset.CharsetBin { + argChs, dstChs := expr.GetType().Charset, ec.Charset + switch convertFuncsMap[funcName] { + case funcPropNone: + return expr + case funcPropBinAware: + if isLegacyCharset(argChs) { + return expr + } + return BuildToBinaryFunction(ctx, expr) + case funcPropAuto: + if argChs != charset.CharsetBin && dstChs == charset.CharsetBin { + if isLegacyCharset(argChs) { + return expr + } return BuildToBinaryFunction(ctx, expr) - } else if ec.Charset != charset.CharsetBin && expr.GetType().Charset == charset.CharsetBin { + } else if argChs == charset.CharsetBin && dstChs != charset.CharsetBin { ft := expr.GetType().Clone() ft.Charset, ft.Collate = ec.Charset, ec.Collation return BuildFromBinaryFunction(ctx, expr, ft) } - case ast.Hex, ast.Length, ast.OctetLength, ast.ASCII, ast.ToBase64, ast.AesEncrypt, ast.AesDecrypt, ast.Decode, ast.Encode, - ast.PasswordFunc, ast.MD5, ast.SHA, ast.SHA1, ast.SHA2, ast.Compress: - if _, err := charset.GetDefaultCollationLegacy(expr.GetType().Charset); err != nil { - return BuildToBinaryFunction(ctx, expr) - } } return expr } + +func isLegacyCharset(chs string) bool { + switch chs { + case charset.CharsetUTF8, charset.CharsetUTF8MB4, charset.CharsetASCII, charset.CharsetLatin1, charset.CharsetBin: + return true + } + return false +} diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 2dec9405f03db..bda56e7bc4bdb 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -16,6 +16,7 @@ package expression import ( "encoding/hex" + "fmt" "strings" "testing" @@ -91,9 +92,10 @@ func TestSQLEncode(t *testing.T) { d, err := f.Eval(chunk.Row{}) require.NoError(t, err) if test.origin != nil { - result, err := charset.NewEncoding(test.chs).EncodeString(test.origin.(string)) + enc := charset.FindEncoding(test.chs) + result, err := enc.Transform(nil, []byte(test.origin.(string)), charset.OpEncode) require.NoError(t, err) - require.Equal(t, types.NewCollationStringDatum(result, test.chs), d) + require.Equal(t, types.NewCollationStringDatum(string(result), test.chs), d) } else { result := types.NewDatum(test.origin) require.Equal(t, result.GetBytes(), d.GetBytes()) @@ -163,7 +165,8 @@ func TestAESEncrypt(t *testing.T) { testAmbiguousInput(t, ctx, ast.AesEncrypt) // Test GBK String - gbkStr, _ := charset.NewEncoding("gbk").EncodeString("你好") + enc := charset.FindEncoding("gbk") + gbkStr, _ := enc.Transform(nil, []byte("你好"), charset.OpEncode) gbkTests := []struct { mode string chs string @@ -188,19 +191,20 @@ func TestAESEncrypt(t *testing.T) { } for _, tt := range gbkTests { + msg := fmt.Sprintf("%v", tt) err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, tt.chs) - require.NoError(t, err) + require.NoError(t, err, msg) err = variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode) - require.NoError(t, err) + require.NoError(t, err, msg) - args := datumsToConstants([]types.Datum{types.NewDatum(tt.origin)}) + args := primitiveValsToConstants(ctx, []interface{}{tt.origin}) args = append(args, primitiveValsToConstants(ctx, tt.params)...) f, err := fc.getFunction(ctx, args) - require.NoError(t, err) + require.NoError(t, err, msg) crypt, err := evalBuiltinFunc(f, chunk.Row{}) - require.NoError(t, err) - require.Equal(t, types.NewDatum(tt.crypt), toHex(crypt)) + require.NoError(t, err, msg) + require.Equal(t, types.NewDatum(tt.crypt), toHex(crypt), msg) } } @@ -209,21 +213,22 @@ func TestAESDecrypt(t *testing.T) { fc := funcs[ast.AesDecrypt] for _, tt := range aesTests { + msg := fmt.Sprintf("%v", tt) err := variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode) - require.NoError(t, err) + require.NoError(t, err, msg) args := []types.Datum{fromHex(tt.crypt)} for _, param := range tt.params { args = append(args, types.NewDatum(param)) } f, err := fc.getFunction(ctx, datumsToConstants(args)) - require.NoError(t, err) + require.NoError(t, err, msg) str, err := evalBuiltinFunc(f, chunk.Row{}) - require.NoError(t, err) + require.NoError(t, err, msg) if tt.origin == nil { require.True(t, str.IsNull()) continue } - require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str) + require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str, msg) } err := variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, "aes-128-ecb") require.NoError(t, err) @@ -231,7 +236,9 @@ func TestAESDecrypt(t *testing.T) { testAmbiguousInput(t, ctx, ast.AesDecrypt) // Test GBK String - gbkStr, _ := charset.NewEncoding("gbk").EncodeString("你好") + enc := charset.FindEncoding("gbk") + r, _ := enc.Transform(nil, []byte("你好"), charset.OpEncode) + gbkStr := string(r) gbkTests := []struct { mode string chs string @@ -256,18 +263,19 @@ func TestAESDecrypt(t *testing.T) { } for _, tt := range gbkTests { + msg := fmt.Sprintf("%v", tt) err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, tt.chs) - require.NoError(t, err) + require.NoError(t, err, msg) err = variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode) - require.NoError(t, err) + require.NoError(t, err, msg) // Set charset and collate except first argument args := datumsToConstants([]types.Datum{fromHex(tt.crypt)}) args = append(args, primitiveValsToConstants(ctx, tt.params)...) f, err := fc.getFunction(ctx, args) - require.NoError(t, err) + require.NoError(t, err, msg) str, err := evalBuiltinFunc(f, chunk.Row{}) - require.NoError(t, err) - require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str) + require.NoError(t, err, msg) + require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str, msg) } } diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 9ac2eb370d380..925aecb3dc33a 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -41,7 +41,6 @@ import ( "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" - "golang.org/x/text/transform" ) var ( @@ -706,7 +705,7 @@ func (c *lowerFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi sig = &builtinLowerSig{bf} sig.setPbCode(tipb.ScalarFuncSig_Lower) } else { - sig = &builtinLowerUTF8Sig{bf, charset.NewEncoding(argTp.Charset)} + sig = &builtinLowerUTF8Sig{bf} sig.setPbCode(tipb.ScalarFuncSig_LowerUTF8) } return sig, nil @@ -714,15 +713,11 @@ func (c *lowerFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi type builtinLowerUTF8Sig struct { baseBuiltinFunc - encoding *charset.Encoding } func (b *builtinLowerUTF8Sig) Clone() builtinFunc { newSig := &builtinLowerUTF8Sig{} newSig.cloneFrom(&b.baseBuiltinFunc) - if b.encoding != nil { - newSig.encoding = charset.NewEncoding(b.encoding.Name()) - } return newSig } @@ -733,8 +728,8 @@ func (b *builtinLowerUTF8Sig) evalString(row chunk.Row) (d string, isNull bool, if isNull || err != nil { return d, isNull, err } - - return b.encoding.ToLower(d), false, nil + enc := charset.FindEncoding(b.args[0].GetType().Charset) + return enc.ToLower(d), false, nil } type builtinLowerSig struct { @@ -905,7 +900,7 @@ func (c *upperFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi sig = &builtinUpperSig{bf} sig.setPbCode(tipb.ScalarFuncSig_Upper) } else { - sig = &builtinUpperUTF8Sig{bf, charset.NewEncoding(argTp.Charset)} + sig = &builtinUpperUTF8Sig{bf} sig.setPbCode(tipb.ScalarFuncSig_UpperUTF8) } return sig, nil @@ -913,15 +908,11 @@ func (c *upperFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi type builtinUpperUTF8Sig struct { baseBuiltinFunc - encoding *charset.Encoding } func (b *builtinUpperUTF8Sig) Clone() builtinFunc { newSig := &builtinUpperUTF8Sig{} newSig.cloneFrom(&b.baseBuiltinFunc) - if b.encoding != nil { - newSig.encoding = charset.NewEncoding(b.encoding.Name()) - } return newSig } @@ -932,8 +923,8 @@ func (b *builtinUpperUTF8Sig) evalString(row chunk.Row) (d string, isNull bool, if isNull || err != nil { return d, isNull, err } - - return b.encoding.ToUpper(d), false, nil + enc := charset.FindEncoding(b.args[0].GetType().Charset) + return enc.ToUpper(d), false, nil } type builtinUpperSig struct { @@ -1141,27 +1132,27 @@ func (b *builtinConvertSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, err } - - // Since charset is already validated and set from getFunction(), there's no - // need to get charset from args again. - encoding, _ := charset.Lookup(b.tp.Charset) - // However, if `b.tp.Charset` is abnormally set to a wrong charset, we still - // return with error. - if encoding == nil { - return "", true, errUnknownCharacterSet.GenWithStackByArgs(b.tp.Charset) + argTp, resultTp := b.args[0].GetType(), b.tp + if !charset.IsSupportedEncoding(resultTp.Charset) { + return "", false, errUnknownCharacterSet.GenWithStackByArgs(resultTp.Charset) } - // if expr is binary string and convert meet error, we should return NULL. - if types.IsBinaryStr(b.args[0].GetType()) { - exprInternal, _, err := transform.String(encoding.NewDecoder(), expr) - return exprInternal, err != nil, nil + if types.IsBinaryStr(argTp) { + // Convert charset binary -> utf8. If it meets error, NULL is returned. + enc := charset.FindEncoding(resultTp.Charset) + ret, err := enc.Transform(nil, hack.Slice(expr), charset.OpDecodeReplace) + return string(ret), err != nil, nil + } else if types.IsBinaryStr(resultTp) { + // Convert charset utf8 -> binary. + enc := charset.FindEncoding(argTp.Charset) + ret, err := enc.Transform(nil, hack.Slice(expr), charset.OpEncode) + return string(ret), false, err } - if types.IsBinaryStr(b.tp) { - enc := charset.NewEncoding(b.args[0].GetType().Charset) - expr, err = enc.EncodeString(expr) - return expr, false, err + enc := charset.FindEncoding(resultTp.Charset) + if !charset.IsValidString(enc, expr) { + replace, _ := enc.Transform(nil, hack.Slice(expr), charset.OpReplace) + return string(replace), false, nil } - enc := charset.NewEncoding(b.tp.Charset) - return string(enc.EncodeInternal(nil, []byte(expr))), false, nil + return expr, false, nil } type substringFunctionClass struct { @@ -2327,12 +2318,7 @@ func (b *builtinBitLengthSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull || err != nil { return 0, isNull, err } - argTp := b.args[0].GetType() - dBytes, err := charset.NewEncoding(argTp.Charset).Encode(nil, hack.Slice(val)) - if err != nil { - return 0, isNull, err - } - return int64(len(dBytes) * 8), false, nil + return int64(len(val) * 8), false, nil } type charFunctionClass struct { @@ -2421,12 +2407,15 @@ func (b *builtinCharSig) evalString(row chunk.Row) (string, bool, error) { } dBytes := b.convertToBytes(bigints) - resultBytes, err := charset.NewEncoding(b.tp.Charset).Decode(nil, dBytes) + enc := charset.FindEncoding(b.tp.Charset) + res, err := enc.Transform(nil, dBytes, charset.OpDecode) if err != nil { b.ctx.GetSessionVars().StmtCtx.AppendWarning(err) - return "", true, nil + if b.ctx.GetSessionVars().StrictSQLMode { + return "", true, nil + } } - return string(resultBytes), false, nil + return string(res), false, nil } type charLengthFunctionClass struct { @@ -2893,43 +2882,19 @@ func (b *builtinOrdSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull, err } - charSet := b.args[0].GetType().Charset - ord, err := chooseOrdFunc(charSet) + strBytes := hack.Slice(str) + enc := charset.FindEncoding(b.args[0].GetType().Charset) + w := len(charset.EncodingUTF8Impl.Peek(strBytes)) + res, err := enc.Transform(nil, strBytes[:w], charset.OpEncode) if err != nil { - return 0, false, err - } - - enc := charset.NewEncoding(charSet) - leftMost, err := enc.EncodeFirstChar(nil, hack.Slice(str)) - if err != nil { - return 0, false, err - } - return ord(leftMost), false, nil -} - -func chooseOrdFunc(charSet string) (func([]byte) int64, error) { - // use utf8 by default - if charSet == "" { - charSet = charset.CharsetUTF8 - } - desc, err := charset.GetCharsetInfo(charSet) - if err != nil { - return nil, err - } - if desc.Maxlen == 1 { - return ordSingleByte, nil - } - return ordOthers, nil -} - -func ordSingleByte(src []byte) int64 { - if len(src) == 0 { - return 0 + // Fallback to the first byte. + return calcOrd(strBytes[:1]), false, nil } - return int64(src[0]) + // Only the first character is considered. + return calcOrd(res[:len(enc.Peek(res))]), false, nil } -func ordOthers(leftMost []byte) int64 { +func calcOrd(leftMost []byte) int64 { var result int64 var factor int64 = 1 for i := len(leftMost) - 1; i >= 0; i-- { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 8aeea63add9e8..5cd3dbbb2757c 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -920,6 +920,7 @@ func TestConvert(t *testing.T) { wrongFunction := f.(*builtinConvertSig) wrongFunction.tp.Charset = "wrongcharset" _, err = evalBuiltinFunc(wrongFunction, chunk.Row{}) + require.Error(t, err) require.Equal(t, "[expression:1115]Unknown character set: 'wrongcharset'", err.Error()) } @@ -1402,14 +1403,10 @@ func TestBitLength(t *testing.T) { } func TestChar(t *testing.T) { + collate.SetCharsetFeatEnabledForTest(true) + defer collate.SetCharsetFeatEnabledForTest(false) ctx := createContext(t) - stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate - stmtCtx.IgnoreTruncate = true - defer func() { - stmtCtx.IgnoreTruncate = origin - }() - + ctx.GetSessionVars().StmtCtx.IgnoreTruncate = true tbl := []struct { str string iNum int64 @@ -1418,30 +1415,36 @@ func TestChar(t *testing.T) { result interface{} warnings int }{ - {"65", 66, 67.5, "utf8", "ABD", 0}, // float - {"65", 16740, 67.5, "utf8", "AAdD", 0}, // large num - {"65", -1, 67.5, nil, "A\xff\xff\xff\xffD", 0}, // nagtive int - {"a", -1, 67.5, nil, "\x00\xff\xff\xff\xffD", 0}, // invalid 'a' - // TODO: Uncomment it when issue #29685 be closed - // {"65", -1, 67.5, "utf8", nil, 1}, // with utf8, return nil - // {"a", -1, 67.5, "utf8", nil, 2}, // with utf8, return nil - // TODO: Uncomment it when gbk be added into charsetInfos - // {"1234567", 1234567, 1234567, "gbk", "謬謬謬", 0}, // test char for gbk - // {"123456789", 123456789, 123456789, "gbk", nil, 3}, // invalid 123456789 in gbk - } - for _, v := range tbl { + {"65", 66, 67.5, "utf8", "ABD", 0}, // float + {"65", 16740, 67.5, "utf8", "AAdD", 0}, // large num + {"65", -1, 67.5, nil, "A\xff\xff\xff\xffD", 0}, // negative int + {"a", -1, 67.5, nil, "\x00\xff\xff\xff\xffD", 0}, // invalid 'a' + {"65", -1, 67.5, "utf8", nil, 1}, // with utf8, return nil + {"a", -1, 67.5, "utf8", nil, 1}, // with utf8, return nil + {"1234567", 1234567, 1234567, "gbk", "\u0012謬\u0012謬\u0012謬", 0}, // test char for gbk + {"123456789", 123456789, 123456789, "gbk", nil, 1}, // invalid 123456789 in gbk + } + run := func(i int, result interface{}, warnCnt int, dts ...interface{}) { fc := funcs[ast.CharFunc] - f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(v.str, v.iNum, v.fNum, v.charset))) - require.NoError(t, err) - require.NotNil(t, f) + f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(dts...))) + require.NoError(t, err, i) + require.NotNil(t, f, i) r, err := evalBuiltinFunc(f, chunk.Row{}) - require.NoError(t, err) - trequire.DatumEqual(t, types.NewDatum(v.result), r) - if v.warnings != 0 { - warnings := ctx.GetSessionVars().StmtCtx.GetWarnings() - require.Equal(t, v.warnings, len(warnings)) + require.NoError(t, err, i) + trequire.DatumEqual(t, types.NewDatum(result), r, i) + if warnCnt != 0 { + warnings := ctx.GetSessionVars().StmtCtx.TruncateWarnings(0) + require.Equal(t, warnCnt, len(warnings), fmt.Sprintf("%d: %v", i, warnings)) } } + for i, v := range tbl { + run(i, v.result, v.warnings, v.str, v.iNum, v.fNum, v.charset) + } + // char() returns null only when the sql_mode is strict. + ctx.GetSessionVars().StrictSQLMode = true + run(-1, nil, 1, 123456, "utf8") + ctx.GetSessionVars().StrictSQLMode = false + run(-2, string([]byte{1}), 1, 123456, "utf8") } func TestCharLength(t *testing.T) { @@ -2205,11 +2208,11 @@ func TestOrd(t *testing.T) { {2.3, 50, "", false, false}, {nil, 0, "", true, false}, {"", 0, "", false, false}, - {"你好", 14990752, "", false, false}, - {"にほん", 14909867, "", false, false}, - {"한국", 15570332, "", false, false}, - {"👍", 4036989325, "", false, false}, - {"א", 55184, "", false, false}, + {"你好", 14990752, "utf8mb4", false, false}, + {"にほん", 14909867, "utf8mb4", false, false}, + {"한국", 15570332, "utf8mb4", false, false}, + {"👍", 4036989325, "utf8mb4", false, false}, + {"א", 55184, "utf8mb4", false, false}, {"abc", 97, "gbk", false, false}, {"一二三", 53947, "gbk", false, false}, {"àáèé", 43172, "gbk", false, false}, diff --git a/expression/builtin_string_vec.go b/expression/builtin_string_vec.go index 62c123faf07bd..3da555f9319ed 100644 --- a/expression/builtin_string_vec.go +++ b/expression/builtin_string_vec.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" - "golang.org/x/text/transform" ) func (b *builtinLowerSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { @@ -46,11 +45,10 @@ func (b *builtinLowerUTF8Sig) vecEvalString(input *chunk.Chunk, result *chunk.Co if err := b.args[0].VecEvalString(b.ctx, input, result); err != nil { return err } - + enc := charset.FindEncoding(b.args[0].GetType().Charset) for i := 0; i < input.NumRows(); i++ { - result.SetRaw(i, []byte(b.encoding.ToLower(result.GetString(i)))) + result.SetRaw(i, []byte(enc.ToLower(result.GetString(i)))) } - return nil } @@ -146,9 +144,9 @@ func (b *builtinUpperUTF8Sig) vecEvalString(input *chunk.Chunk, result *chunk.Co if err := b.args[0].VecEvalString(b.ctx, input, result); err != nil { return err } - + enc := charset.FindEncoding(b.args[0].GetType().Charset) for i := 0; i < input.NumRows(); i++ { - result.SetRaw(i, []byte(b.encoding.ToUpper(result.GetString(i)))) + result.SetRaw(i, []byte(enc.ToUpper(result.GetString(i)))) } return nil } @@ -677,49 +675,59 @@ func (b *builtinConvertSig) vecEvalString(input *chunk.Chunk, result *chunk.Colu if err := b.args[0].VecEvalString(b.ctx, input, expr); err != nil { return err } - // Since charset is already validated and set from getFunction(), there's no - // need to get charset from args again. - encoding, _ := charset.Lookup(b.tp.Charset) - // However, if `b.tp.Charset` is abnormally set to a wrong charset, we still - // return with error. - if encoding == nil { - return errUnknownCharacterSet.GenWithStackByArgs(b.tp.Charset) + argTp, resultTp := b.args[0].GetType(), b.tp + result.ReserveString(n) + done := vecEvalStringConvertBinary(result, n, expr, argTp, resultTp) + if done { + return nil } - decoder := encoding.NewDecoder() - isBinaryStr := types.IsBinaryStr(b.args[0].GetType()) - isRetBinary := types.IsBinaryStr(b.tp) - enc := charset.NewEncoding(b.tp.Charset) - if isRetBinary { - enc = charset.NewEncoding(b.args[0].GetType().Charset) + enc := charset.FindEncoding(resultTp.Charset) + var encBuf []byte + for i := 0; i < n; i++ { + if expr.IsNull(i) { + result.AppendNull() + continue + } + exprI := expr.GetBytes(i) + if !charset.IsValid(enc, exprI) { + encBuf, _ = enc.Transform(encBuf, exprI, charset.OpReplace) + result.AppendBytes(encBuf) + } else { + result.AppendBytes(exprI) + } } + return nil +} - result.ReserveString(n) +func vecEvalStringConvertBinary(result *chunk.Column, n int, expr *chunk.Column, + argTp, resultTp *types.FieldType) (done bool) { + var chs string + var op charset.Op + if types.IsBinaryStr(argTp) { + chs = resultTp.Charset + op = charset.OpDecode + } else if types.IsBinaryStr(resultTp) { + chs = argTp.Charset + op = charset.OpEncode + } else { + return false + } + enc := charset.FindEncoding(chs) + var encBuf []byte for i := 0; i < n; i++ { if expr.IsNull(i) { result.AppendNull() continue } - exprI := expr.GetString(i) - if isBinaryStr { - target, _, err := transform.String(decoder, exprI) - if err != nil { - result.AppendNull() - continue - } - result.AppendString(target) + encBuf, err := enc.Transform(encBuf, expr.GetBytes(i), op) + if err != nil { + result.AppendNull() } else { - if isRetBinary { - str, err := enc.EncodeString(exprI) - if err != nil { - return err - } - result.AppendString(str) - continue - } - result.AppendString(string(enc.EncodeInternal(nil, []byte(exprI)))) + result.AppendBytes(encBuf) } + continue } - return nil + return true } func (b *builtinSubstringIndexSig) vectorized() bool { @@ -2068,15 +2076,9 @@ func (b *builtinOrdSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) err return err } - charSet := b.args[0].GetType().Charset - ord, err := chooseOrdFunc(charSet) - if err != nil { - return err - } - - enc := charset.NewEncoding(charSet) - var encodedBuf []byte - + enc := charset.FindEncoding(b.args[0].GetType().Charset) + var x [4]byte + encBuf := x[:] result.ResizeInt64(n, false) result.MergeNulls(buf) i64s := result.Int64s() @@ -2084,12 +2086,15 @@ func (b *builtinOrdSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) err if result.IsNull(i) { continue } - str := buf.GetBytes(i) - encoded, err := enc.EncodeFirstChar(encodedBuf, str) + strBytes := buf.GetBytes(i) + w := len(charset.EncodingUTF8Impl.Peek(strBytes)) + encBuf, err = enc.Transform(encBuf, strBytes[:w], charset.OpEncode) if err != nil { - return err + i64s[i] = calcOrd(strBytes[:1]) + continue } - i64s[i] = ord(encoded) + // Only the first character is considered. + i64s[i] = calcOrd(encBuf[:len(enc.Peek(encBuf))]) } return nil } @@ -2231,9 +2236,6 @@ func (b *builtinBitLengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum return err } - argTp := b.args[0].GetType() - enc := charset.NewEncoding(argTp.Charset) - result.ResizeInt64(n, false) result.MergeNulls(buf) i64s := result.Int64s() @@ -2242,11 +2244,7 @@ func (b *builtinBitLengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum continue } str := buf.GetBytes(i) - dBytes, err := enc.Encode(nil, str) - if err != nil { - return err - } - i64s[i] = int64(len(dBytes) * 8) + i64s[i] = int64(len(str) * 8) } return nil } @@ -2282,7 +2280,8 @@ func (b *builtinCharSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) bufint[i] = buf[i].Int64s() } var resultBytes []byte - enc := charset.NewEncoding(b.tp.Charset) + enc := charset.FindEncoding(b.tp.Charset) + hasStrictMode := b.ctx.GetSessionVars().StrictSQLMode for i := 0; i < n; i++ { bigints = bigints[0:0] for j := 0; j < l-1; j++ { @@ -2292,12 +2291,13 @@ func (b *builtinCharSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) bigints = append(bigints, bufint[j][i]) } dBytes := b.convertToBytes(bigints) - - resultBytes, err := enc.Decode(resultBytes, dBytes) + resultBytes, err = enc.Transform(resultBytes, dBytes, charset.OpDecode) if err != nil { b.ctx.GetSessionVars().StmtCtx.AppendWarning(err) - result.AppendNull() - continue + if hasStrictMode { + result.AppendNull() + continue + } } result.AppendString(string(resultBytes)) } diff --git a/expression/collation.go b/expression/collation.go index 80a2720c8cfe4..8dc5df02e55e0 100644 --- a/expression/collation.go +++ b/expression/collation.go @@ -296,6 +296,7 @@ func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, e } func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) bool { + enc := charset.FindEncoding(ec.Charset) for _, arg := range args { if arg.GetType().Charset == ec.Charset { continue @@ -311,7 +312,10 @@ func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) if err != nil { return false } - if !isNull && !isValidString(str, ec.Charset) { + if isNull { + continue + } + if !charset.IsValidString(enc, str) { return false } } else { @@ -324,25 +328,6 @@ func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) return true } -// isValidString check if str can convert to dstChs charset without data loss. -func isValidString(str string, dstChs string) bool { - switch dstChs { - case charset.CharsetASCII: - return charset.StringValidatorASCII{}.Validate(str) == -1 - case charset.CharsetLatin1: - // For backward compatibility, we do not block SQL like select '啊' = convert('a' using latin1) collate latin1_bin; - return true - case charset.CharsetUTF8, charset.CharsetUTF8MB4: - // String in tidb is actually use utf8mb4 encoding. - return true - case charset.CharsetBinary: - // Convert to binary is always safe. - return true - default: - return charset.StringValidatorOther{Charset: dstChs}.Validate(str) == -1 - } -} - // inferCollation infers collation, charset, coercibility and check the legitimacy. func inferCollation(exprs ...Expression) *ExprCollation { if len(exprs) == 0 { diff --git a/expression/integration_test.go b/expression/integration_test.go index 79cb81daf90f6..95dc3157507ad 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -378,7 +378,6 @@ func TestConvertToBit(t *testing.T) { } func TestStringBuiltin(t *testing.T) { - t.Skip("it has been broken. Please fix it as soon as possible.") store, clean := testkit.CreateMockStore(t) defer clean() @@ -813,6 +812,25 @@ func TestStringBuiltin(t *testing.T) { "-38.04620119 38.04620115 -38.04620119,38.04620115")) } +func TestInvalidStrings(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + // Test convert invalid string. + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a binary(5));") + tk.MustExec("insert into t values (0x1e240), ('ABCDE');") + tk.MustExec("set tidb_enable_vectorized_expression = on;") + tk.MustQuery("select convert(t.a using utf8) from t;").Check(testkit.Rows("", "ABCDE")) + tk.MustQuery("select convert(0x1e240 using utf8);").Check(testkit.Rows("")) + tk.MustExec("set tidb_enable_vectorized_expression = off;") + tk.MustQuery("select convert(t.a using utf8) from t;").Check(testkit.Rows("", "ABCDE")) + tk.MustQuery("select convert(0x1e240 using utf8);").Check(testkit.Rows("")) +} + func TestEncryptionBuiltin(t *testing.T) { store, clean := testkit.CreateMockStore(t) defer clean() diff --git a/parser/charset/encoding.go b/parser/charset/encoding.go index 8bd1b92c9bcf6..25257c44e440b 100644 --- a/parser/charset/encoding.go +++ b/parser/charset/encoding.go @@ -13,212 +13,131 @@ package charset -import ( - "bytes" - "fmt" - "reflect" - "strings" - "unicode" - "unsafe" - - "github.com/cznic/mathutil" - "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tidb/parser/terror" - "golang.org/x/text/encoding" - "golang.org/x/text/transform" +// Make sure all of them implement Encoding interface. +var ( + _ Encoding = &encodingUTF8{} + _ Encoding = &encodingUTF8MB3Strict{} + _ Encoding = &encodingASCII{} + _ Encoding = &encodingLatin1{} + _ Encoding = &encodingBin{} + _ Encoding = &encodingGBK{} ) -var errInvalidCharacterString = terror.ClassParser.NewStd(mysql.ErrInvalidCharacterString) - -type EncodingLabel string - -// Format trim and change the label to lowercase. -func Format(label string) EncodingLabel { - return EncodingLabel(strings.ToLower(strings.Trim(label, "\t\n\r\f "))) -} - -// Formatted is used when the label is already trimmed and it is lowercase. -func Formatted(label string) EncodingLabel { - return EncodingLabel(label) -} - -// Encoding provide a interface to encode/decode a string with specific encoding. -type Encoding struct { - enc encoding.Encoding - name string - charLength func([]byte) int - specialCase unicode.SpecialCase -} - -// enabled indicates whether the non-utf8 encoding is used. -func (e *Encoding) enabled() bool { - return e != UTF8Encoding -} - -// Name returns the name of the current encoding. -func (e *Encoding) Name() string { - return e.name -} - -// CharLength returns the next character length in bytes. -func (e *Encoding) CharLength(bs []byte) int { - return e.charLength(bs) +// IsSupportedEncoding checks if the charset is fully supported. +func IsSupportedEncoding(charset string) bool { + _, ok := encodingMap[charset] + return ok } -// NewEncoding creates a new Encoding. -func NewEncoding(label string) *Encoding { - if len(label) == 0 { - return UTF8Encoding +// FindEncoding finds the encoding according to charset. +func FindEncoding(charset string) Encoding { + if len(charset) == 0 { + return EncodingBinImpl } - - if e, exist := encodingMap[Format(label)]; exist { + if e, exist := encodingMap[charset]; exist { return e } - return UTF8Encoding -} - -// Encode convert bytes from utf-8 charset to a specific charset. -func (e *Encoding) Encode(dest, src []byte) ([]byte, error) { - if !e.enabled() { - return src, nil - } - return e.transform(e.enc.NewEncoder(), dest, src, false) -} - -// EncodeString convert a string from utf-8 charset to a specific charset. -func (e *Encoding) EncodeString(src string) (string, error) { - if !e.enabled() { - return src, nil - } - bs, err := e.transform(e.enc.NewEncoder(), nil, Slice(src), false) - return string(bs), err -} - -// EncodeFirstChar convert first code point of bytes from utf-8 charset to a specific charset. -func (e *Encoding) EncodeFirstChar(dest, src []byte) ([]byte, error) { - srcNextLen := e.nextCharLenInSrc(src, false) - srcEnd := mathutil.Min(srcNextLen, len(src)) - if !e.enabled() { - return src[:srcEnd], nil - } - return e.transform(e.enc.NewEncoder(), dest, src[:srcEnd], false) -} - -// EncodeInternal convert bytes from utf-8 charset to a specific charset, we actually do not do the real convert, just find the inconvertible character and use ? replace. -// The code below is equivalent to -// expr, _ := e.Encode(dest, src) -// ret, _ := e.Decode(nil, expr) -// return ret -func (e *Encoding) EncodeInternal(dest, src []byte) []byte { - if !e.enabled() { - return src - } - if dest == nil { - dest = make([]byte, 0, len(src)) - } - var srcOffset int + return EncodingBinImpl +} + +var encodingMap = map[string]Encoding{ + CharsetUTF8MB4: EncodingUTF8Impl, + CharsetUTF8: EncodingUTF8Impl, + CharsetGBK: EncodingGBKImpl, + CharsetLatin1: EncodingLatin1Impl, + CharsetBin: EncodingBinImpl, + CharsetASCII: EncodingASCIIImpl, +} + +// Encoding provide encode/decode functions for a string with a specific charset. +type Encoding interface { + // Name is the name of the encoding. + Name() string + // Tp is the type of the encoding. + Tp() EncodingTp + // Peek returns the next char. + Peek(src []byte) []byte + // Foreach iterates the characters in in current encoding. + Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) + // Transform map the bytes in src to dest according to Op. + Transform(dest, src []byte, op Op) ([]byte, error) + // ToUpper change a string to uppercase. + ToUpper(src string) string + // ToLower change a string to lowercase. + ToLower(src string) string +} + +type EncodingTp int8 + +const ( + EncodingTpNone EncodingTp = iota + EncodingTpUTF8 + EncodingTpUTF8MB3Strict + EncodingTpASCII + EncodingTpLatin1 + EncodingTpBin + EncodingTpGBK +) - var buf [4]byte - transformer := e.enc.NewEncoder() - for srcOffset < len(src) { - length := UTF8Encoding.CharLength(src[srcOffset:]) - _, _, err := transformer.Transform(buf[:], src[srcOffset:srcOffset+length], true) - if err != nil { - dest = append(dest, byte('?')) - } else { - dest = append(dest, src[srcOffset:srcOffset+length]...) - } - srcOffset += length - } +// Op is used by Encoding.Transform. +type Op int16 + +const ( + opFromUTF8 Op = 1 << iota + opToUTF8 + opTruncateTrim + opTruncateReplace + opCollectFrom + opCollectTo + opSkipError +) - return dest -} +const ( + OpReplace = opFromUTF8 | opTruncateReplace | opCollectFrom | opSkipError + OpEncode = opFromUTF8 | opTruncateTrim | opCollectTo + OpEncodeNoErr = OpEncode | opSkipError + OpEncodeReplace = opFromUTF8 | opTruncateReplace | opCollectTo + OpDecode = opToUTF8 | opTruncateTrim | opCollectTo + OpDecodeReplace = opToUTF8 | opTruncateReplace | opCollectTo +) -// Decode convert bytes from a specific charset to utf-8 charset. -func (e *Encoding) Decode(dest, src []byte) ([]byte, error) { - if !e.enabled() { - return src, nil - } - return e.transform(e.enc.NewDecoder(), dest, src, true) +// IsValid checks whether the bytes is valid in current encoding. +func IsValid(e Encoding, src []byte) bool { + isValid := true + e.Foreach(src, opFromUTF8, func(from, to []byte, ok bool) bool { + isValid = ok + return ok + }) + return isValid } -// DecodeString convert a string from a specific charset to utf-8 charset. -func (e *Encoding) DecodeString(src string) (string, error) { - if !e.enabled() { - return src, nil - } - bs, err := e.transform(e.enc.NewDecoder(), nil, Slice(src), true) - return string(bs), err +// IsValidString is a string version of IsValid. +func IsValidString(e Encoding, str string) bool { + return IsValid(e, Slice(str)) } -func (e *Encoding) transform(transformer transform.Transformer, dest, src []byte, isDecoding bool) ([]byte, error) { - if len(dest) < len(src) { - dest = make([]byte, len(src)*2) - } - if len(src) == 0 { - return src, nil - } - var destOffset, srcOffset int - var encodingErr error - for { - srcNextLen := e.nextCharLenInSrc(src[srcOffset:], isDecoding) - srcEnd := mathutil.Min(srcOffset+srcNextLen, len(src)) - nDest, nSrc, err := transformer.Transform(dest[destOffset:], src[srcOffset:srcEnd], false) - if err == transform.ErrShortDst { - dest = enlargeCapacity(dest) - } else if err != nil || isDecoding && beginWithReplacementChar(dest[destOffset:destOffset+nDest]) { - if encodingErr == nil { - encodingErr = e.generateErr(src[srcOffset:], srcNextLen) - } - dest[destOffset] = byte('?') - nDest, nSrc = 1, srcNextLen // skip the source bytes that cannot be decoded normally. +// CountValidBytes counts the first valid bytes in src that +// can be encode to the current encoding. +func CountValidBytes(e Encoding, src []byte) int { + nSrc := 0 + e.Foreach(src, opFromUTF8, func(from, to []byte, ok bool) bool { + if ok { + nSrc += len(from) } - destOffset += nDest - srcOffset += nSrc - // The source bytes are exhausted. - if srcOffset >= len(src) { - return dest[:destOffset], encodingErr + return ok + }) + return nSrc +} + +// CountValidBytesDecode counts the first valid bytes in src that +// can be decode to utf-8. +func CountValidBytesDecode(e Encoding, src []byte) int { + nSrc := 0 + e.Foreach(src, opToUTF8, func(from, to []byte, ok bool) bool { + if ok { + nSrc += len(from) } - } -} - -func (e *Encoding) nextCharLenInSrc(srcRest []byte, isDecoding bool) int { - if isDecoding { - if e.charLength != nil { - return e.charLength(srcRest) - } - return len(srcRest) - } - return UTF8Encoding.CharLength(srcRest) -} - -func enlargeCapacity(dest []byte) []byte { - newDest := make([]byte, len(dest)*2) - copy(newDest, dest) - return newDest -} - -func (e *Encoding) generateErr(srcRest []byte, srcNextLen int) error { - cutEnd := mathutil.Min(srcNextLen, len(srcRest)) - invalidBytes := fmt.Sprintf("%X", string(srcRest[:cutEnd])) - return errInvalidCharacterString.GenWithStackByArgs(e.name, invalidBytes) -} - -// replacementBytes are bytes for the replacement rune 0xfffd. -var replacementBytes = []byte{0xEF, 0xBF, 0xBD} - -// beginWithReplacementChar check if dst has the prefix '0xEFBFBD'. -func beginWithReplacementChar(dst []byte) bool { - return bytes.HasPrefix(dst, replacementBytes) -} - -// Slice converts string to slice without copy. -// Use at your own risk. -func Slice(s string) (b []byte) { - pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) - pBytes.Data = pString.Data - pBytes.Len = pString.Len - pBytes.Cap = pString.Len - return + return ok + }) + return nSrc } diff --git a/parser/charset/encoding_ascii.go b/parser/charset/encoding_ascii.go new file mode 100644 index 0000000000000..df5fed9c3bce2 --- /dev/null +++ b/parser/charset/encoding_ascii.go @@ -0,0 +1,71 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + go_unicode "unicode" + + "golang.org/x/text/encoding" +) + +// EncodingASCIIImpl is the instance of encodingASCII +var EncodingASCIIImpl = &encodingASCII{encodingBase{enc: encoding.Nop}} + +func init() { + EncodingASCIIImpl.self = EncodingASCIIImpl +} + +// encodingASCII is the ASCII encoding. +type encodingASCII struct { + encodingBase +} + +// Name implements Encoding interface. +func (e *encodingASCII) Name() string { + return CharsetASCII +} + +// Tp implements Encoding interface. +func (e *encodingASCII) Tp() EncodingTp { + return EncodingTpASCII +} + +// Peek implements Encoding interface. +func (e *encodingASCII) Peek(src []byte) []byte { + if len(src) == 0 { + return src + } + return src[:1] +} + +func (e *encodingASCII) Transform(dest, src []byte, op Op) ([]byte, error) { + if IsValid(e, src) { + return src, nil + } + return e.encodingBase.Transform(dest, src, op) +} + +func (e *encodingASCII) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { + for i, w := 0, 0; i < len(src); i += w { + w = 1 + ok := true + if src[i] > go_unicode.MaxASCII { + w = len(EncodingUTF8Impl.Peek(src[i:])) + ok = false + } + if !fn(src[i:i+w], src[i:i+w], ok) { + return + } + } +} diff --git a/parser/charset/encoding_base.go b/parser/charset/encoding_base.go new file mode 100644 index 0000000000000..275db24c5a3d6 --- /dev/null +++ b/parser/charset/encoding_base.go @@ -0,0 +1,117 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "bytes" + "fmt" + "reflect" + "strings" + "unsafe" + + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/parser/terror" + "golang.org/x/text/encoding" + "golang.org/x/text/transform" +) + +var errInvalidCharacterString = terror.ClassParser.NewStd(mysql.ErrInvalidCharacterString) + +// encodingBase defines some generic functions. +type encodingBase struct { + enc encoding.Encoding + self Encoding +} + +func (b encodingBase) ToUpper(src string) string { + return strings.ToUpper(src) +} + +func (b encodingBase) ToLower(src string) string { + return strings.ToLower(src) +} + +func (b encodingBase) Transform(dest, src []byte, op Op) (result []byte, err error) { + if dest == nil { + dest = make([]byte, len(src)) + } + dest = dest[:0] + b.self.Foreach(src, op, func(from, to []byte, ok bool) bool { + if !ok { + if err == nil && (op&opSkipError == 0) { + err = generateEncodingErr(b.self.Name(), from) + } + if op&opTruncateTrim != 0 { + return false + } + if op&opTruncateReplace != 0 { + dest = append(dest, '?') + return true + } + } + if op&opCollectFrom != 0 { + dest = append(dest, from...) + } else if op&opCollectTo != 0 { + dest = append(dest, to...) + } + return true + }) + return dest, err +} + +func (b encodingBase) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { + var tfm transform.Transformer + var peek func([]byte) []byte + if op&opFromUTF8 != 0 { + tfm = b.enc.NewEncoder() + peek = EncodingUTF8Impl.Peek + } else { + tfm = b.enc.NewDecoder() + peek = b.self.Peek + } + var buf [4]byte + for i, w := 0, 0; i < len(src); i += w { + w = len(peek(src[i:])) + nDst, _, err := tfm.Transform(buf[:], src[i:i+w], false) + meetErr := err != nil || (op&opToUTF8 != 0 && beginWithReplacementChar(buf[:nDst])) + if !fn(src[i:i+w], buf[:nDst], !meetErr) { + return + } + } +} + +// replacementBytes are bytes for the replacement rune 0xfffd. +var replacementBytes = []byte{0xEF, 0xBF, 0xBD} + +// beginWithReplacementChar check if dst has the prefix '0xEFBFBD'. +func beginWithReplacementChar(dst []byte) bool { + return bytes.HasPrefix(dst, replacementBytes) +} + +// generateEncodingErr generates an invalid string in charset error. +func generateEncodingErr(name string, invalidBytes []byte) error { + arg := fmt.Sprintf("%X", invalidBytes) + return errInvalidCharacterString.FastGenByArgs(name, arg) +} + +// Slice converts string to slice without copy. +// Use at your own risk. +func Slice(s string) (b []byte) { + pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pBytes.Data = pString.Data + pBytes.Len = pString.Len + pBytes.Cap = pString.Len + return +} diff --git a/parser/charset/encoding_bin.go b/parser/charset/encoding_bin.go new file mode 100644 index 0000000000000..30fd87644c571 --- /dev/null +++ b/parser/charset/encoding_bin.go @@ -0,0 +1,61 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "golang.org/x/text/encoding" +) + +// EncodingBinImpl is the instance of encodingBin. +var EncodingBinImpl = &encodingBin{encodingBase{enc: encoding.Nop}} + +func init() { + EncodingBinImpl.self = EncodingBinImpl +} + +// encodingBin is the binary encoding. +type encodingBin struct { + encodingBase +} + +// Name implements Encoding interface. +func (e *encodingBin) Name() string { + return CharsetBin +} + +// Tp implements Encoding interface. +func (e *encodingBin) Tp() EncodingTp { + return EncodingTpBin +} + +// Peek implements Encoding interface. +func (e *encodingBin) Peek(src []byte) []byte { + if len(src) == 0 { + return src + } + return src[:1] +} + +// Foreach implements Encoding interface. +func (e *encodingBin) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { + for i := 0; i < len(src); i++ { + if !fn(src[i:i+1], src[i:i+1], true) { + return + } + } +} + +func (e *encodingBin) Transform(dest, src []byte, op Op) ([]byte, error) { + return src, nil +} diff --git a/parser/charset/encoding_gbk.go b/parser/charset/encoding_gbk.go new file mode 100644 index 0000000000000..3dc3fe14fed6c --- /dev/null +++ b/parser/charset/encoding_gbk.go @@ -0,0 +1,93 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "strings" + "unicode" + + "golang.org/x/text/encoding/simplifiedchinese" +) + +// EncodingGBKImpl is the instance of encodingGBK +var EncodingGBKImpl = &encodingGBK{encodingBase{enc: simplifiedchinese.GBK}} + +func init() { + EncodingGBKImpl.self = EncodingGBKImpl +} + +// encodingGBK is GBK encoding. +type encodingGBK struct { + encodingBase +} + +// Name implements Encoding interface. +func (e *encodingGBK) Name() string { + return CharsetGBK +} + +// Tp implements Encoding interface. +func (e *encodingGBK) Tp() EncodingTp { + return EncodingTpGBK +} + +// Peek implements Encoding interface. +func (e *encodingGBK) Peek(src []byte) []byte { + charLen := 2 + if len(src) == 0 || src[0] < 0x80 { + // A byte in the range 00–7F is a single byte that means the same thing as it does in ASCII. + charLen = 1 + } + if charLen < len(src) { + return src[:charLen] + } + return src +} + +// ToUpper implements Encoding interface. +func (e *encodingGBK) ToUpper(d string) string { + return strings.ToUpperSpecial(GBKCase, d) +} + +// ToLower implements Encoding interface. +func (e *encodingGBK) ToLower(d string) string { + return strings.ToLowerSpecial(GBKCase, d) +} + +// GBKCase follows https://dev.mysql.com/worklog/task/?id=4583. +var GBKCase = unicode.SpecialCase{ + unicode.CaseRange{Lo: 0x00E0, Hi: 0x00E1, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00E8, Hi: 0x00EA, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00EC, Hi: 0x00ED, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00F2, Hi: 0x00F3, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00F9, Hi: 0x00FA, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00FC, Hi: 0x00FC, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x0101, Hi: 0x0101, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x0113, Hi: 0x0113, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x011B, Hi: 0x011B, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x012B, Hi: 0x012B, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x0144, Hi: 0x0144, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x0148, Hi: 0x0148, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x014D, Hi: 0x014D, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x016B, Hi: 0x016B, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01CE, Hi: 0x01CE, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D0, Hi: 0x01D0, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D2, Hi: 0x01D2, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D4, Hi: 0x01D4, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D6, Hi: 0x01D6, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D8, Hi: 0x01D8, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01DA, Hi: 0x01DA, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01DC, Hi: 0x01DC, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x216A, Hi: 0x216B, Delta: [unicode.MaxCase]rune{0, 0, 0}}, +} diff --git a/parser/charset/encoding_latin1.go b/parser/charset/encoding_latin1.go new file mode 100644 index 0000000000000..1d2992b87642d --- /dev/null +++ b/parser/charset/encoding_latin1.go @@ -0,0 +1,51 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import "golang.org/x/text/encoding" + +// EncodingLatin1Impl is the instance of encodingLatin1. +// TiDB uses utf8 implementation for latin1 charset because of the backward compatibility. +var EncodingLatin1Impl = &encodingLatin1{encodingUTF8{encodingBase{enc: encoding.Nop}}} + +func init() { + EncodingLatin1Impl.self = EncodingLatin1Impl +} + +// encodingLatin1 compatibles with latin1 in old version TiDB. +type encodingLatin1 struct { + encodingUTF8 +} + +// Name implements Encoding interface. +func (e *encodingLatin1) Name() string { + return CharsetLatin1 +} + +// Peek implements Encoding interface. +func (e *encodingLatin1) Peek(src []byte) []byte { + if len(src) == 0 { + return src + } + return src[:1] +} + +// Tp implements Encoding interface. +func (e *encodingLatin1) Tp() EncodingTp { + return EncodingTpLatin1 +} + +func (e *encodingLatin1) Transform(dest, src []byte, op Op) ([]byte, error) { + return src, nil +} diff --git a/parser/charset/encoding_table.go b/parser/charset/encoding_table.go index 2de9d957d923a..2780272296acb 100644 --- a/parser/charset/encoding_table.go +++ b/parser/charset/encoding_table.go @@ -14,11 +14,6 @@ package charset import ( - "strings" - go_unicode "unicode" - "unicode/utf8" - - "github.com/cznic/mathutil" "golang.org/x/text/encoding" "golang.org/x/text/encoding/charmap" "golang.org/x/text/encoding/japanese" @@ -26,28 +21,20 @@ import ( "golang.org/x/text/encoding/simplifiedchinese" "golang.org/x/text/encoding/traditionalchinese" "golang.org/x/text/encoding/unicode" + "strings" ) -var encodingMap = map[EncodingLabel]*Encoding{ - CharsetUTF8MB4: UTF8Encoding, - CharsetUTF8: UTF8Encoding, - CharsetGBK: GBKEncoding, - CharsetLatin1: LatinEncoding, - CharsetBin: BinaryEncoding, - CharsetASCII: ASCIIEncoding, -} - // Lookup returns the encoding with the specified label, and its canonical // name. It returns nil and the empty string if label is not one of the // standard encodings for HTML. Matching is case-insensitive and ignores // leading and trailing whitespace. func Lookup(label string) (e encoding.Encoding, name string) { label = strings.ToLower(strings.Trim(label, "\t\n\r\f ")) - return lookup(Formatted(label)) + return lookup(label) } -func lookup(label EncodingLabel) (e encoding.Encoding, name string) { - enc := encodings[string(label)] +func lookup(label string) (e encoding.Encoding, name string) { + enc := encodings[label] return enc.e, enc.name } @@ -274,179 +261,3 @@ var encodings = map[string]struct { "utf-16le": {unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM), "utf-16le"}, "x-user-defined": {charmap.XUserDefined, "x-user-defined"}, } - -// TruncateStrategy indicates the way to handle the invalid strings in specific charset. -// - TruncateStrategyEmpty: returns an empty string. -// - TruncateStrategyTrim: returns the valid prefix part of string. -// - TruncateStrategyReplace: returns the whole string, but the invalid characters are replaced with '?'. -type TruncateStrategy int8 - -const ( - TruncateStrategyEmpty TruncateStrategy = iota - TruncateStrategyTrim - TruncateStrategyReplace -) - -var _ StringValidator = StringValidatorASCII{} -var _ StringValidator = StringValidatorUTF8{} -var _ StringValidator = StringValidatorOther{} - -// StringValidator is used to check if a string is valid in the specific charset. -type StringValidator interface { - Validate(str string) (invalidPos int) - Truncate(str string, strategy TruncateStrategy) (result string, invalidPos int) -} - -// StringValidatorASCII checks whether a string is valid ASCII string. -type StringValidatorASCII struct{} - -// Validate checks whether the string is valid in the given charset. -func (s StringValidatorASCII) Validate(str string) int { - _, invalidPos := s.Truncate(str, TruncateStrategyEmpty) - return invalidPos -} - -// Truncate implement the interface StringValidator. -func (s StringValidatorASCII) Truncate(str string, strategy TruncateStrategy) (string, int) { - invalidPos := -1 - for i := 0; i < len(str); i++ { - if str[i] > go_unicode.MaxASCII { - invalidPos = i - break - } - } - if invalidPos == -1 { - // Quick check passed. - return str, -1 - } - switch strategy { - case TruncateStrategyEmpty: - return "", invalidPos - case TruncateStrategyTrim: - return str[:invalidPos], invalidPos - case TruncateStrategyReplace: - result := make([]byte, 0, len(str)) - for i, w := 0, 0; i < len(str); i += w { - w = 1 - if str[i] > go_unicode.MaxASCII { - w = UTF8Encoding.CharLength(Slice(str)[i:]) - w = mathutil.Min(w, len(str)-i) - result = append(result, '?') - continue - } - result = append(result, str[i:i+w]...) - } - return string(result), invalidPos - } - return str, -1 -} - -// StringValidatorUTF8 checks whether a string is valid UTF8 string. -type StringValidatorUTF8 struct { - IsUTF8MB4 bool // Distinguish between "utf8" and "utf8mb4" - CheckMB4ValueInUTF8 bool -} - -// Validate checks whether the string is valid in the given charset. -func (s StringValidatorUTF8) Validate(str string) int { - _, invalidPos := s.Truncate(str, TruncateStrategyEmpty) - return invalidPos -} - -// Truncate implement the interface StringValidator. -func (s StringValidatorUTF8) Truncate(str string, strategy TruncateStrategy) (string, int) { - if str == "" { - return str, -1 - } - if s.IsUTF8MB4 && utf8.ValidString(str) { - // Quick check passed. - return str, -1 - } - doMB4CharCheck := !s.IsUTF8MB4 && s.CheckMB4ValueInUTF8 - var result []byte - if strategy == TruncateStrategyReplace { - result = make([]byte, 0, len(str)) - } - invalidPos := -1 - for i, w := 0, 0; i < len(str); i += w { - var rv rune - rv, w = utf8.DecodeRuneInString(str[i:]) - if (rv == utf8.RuneError && w == 1) || (w > 3 && doMB4CharCheck) { - if invalidPos == -1 { - invalidPos = i - } - switch strategy { - case TruncateStrategyEmpty: - return "", invalidPos - case TruncateStrategyTrim: - return str[:i], invalidPos - case TruncateStrategyReplace: - result = append(result, '?') - continue - } - } - if strategy == TruncateStrategyReplace { - result = append(result, str[i:i+w]...) - } - } - if strategy == TruncateStrategyReplace { - return string(result), invalidPos - } - return str, -1 -} - -// StringValidatorOther checks whether a string is valid string in given charset. -type StringValidatorOther struct { - Charset string -} - -// Validate checks whether the string is valid in the given charset. -func (s StringValidatorOther) Validate(str string) int { - _, invalidPos := s.Truncate(str, TruncateStrategyEmpty) - return invalidPos -} - -// Truncate implement the interface StringValidator. -func (s StringValidatorOther) Truncate(str string, strategy TruncateStrategy) (string, int) { - if str == "" { - return str, -1 - } - enc := NewEncoding(s.Charset) - if !enc.enabled() { - return str, -1 - } - var result []byte - if strategy == TruncateStrategyReplace { - result = make([]byte, 0, len(str)) - } - var buf [4]byte - strBytes := Slice(str) - transformer := enc.enc.NewEncoder() - invalidPos := -1 - for i, w := 0, 0; i < len(str); i += w { - w = UTF8Encoding.CharLength(strBytes[i:]) - w = mathutil.Min(w, len(str)-i) - _, _, err := transformer.Transform(buf[:], strBytes[i:i+w], true) - if err != nil { - if invalidPos == -1 { - invalidPos = i - } - switch strategy { - case TruncateStrategyEmpty: - return "", invalidPos - case TruncateStrategyTrim: - return str[:i], invalidPos - case TruncateStrategyReplace: - result = append(result, '?') - continue - } - } - if strategy == TruncateStrategyReplace { - result = append(result, strBytes[i:i+w]...) - } - } - if strategy == TruncateStrategyReplace { - return string(result), invalidPos - } - return str, -1 -} diff --git a/parser/charset/encoding_test.go b/parser/charset/encoding_test.go index 51f5b53b3e2fd..a78aa640d8be5 100644 --- a/parser/charset/encoding_test.go +++ b/parser/charset/encoding_test.go @@ -24,21 +24,21 @@ import ( ) func TestEncoding(t *testing.T) { - enc := charset.NewEncoding(charset.CharsetGBK) + enc := charset.FindEncoding(charset.CharsetGBK) require.Equal(t, charset.CharsetGBK, enc.Name()) txt := []byte("一二三四") e, _ := charset.Lookup("gbk") gbkEncodedTxt, _, err := transform.Bytes(e.NewEncoder(), txt) require.NoError(t, err) - result, err := enc.Decode(nil, gbkEncodedTxt) + result, err := enc.Transform(nil, gbkEncodedTxt, charset.OpDecode) require.NoError(t, err) require.Equal(t, txt, result) - gbkEncodedTxt2, err := enc.Encode(nil, txt) + gbkEncodedTxt2, err := enc.Transform(nil, txt, charset.OpEncode) require.NoError(t, err) require.Equal(t, gbkEncodedTxt2, gbkEncodedTxt) - result, err = enc.Decode(nil, gbkEncodedTxt2) + result, err = enc.Transform(nil, gbkEncodedTxt2, charset.OpDecode) require.NoError(t, err) require.Equal(t, txt, result) @@ -58,7 +58,7 @@ func TestEncoding(t *testing.T) { } for _, tc := range GBKCases { cmt := fmt.Sprintf("%v", tc) - result, err = enc.Decode(nil, []byte(tc.utf8Str)) + result, err := enc.Transform(nil, []byte(tc.utf8Str), charset.OpDecodeReplace) if tc.isValid { require.NoError(t, err, cmt) } else { @@ -78,7 +78,7 @@ func TestEncoding(t *testing.T) { } for _, tc := range utf8Cases { cmt := fmt.Sprintf("%v", tc) - result, err = enc.Encode(nil, []byte(tc.utf8Str)) + result, err := enc.Transform(nil, []byte(tc.utf8Str), charset.OpEncodeReplace) if tc.isValid { require.NoError(t, err, cmt) } else { @@ -88,111 +88,54 @@ func TestEncoding(t *testing.T) { } } -func TestStringValidatorASCII(t *testing.T) { - v := charset.StringValidatorASCII{} - testCases := []struct { - str string - strategy charset.TruncateStrategy - expected string - invalidPos int - }{ - {"", charset.TruncateStrategyEmpty, "", -1}, - {"qwerty", charset.TruncateStrategyEmpty, "qwerty", -1}, - {"qwÊrty", charset.TruncateStrategyEmpty, "", 2}, - {"qwÊrty", charset.TruncateStrategyTrim, "qw", 2}, - {"qwÊrty", charset.TruncateStrategyReplace, "qw?rty", 2}, - {"中文", charset.TruncateStrategyEmpty, "", 0}, - {"中文?qwert", charset.TruncateStrategyTrim, "", 0}, - {"中文?qwert", charset.TruncateStrategyReplace, "???qwert", 0}, - } - for _, tc := range testCases { - msg := fmt.Sprintf("%v", tc) - actual, invalidPos := v.Truncate(tc.str, tc.strategy) - require.Equal(t, tc.expected, actual, msg) - require.Equal(t, tc.invalidPos, invalidPos, msg) - } - require.Equal(t, -1, v.Validate("qwerty")) - require.Equal(t, 2, v.Validate("qwÊrty")) - require.Equal(t, 0, v.Validate("中文")) -} - -func TestStringValidatorUTF8(t *testing.T) { - // Test charset "utf8mb4". - v := charset.StringValidatorUTF8{IsUTF8MB4: true} +func TestEncodingValidate(t *testing.T) { oxfffefd := string([]byte{0xff, 0xfe, 0xfd}) testCases := []struct { - str string - strategy charset.TruncateStrategy - expected string - invalidPos int - }{ - {"", charset.TruncateStrategyEmpty, "", -1}, - {"qwerty", charset.TruncateStrategyEmpty, "qwerty", -1}, - {"qwÊrty", charset.TruncateStrategyEmpty, "qwÊrty", -1}, - {"qwÊ合法字符串", charset.TruncateStrategyEmpty, "qwÊ合法字符串", -1}, - {"😂", charset.TruncateStrategyEmpty, "😂", -1}, - {oxfffefd, charset.TruncateStrategyEmpty, "", 0}, - {oxfffefd, charset.TruncateStrategyReplace, "???", 0}, - {"中文" + oxfffefd, charset.TruncateStrategyTrim, "中文", 6}, - {"中文" + oxfffefd, charset.TruncateStrategyReplace, "中文???", 6}, - {string(utf8.RuneError), charset.TruncateStrategyEmpty, "�", -1}, - } - for _, tc := range testCases { - msg := fmt.Sprintf("%v", tc) - actual, invalidPos := v.Truncate(tc.str, tc.strategy) - require.Equal(t, tc.expected, actual, msg) - require.Equal(t, tc.invalidPos, invalidPos, msg) - } - // Test charset "utf8" with checking mb4 value. - v = charset.StringValidatorUTF8{IsUTF8MB4: false, CheckMB4ValueInUTF8: true} - testCases = []struct { - str string - strategy charset.TruncateStrategy - expected string - invalidPos int - }{ - {"", charset.TruncateStrategyEmpty, "", -1}, - {"qwerty", charset.TruncateStrategyEmpty, "qwerty", -1}, - {"qwÊrty", charset.TruncateStrategyEmpty, "qwÊrty", -1}, - {"qwÊ合法字符串", charset.TruncateStrategyEmpty, "qwÊ合法字符串", -1}, - {"😂", charset.TruncateStrategyEmpty, "", 0}, - {"😂", charset.TruncateStrategyReplace, "?", 0}, - {"valid_str😂", charset.TruncateStrategyReplace, "valid_str?", 9}, - {oxfffefd, charset.TruncateStrategyEmpty, "", 0}, - {oxfffefd, charset.TruncateStrategyReplace, "???", 0}, - {"中文" + oxfffefd, charset.TruncateStrategyTrim, "中文", 6}, - {"中文" + oxfffefd, charset.TruncateStrategyReplace, "中文???", 6}, - {string(utf8.RuneError), charset.TruncateStrategyEmpty, "�", -1}, - } - for _, tc := range testCases { - msg := fmt.Sprintf("%v", tc) - actual, invalidPos := v.Truncate(tc.str, tc.strategy) - require.Equal(t, tc.expected, actual, msg) - require.Equal(t, tc.invalidPos, invalidPos, msg) - } -} - -func TestStringValidatorGBK(t *testing.T) { - v := charset.StringValidatorOther{Charset: "gbk"} - testCases := []struct { - str string - strategy charset.TruncateStrategy - expected string - invalidPos int + chs string + str string + expected string + nSrc int + ok bool }{ - {"", charset.TruncateStrategyEmpty, "", -1}, - {"asdf", charset.TruncateStrategyEmpty, "asdf", -1}, - {"中文", charset.TruncateStrategyEmpty, "中文", -1}, - {"À", charset.TruncateStrategyEmpty, "", 0}, - {"À", charset.TruncateStrategyReplace, "?", 0}, - {"中文À中文", charset.TruncateStrategyTrim, "中文", 6}, - {"中文À中文", charset.TruncateStrategyReplace, "中文?中文", 6}, - {"asdfÀ", charset.TruncateStrategyReplace, "asdf?", 4}, + {charset.CharsetASCII, "", "", 0, true}, + {charset.CharsetASCII, "qwerty", "qwerty", 6, true}, + {charset.CharsetASCII, "qwÊrty", "qw?rty", 2, false}, + {charset.CharsetASCII, "中文", "??", 0, false}, + {charset.CharsetASCII, "中文?qwert", "???qwert", 0, false}, + {charset.CharsetUTF8MB4, "", "", 0, true}, + {charset.CharsetUTF8MB4, "qwerty", "qwerty", 6, true}, + {charset.CharsetUTF8MB4, "qwÊrty", "qwÊrty", 7, true}, + {charset.CharsetUTF8MB4, "qwÊ合法字符串", "qwÊ合法字符串", 19, true}, + {charset.CharsetUTF8MB4, "😂", "😂", 4, true}, + {charset.CharsetUTF8MB4, oxfffefd, "???", 0, false}, + {charset.CharsetUTF8MB4, "中文" + oxfffefd, "中文???", 6, false}, + {charset.CharsetUTF8MB4, string(utf8.RuneError), "�", 3, true}, + {charset.CharsetUTF8, "", "", 0, true}, + {charset.CharsetUTF8, "qwerty", "qwerty", 6, true}, + {charset.CharsetUTF8, "qwÊrty", "qwÊrty", 7, true}, + {charset.CharsetUTF8, "qwÊ合法字符串", "qwÊ合法字符串", 19, true}, + {charset.CharsetUTF8, "😂", "?", 0, false}, + {charset.CharsetUTF8, "valid_str😂", "valid_str?", 9, false}, + {charset.CharsetUTF8, oxfffefd, "???", 0, false}, + {charset.CharsetUTF8, "中文" + oxfffefd, "中文???", 6, false}, + {charset.CharsetUTF8, string(utf8.RuneError), "�", 3, true}, + {charset.CharsetGBK, "", "", 0, true}, + {charset.CharsetGBK, "asdf", "asdf", 4, true}, + {charset.CharsetGBK, "中文", "中文", 6, true}, + {charset.CharsetGBK, "À", "?", 0, false}, + {charset.CharsetGBK, "中文À中文", "中文?中文", 6, false}, + {charset.CharsetGBK, "asdfÀ", "asdf?", 4, false}, } for _, tc := range testCases { msg := fmt.Sprintf("%v", tc) - actual, invalidPos := v.Truncate(tc.str, tc.strategy) - require.Equal(t, tc.expected, actual, msg) - require.Equal(t, tc.invalidPos, invalidPos, msg) + enc := charset.FindEncoding(tc.chs) + if tc.chs == charset.CharsetUTF8 { + enc = charset.EncodingUTF8MB3StrictImpl + } + strBytes := []byte(tc.str) + ok := charset.IsValid(enc, strBytes) + require.Equal(t, tc.ok, ok, msg) + replace, _ := enc.Transform(nil, strBytes, charset.OpReplace) + require.Equal(t, tc.expected, string(replace), msg) } } diff --git a/parser/charset/encoding_utf8.go b/parser/charset/encoding_utf8.go new file mode 100644 index 0000000000000..871a5e5ec33c1 --- /dev/null +++ b/parser/charset/encoding_utf8.go @@ -0,0 +1,114 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "unicode/utf8" + + "golang.org/x/text/encoding" +) + +// EncodingUTF8Impl is the instance of encodingUTF8. +var EncodingUTF8Impl = &encodingUTF8{encodingBase{enc: encoding.Nop}} + +// EncodingUTF8MB3StrictImpl is the instance of encodingUTF8MB3Strict. +var EncodingUTF8MB3StrictImpl = &encodingUTF8MB3Strict{ + encodingUTF8{ + encodingBase{ + enc: encoding.Nop, + }, + }, +} + +func init() { + EncodingUTF8Impl.self = EncodingUTF8Impl + EncodingUTF8MB3StrictImpl.self = EncodingUTF8MB3StrictImpl +} + +// encodingUTF8 is TiDB's default encoding. +type encodingUTF8 struct { + encodingBase +} + +// Name implements Encoding interface. +func (e *encodingUTF8) Name() string { + return CharsetUTF8MB4 +} + +// Tp implements Encoding interface. +func (e *encodingUTF8) Tp() EncodingTp { + return EncodingTpUTF8 +} + +// Peek implements Encoding interface. +func (e *encodingUTF8) Peek(src []byte) []byte { + nextLen := 4 + if len(src) == 0 || src[0] < 0x80 { + nextLen = 1 + } else if src[0] < 0xe0 { + nextLen = 2 + } else if src[0] < 0xf0 { + nextLen = 3 + } + if len(src) < nextLen { + return src + } + return src[:nextLen] +} + +// Transform implements Encoding interface. +func (e *encodingUTF8) Transform(dest, src []byte, op Op) ([]byte, error) { + if IsValid(e, src) { + return src, nil + } + return e.encodingBase.Transform(dest, src, op) +} + +// Foreach implements Encoding interface. +func (e *encodingUTF8) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { + var rv rune + for i, w := 0, 0; i < len(src); i += w { + rv, w = utf8.DecodeRune(src[i:]) + meetErr := rv == utf8.RuneError && w == 1 + if !fn(src[i:i+w], src[i:i+w], !meetErr) { + return + } + } +} + +// encodingUTF8MB3Strict is the strict mode of EncodingUTF8MB3. +// MB4 characters are considered invalid. +type encodingUTF8MB3Strict struct { + encodingUTF8 +} + +// Foreach implements Encoding interface. +func (e *encodingUTF8MB3Strict) Foreach(src []byte, op Op, fn func(srcCh, dstCh []byte, ok bool) bool) { + for i, w := 0, 0; i < len(src); i += w { + var rv rune + rv, w = utf8.DecodeRune(src[i:]) + meetErr := (rv == utf8.RuneError && w == 1) || w > 3 + if !fn(src[i:i+w], src[i:i+w], !meetErr) { + return + } + } +} + +// Transform implements Encoding interface. +func (e *encodingUTF8MB3Strict) Transform(dest, src []byte, op Op) ([]byte, error) { + if IsValid(e, src) { + return src, nil + } + return e.encodingBase.Transform(dest, src, op) +} diff --git a/parser/charset/gbk.go b/parser/charset/gbk.go deleted file mode 100644 index 5686c6e1b50f0..0000000000000 --- a/parser/charset/gbk.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package charset - -import "golang.org/x/text/encoding/simplifiedchinese" - -var GBKEncoding = &Encoding{ - enc: simplifiedchinese.GBK, - name: CharsetGBK, - charLength: func(bs []byte) int { - if len(bs) == 0 || bs[0] < 0x80 { - // A byte in the range 00–7F is a single byte that means the same thing as it does in ASCII. - return 1 - } - return 2 - }, - specialCase: GBKCase, -} diff --git a/parser/charset/latin.go b/parser/charset/latin.go deleted file mode 100644 index 04de80d250aef..0000000000000 --- a/parser/charset/latin.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package charset - -import ( - "golang.org/x/text/encoding" - "golang.org/x/text/encoding/charmap" -) - -var ( - LatinEncoding = &Encoding{ - enc: charmap.Windows1252, - name: CharsetLatin1, - charLength: func(bytes []byte) int { - return 1 - }, - specialCase: nil, - } - - BinaryEncoding = &Encoding{ - enc: encoding.Nop, - name: CharsetBin, - charLength: func(bytes []byte) int { - return 1 - }, - specialCase: nil, - } - - ASCIIEncoding = &Encoding{ - enc: encoding.Nop, - name: CharsetASCII, - charLength: func(bytes []byte) int { - return 1 - }, - specialCase: nil, - } -) diff --git a/parser/charset/special_case_tables.go b/parser/charset/special_case_tables.go deleted file mode 100644 index 8a92ee717c566..0000000000000 --- a/parser/charset/special_case_tables.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package charset - -import ( - "strings" - "unicode" -) - -func (e *Encoding) ToUpper(d string) string { - return strings.ToUpperSpecial(e.specialCase, d) -} - -func (e *Encoding) ToLower(d string) string { - return strings.ToLowerSpecial(e.specialCase, d) -} - -func LookupSpecialCase(label string) unicode.SpecialCase { - label = strings.ToLower(strings.Trim(label, "\t\n\r\f ")) - return specailCases[label].c -} - -var specailCases = map[string]struct { - c unicode.SpecialCase -}{ - "utf-8": {nil}, - "ibm866": {nil}, - "iso-8859-2": {nil}, - "iso-8859-3": {nil}, - "iso-8859-4": {nil}, - "iso-8859-5": {nil}, - "iso-8859-6": {nil}, - "iso-8859-7": {nil}, - "iso-8859-8": {nil}, - "iso-8859-8-i": {nil}, - "iso-8859-10": {nil}, - "iso-8859-13": {nil}, - "iso-8859-14": {nil}, - "iso-8859-15": {nil}, - "iso-8859-16": {nil}, - "koi8-r": {nil}, - "macintosh": {nil}, - "windows-874": {nil}, - "windows-1250": {nil}, - "windows-1251": {nil}, - "windows-1252": {nil}, - "windows-1253": {nil}, - "windows-1254": {nil}, - "windows-1255": {nil}, - "windows-1256": {nil}, - "windows-1257": {nil}, - "windows-1258": {nil}, - "x-mac-cyrillic": {nil}, - "gbk": {GBKCase}, - "gb18030": {nil}, - "hz-gb-2312": {nil}, - "big5": {nil}, - "euc-jp": {nil}, - "iso-2022-jp": {nil}, - "shift_jis": {nil}, - "euc-kr": {nil}, - "replacement": {nil}, - "utf-16be": {nil}, - "utf-16le": {nil}, - "x-user-defined": {nil}, -} - -// follow https://dev.mysql.com/worklog/task/?id=4583 for GBK -var GBKCase = unicode.SpecialCase{ - unicode.CaseRange{0x00E0, 0x00E1, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00E8, 0x00EA, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00EC, 0x00ED, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00F2, 0x00F3, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00F9, 0x00FA, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00FC, 0x00FC, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x0101, 0x0101, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x0113, 0x0113, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x011B, 0x011B, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x012B, 0x012B, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x0144, 0x0144, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x0148, 0x0148, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x014D, 0x014D, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x016B, 0x016B, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01CE, 0x01CE, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D0, 0x01D0, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D2, 0x01D2, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D4, 0x01D4, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D6, 0x01D6, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D8, 0x01D8, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01DA, 0x01DA, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01DC, 0x01DC, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x216A, 0x216B, [unicode.MaxCase]rune{0, 0, 0}}, -} diff --git a/parser/charset/utf.go b/parser/charset/utf.go deleted file mode 100644 index 301aaba49d19a..0000000000000 --- a/parser/charset/utf.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package charset - -import ( - "golang.org/x/text/encoding" -) - -var UTF8Encoding = &Encoding{ - enc: encoding.Nop, - name: CharsetUTF8MB4, - charLength: func(bs []byte) int { - if len(bs) == 0 || bs[0] < 0x80 { - return 1 - } else if bs[0] < 0xe0 { - return 2 - } else if bs[0] < 0xf0 { - return 3 - } - return 4 - }, - specialCase: nil, -} diff --git a/parser/lexer.go b/parser/lexer.go index 94358fe51a962..c274a53f9f049 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -40,7 +40,7 @@ type Scanner struct { r reader buf bytes.Buffer - encoding *charset.Encoding + encoding charset.Encoding errs []error warns []error @@ -146,12 +146,18 @@ func (s *Scanner) AppendWarn(err error) { } func (s *Scanner) tryDecodeToUTF8String(sql string) string { - utf8Lit, err := s.encoding.DecodeString(sql) + if mysql.IsUTF8Charset(s.encoding.Name()) { + // Skip utf8 encoding because `ToUTF8` validates the whole SQL. + // This can cause failure when the SQL contains BLOB values. + // TODO: Convert charset on every token and use 'binary' encoding to decode token. + return sql + } + utf8Lit, err := s.encoding.Transform(nil, charset.Slice(sql), charset.OpDecodeReplace) if err != nil { s.AppendError(err) s.lastErrorAsWarn() } - return utf8Lit + return string(utf8Lit) } func (s *Scanner) getNextToken() int { diff --git a/parser/parser.go b/parser/parser.go index 5f167d88ee7e8..a4fa6c451eed0 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -14989,7 +14989,7 @@ yynewstate: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", yyS[yypt-1].ident)) return 1 } - expr := ast.NewValueExpr(yyS[yypt-0].ident, parser.charset, parser.collation) + expr := ast.NewValueExpr(yyS[yypt-0].ident, yyS[yypt-1].ident, co) tp := expr.GetType() tp.Charset = yyS[yypt-1].ident tp.Collate = co @@ -15013,7 +15013,7 @@ yynewstate: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", yyS[yypt-1].ident)) return 1 } - expr := ast.NewValueExpr(yyS[yypt-0].item, parser.charset, parser.collation) + expr := ast.NewValueExpr(yyS[yypt-0].item, yyS[yypt-1].ident, co) tp := expr.GetType() tp.Charset = yyS[yypt-1].ident tp.Collate = co @@ -15029,7 +15029,7 @@ yynewstate: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", yyS[yypt-1].ident)) return 1 } - expr := ast.NewValueExpr(yyS[yypt-0].item, parser.charset, parser.collation) + expr := ast.NewValueExpr(yyS[yypt-0].item, yyS[yypt-1].ident, co) tp := expr.GetType() tp.Charset = yyS[yypt-1].ident tp.Collate = co diff --git a/parser/parser.y b/parser/parser.y index 4c2dc969450bd..91e3b05918fc9 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -6456,7 +6456,7 @@ Literal: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", $1)) return 1 } - expr := ast.NewValueExpr($2, parser.charset, parser.collation) + expr := ast.NewValueExpr($2, $1, co) tp := expr.GetType() tp.Charset = $1 tp.Collate = co @@ -6480,7 +6480,7 @@ Literal: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", $1)) return 1 } - expr := ast.NewValueExpr($2, parser.charset, parser.collation) + expr := ast.NewValueExpr($2, $1, co) tp := expr.GetType() tp.Charset = $1 tp.Collate = co @@ -6496,7 +6496,7 @@ Literal: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", $1)) return 1 } - expr := ast.NewValueExpr($2, parser.charset, parser.collation) + expr := ast.NewValueExpr($2, $1, co) tp := expr.GetType() tp.Charset = $1 tp.Collate = co diff --git a/parser/yy_parser.go b/parser/yy_parser.go index df3f416fad2e7..58e18083b28cb 100644 --- a/parser/yy_parser.go +++ b/parser/yy_parser.go @@ -396,7 +396,7 @@ var ( func resetParams(p *Parser) { p.charset = mysql.DefaultCharset p.collation = mysql.DefaultCollationName - p.lexer.encoding = charset.UTF8Encoding + p.lexer.encoding = charset.EncodingUTF8Impl } // ParseParam represents the parameter of parsing. @@ -436,6 +436,6 @@ type CharsetClient string // ApplyOn implements ParseParam interface. func (c CharsetClient) ApplyOn(p *Parser) error { - p.lexer.encoding = charset.NewEncoding(string(c)) + p.lexer.encoding = charset.FindEncoding(string(c)) return nil } diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index c3e09b82c2409..40b5da424a919 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -4987,7 +4987,7 @@ func (s *testIntegrationSuite) TestIssue30094(c *C) { )) tk.MustQuery(`explain format = 'brief' select * from t30094 where concat(a,'1') = _binary 0xe59388e59388e59388 collate binary and concat(a,'1') = _binary 0xe598bfe598bfe598bf collate binary;`).Check(testkit.Rows( "TableReader 8000.00 root data:Selection", - "└─Selection 8000.00 cop[tikv] eq(to_binary(concat(test.t30094.a, \"1\")), \"0xe59388e59388e59388\"), eq(to_binary(concat(test.t30094.a, \"1\")), \"0xe598bfe598bfe598bf\")", + "└─Selection 8000.00 cop[tikv] eq(concat(test.t30094.a, \"1\"), \"0xe59388e59388e59388\"), eq(concat(test.t30094.a, \"1\"), \"0xe598bfe598bfe598bf\")", " └─TableFullScan 10000.00 cop[tikv] table:t30094 keep order:false, stats:pseudo", )) } diff --git a/server/util.go b/server/util.go index d84b4cd566519..3bf27665e4d5c 100644 --- a/server/util.go +++ b/server/util.go @@ -291,14 +291,14 @@ func dumpBinaryRow(buffer []byte, columns []*ColumnInfo, row chunk.Row, d *resul } type inputDecoder struct { - encoding *charset.Encoding + encoding charset.Encoding buffer []byte } func newInputDecoder(chs string) *inputDecoder { return &inputDecoder{ - encoding: charset.NewEncoding(chs), + encoding: charset.FindEncoding(chs), buffer: nil, } } @@ -309,7 +309,7 @@ func (i *inputDecoder) clean() { } func (i *inputDecoder) decodeInput(src []byte) []byte { - result, err := i.encoding.Decode(i.buffer, src) + result, err := i.encoding.Transform(i.buffer, src, charset.OpDecode) if err != nil { return src } @@ -320,22 +320,23 @@ type resultEncoder struct { // chsName and encoding are unchanged after the initialization from // session variable @@character_set_results. chsName string - encoding *charset.Encoding + encoding charset.Encoding // dataEncoding can be updated to match the column data charset. - dataEncoding *charset.Encoding + dataEncoding charset.Encoding buffer []byte - isBinary bool - isNull bool + isBinary bool + isNull bool + dataIsBinary bool } // newResultEncoder creates a new resultEncoder. func newResultEncoder(chs string) *resultEncoder { return &resultEncoder{ chsName: chs, - encoding: charset.NewEncoding(chs), + encoding: charset.FindEncoding(chs), buffer: nil, isBinary: chs == charset.CharsetBinary, isNull: len(chs) == 0, @@ -352,7 +353,8 @@ func (d *resultEncoder) updateDataEncoding(chsID uint16) { if err != nil { logutil.BgLogger().Warn("unknown charset ID", zap.Error(err)) } - d.dataEncoding = charset.NewEncoding(chs) + d.dataEncoding = charset.FindEncoding(chs) + d.dataIsBinary = chsID == mysql.BinaryDefaultCollationID } func (d *resultEncoder) columnTypeInfoCharsetID(info *ColumnInfo) uint16 { @@ -367,24 +369,30 @@ func (d *resultEncoder) columnTypeInfoCharsetID(info *ColumnInfo) uint16 { return uint16(mysql.CharsetNameToID(d.chsName)) } +// encodeMeta encodes bytes for meta info like column names. +// Note that the result should be consumed immediately. func (d *resultEncoder) encodeMeta(src []byte) []byte { return d.encodeWith(src, d.encoding) } +// encodeData encodes bytes for row data. +// Note that the result should be consumed immediately. func (d *resultEncoder) encodeData(src []byte) []byte { - if d.isNull || d.isBinary { + if d.isNull || d.isBinary || d.dataIsBinary { // Use the column charset to encode. return d.encodeWith(src, d.dataEncoding) } return d.encodeWith(src, d.encoding) } -func (d *resultEncoder) encodeWith(src []byte, enc *charset.Encoding) []byte { - result, err := enc.Encode(d.buffer, src) +func (d *resultEncoder) encodeWith(src []byte, enc charset.Encoding) []byte { + var err error + d.buffer, err = enc.Transform(d.buffer, src, charset.OpEncode) if err != nil { logutil.BgLogger().Debug("encode error", zap.Error(err)) } - return result + // The buffer will be reused. + return d.buffer } func dumpTextRow(buffer []byte, columns []*ColumnInfo, row chunk.Row, d *resultEncoder) ([]byte, error) { diff --git a/table/column.go b/table/column.go index 0225a88556c0b..d7e9a9ec5dadb 100644 --- a/table/column.go +++ b/table/column.go @@ -170,7 +170,7 @@ func truncateTrailingSpaces(v *types.Datum) { v.SetString(str, v.Collation()) } -func handleWrongCharsetValue(ctx sessionctx.Context, col *model.ColumnInfo, str string, i int) error { +func handleWrongCharsetValue(ctx sessionctx.Context, col *model.ColumnInfo, str []byte, i int) error { sc := ctx.GetSessionVars().StmtCtx var strval strings.Builder for j := 0; j < 6; j++ { @@ -328,46 +328,57 @@ func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, r truncateTrailingSpaces(&casted) } - if v := makeStringValidator(ctx, col); v != nil { - str := casted.GetString() - strategy := charset.TruncateStrategyReplace - if val.Collation() == charset.CollationBin { - strategy = charset.TruncateStrategyTrim - } - if newStr, invalidPos := v.Truncate(str, strategy); invalidPos >= 0 { - casted = types.NewStringDatum(newStr) - err = handleWrongCharsetValue(ctx, col, str, invalidPos) - } - } + err = validateStringDatum(ctx, &val, &casted, col) if forceIgnoreTruncate { err = nil } return casted, err } -func makeStringValidator(ctx sessionctx.Context, col *model.ColumnInfo) charset.StringValidator { - switch col.Charset { - case charset.CharsetASCII: - if ctx.GetSessionVars().SkipASCIICheck { - return nil - } - return charset.StringValidatorASCII{} - case charset.CharsetUTF8: - if ctx.GetSessionVars().SkipUTF8Check { - return nil - } - needCheckMB4 := config.GetGlobalConfig().CheckMb4ValueInUTF8 - return charset.StringValidatorUTF8{IsUTF8MB4: false, CheckMB4ValueInUTF8: needCheckMB4} - case charset.CharsetUTF8MB4: - if ctx.GetSessionVars().SkipUTF8Check { - return nil +func validateStringDatum(ctx sessionctx.Context, origin, casted *types.Datum, col *model.ColumnInfo) error { + // Only strings are need to validate. + if !types.IsString(col.Tp) { + return nil + } + fromBinary := origin.Kind() == types.KindBinaryLiteral || + (origin.Kind() == types.KindString && origin.Collation() == charset.CollationBin) + toBinary := types.IsTypeBlob(col.Tp) || col.Charset == charset.CharsetBin + if fromBinary && toBinary { + return nil + } + enc := charset.FindEncoding(col.Charset) + // Skip utf8 check if possible. + if enc.Tp() == charset.EncodingTpUTF8 && ctx.GetSessionVars().SkipUTF8Check { + return nil + } + // Skip ascii check if possible. + if enc.Tp() == charset.EncodingTpASCII && ctx.GetSessionVars().SkipASCIICheck { + return nil + } + if col.Charset == charset.CharsetUTF8 && config.GetGlobalConfig().CheckMb4ValueInUTF8 { + // Use a strict mode implementation. 4 bytes characters are invalid. + enc = charset.EncodingUTF8MB3StrictImpl + } + if fromBinary { + src := casted.GetBytes() + encBytes, err := enc.Transform(nil, src, charset.OpDecode) + if err != nil { + casted.SetBytesAsString(encBytes, charset.CollationUTF8MB4, 0) + nSrc := charset.CountValidBytesDecode(enc, src) + return handleWrongCharsetValue(ctx, col, src, nSrc) } - return charset.StringValidatorUTF8{IsUTF8MB4: true} - case charset.CharsetLatin1, charset.CharsetBinary: + casted.SetBytesAsString(encBytes, charset.CollationUTF8MB4, 0) return nil - default: - return charset.StringValidatorOther{Charset: col.Charset} } + // Check if the string is valid in the given column charset. + str := casted.GetBytes() + if !charset.IsValid(enc, str) { + replace, _ := enc.Transform(nil, str, charset.OpReplace) + casted.SetBytesAsString(replace, charset.CollationUTF8MB4, 0) + nSrc := charset.CountValidBytes(enc, str) + return handleWrongCharsetValue(ctx, col, str, nSrc) + } + return nil } // ColDesc describes column information like MySQL desc and show columns do. From 92207005ec8e37b545d810ec35da477844cca2f4 Mon Sep 17 00:00:00 2001 From: glorv Date: Mon, 20 Dec 2021 15:45:46 +0800 Subject: [PATCH 7/9] lightning: optimize region split check logic (#30428) close pingcap/tidb#30018 --- br/pkg/lightning/backend/local/local.go | 10 ++++++++-- br/pkg/lightning/restore/table_restore.go | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index f4921da0fa55e..eb7ab37802e4e 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -82,7 +82,8 @@ const ( gRPCBackOffMaxDelay = 10 * time.Minute // See: https://github.com/tikv/tikv/blob/e030a0aae9622f3774df89c62f21b2171a72a69e/etc/config-template.toml#L360 - regionMaxKeyCount = 1_440_000 + // lower the max-key-count to avoid tikv trigger region auto split + regionMaxKeyCount = 1_280_000 defaultRegionSplitSize = 96 * units.MiB propRangeIndex = "tikv.range_index" @@ -782,7 +783,12 @@ func (local *local) WriteToTiKV( size := int64(0) totalCount := int64(0) firstLoop := true - regionMaxSize := regionSplitSize * 4 / 3 + // if region-split-size <= 96MiB, we bump the threshold a bit to avoid too many retry split + // because the range-properties is not 100% accurate + regionMaxSize := regionSplitSize + if regionSplitSize <= defaultRegionSplitSize { + regionMaxSize = regionSplitSize * 4 / 3 + } for iter.First(); iter.Valid(); iter.Next() { size += int64(len(iter.Key()) + len(iter.Value())) diff --git a/br/pkg/lightning/restore/table_restore.go b/br/pkg/lightning/restore/table_restore.go index 8664943e75199..a60d34dcfa20c 100644 --- a/br/pkg/lightning/restore/table_restore.go +++ b/br/pkg/lightning/restore/table_restore.go @@ -998,8 +998,8 @@ func estimateCompactionThreshold(cp *checkpoints.TableCheckpoint, factor int64) threshold := totalRawFileSize / 512 threshold = utils.NextPowerOfTwo(threshold) if threshold < compactionLowerThreshold { - // disable compaction if threshold is smaller than lower bound - threshold = 0 + // too may small SST files will cause inaccuracy of region range estimation, + threshold = compactionLowerThreshold } else if threshold > compactionUpperThreshold { threshold = compactionUpperThreshold } From c44630e137cd268060b9228af0e68d3d5799588d Mon Sep 17 00:00:00 2001 From: Zak Zhao <57036248+joccau@users.noreply.github.com> Date: Mon, 20 Dec 2021 16:23:46 +0800 Subject: [PATCH 8/9] br: ignore mock directory when gcov in br (#30586) --- .codecov.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.codecov.yml b/.codecov.yml index f2482097c10a9..07178456a6804 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -43,4 +43,5 @@ ignore: - "executor/seqtest/.*" - "metrics/.*" - "expression/generator/.*" + - "br/pkg/mock/.*" From b9d9f19bd1f15cc7d28d572b3d2433c811a91b6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Mon, 20 Dec 2021 16:47:46 +0800 Subject: [PATCH 9/9] *: forbid set tiflash replica count for a placement table (#30844) close pingcap/tidb#30741 --- ddl/error.go | 2 + ddl/placement_policy.go | 16 ++++++ ddl/placement_sql_test.go | 105 ++++++++++++++++++++++++++++++++++++++ ddl/table.go | 12 +++++ 4 files changed, 135 insertions(+) diff --git a/ddl/error.go b/ddl/error.go index f95b5baf0e4a9..6819920860e39 100644 --- a/ddl/error.go +++ b/ddl/error.go @@ -310,4 +310,6 @@ var ( errDependentByFunctionalIndex = dbterror.ClassDDL.NewStd(mysql.ErrDependentByFunctionalIndex) // errFunctionalIndexOnBlob when the expression of expression index returns blob or text. errFunctionalIndexOnBlob = dbterror.ClassDDL.NewStd(mysql.ErrFunctionalIndexOnBlob) + // ErrIncompatibleTiFlashAndPlacement when placement and tiflash replica options are set at the same time + ErrIncompatibleTiFlashAndPlacement = dbterror.ClassDDL.NewStdErr(mysql.ErrUnsupportedDDLOperation, parser_mysql.Message("Placement and tiflash replica options cannot be set at the same time", nil)) ) diff --git a/ddl/placement_policy.go b/ddl/placement_policy.go index 3ccba4ef346f6..80faeead391a5 100644 --- a/ddl/placement_policy.go +++ b/ddl/placement_policy.go @@ -381,3 +381,19 @@ func checkPlacementPolicyNotUsedByTable(tblInfo *model.TableInfo, policy *model. return nil } + +func tableHasPlacementSettings(tblInfo *model.TableInfo) bool { + if tblInfo.DirectPlacementOpts != nil || tblInfo.PlacementPolicyRef != nil { + return true + } + + if tblInfo.Partition != nil { + for _, def := range tblInfo.Partition.Definitions { + if def.DirectPlacementOpts != nil || def.PlacementPolicyRef != nil { + return true + } + } + } + + return false +} diff --git a/ddl/placement_sql_test.go b/ddl/placement_sql_test.go index 62d6f36db1c4a..7dd419bac54fc 100644 --- a/ddl/placement_sql_test.go +++ b/ddl/placement_sql_test.go @@ -21,6 +21,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/ddl/placement" mysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/model" @@ -472,3 +473,107 @@ func (s *testDBSuite6) TestEnablePlacementCheck(c *C) { tk.MustGetErrCode("create table m (c int) partition by range (c) (partition p1 values less than (200) followers=2);", mysql.ErrUnsupportedDDLOperation) tk.MustGetErrCode("alter table t partition p1 placement policy=\"placement_x\";", mysql.ErrUnsupportedDDLOperation) } + +func (s *testDBSuite6) TestPlacementTiflashCheck(c *C) { + tk := testkit.NewTestKit(c, s.store) + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "set @@global.tidb_enable_alter_placement=1") + c.Assert(err, IsNil) + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount", `return(true)`), IsNil) + defer func() { + err := failpoint.Disable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount") + c.Assert(err, IsNil) + }() + + tk.MustExec("use test") + tk.MustExec("drop placement policy if exists p1") + tk.MustExec("drop table if exists tp") + + tk.MustExec("create placement policy p1 primary_region='r1' regions='r1'") + defer tk.MustExec("drop placement policy if exists p1") + + tk.MustExec(`CREATE TABLE tp (id INT) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN (1000) + )`) + defer tk.MustExec("drop table if exists tp") + tk.MustExec("alter table tp set tiflash replica 1") + + err = tk.ExecToErr("alter table tp placement policy p1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + err = tk.ExecToErr("alter table tp primary_region='r2' regions='r2'") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + err = tk.ExecToErr("alter table tp partition p0 placement policy p1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + err = tk.ExecToErr("alter table tp partition p0 primary_region='r2' regions='r2'") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100),\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) + + tk.MustExec("drop table tp") + tk.MustExec(`CREATE TABLE tp (id INT) placement policy p1 PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN (1000) + )`) + err = tk.ExecToErr("alter table tp set tiflash replica 1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PLACEMENT POLICY=`p1` */\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100),\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) + + tk.MustExec("drop table tp") + tk.MustExec(`CREATE TABLE tp (id INT) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100) placement policy p1 , + PARTITION p1 VALUES LESS THAN (1000) + )`) + err = tk.ExecToErr("alter table tp set tiflash replica 1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100) /*T![placement] PLACEMENT POLICY=`p1` */,\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) + + tk.MustExec("drop table tp") + tk.MustExec(`CREATE TABLE tp (id INT) primary_region='r2' regions='r2' PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN (1000) + )`) + err = tk.ExecToErr("alter table tp set tiflash replica 1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"r2\" REGIONS=\"r2\" */\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100),\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) + + tk.MustExec("drop table tp") + tk.MustExec(`CREATE TABLE tp (id INT) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100) primary_region='r3' regions='r3', + PARTITION p1 VALUES LESS THAN (1000) + )`) + err = tk.ExecToErr("alter table tp set tiflash replica 1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100) /*T![placement] PRIMARY_REGION=\"r3\" REGIONS=\"r3\" */,\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) +} diff --git a/ddl/table.go b/ddl/table.go index cd2b818450c57..83f7ad0b0e58a 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -954,6 +954,10 @@ func (w *worker) onSetTableFlashReplica(t *meta.Meta, job *model.Job) (ver int64 return ver, errors.Trace(err) } + if replicaInfo.Count > 0 && tableHasPlacementSettings(tblInfo) { + return ver, errors.Trace(ErrIncompatibleTiFlashAndPlacement) + } + // Ban setting replica count for tables in system database. if tidb_util.IsMemOrSysDB(job.SchemaName) { return ver, errors.Trace(errUnsupportedAlterReplicaForSysTable) @@ -1274,6 +1278,10 @@ func onAlterTablePartitionPlacement(t *meta.Meta, job *model.Job) (ver int64, er return 0, err } + if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Count > 0 { + return 0, errors.Trace(ErrIncompatibleTiFlashAndPlacement) + } + ptInfo := tblInfo.GetPartitionInfo() var partitionDef *model.PartitionDefinition definitions := ptInfo.Definitions @@ -1341,6 +1349,10 @@ func onAlterTablePlacement(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, return 0, err } + if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Count > 0 { + return 0, errors.Trace(ErrIncompatibleTiFlashAndPlacement) + } + if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, policyRefInfo); err != nil { return 0, errors.Trace(err) }