diff --git a/.codecov.yml b/.codecov.yml index f2482097c10a9..07178456a6804 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -43,4 +43,5 @@ ignore: - "executor/seqtest/.*" - "metrics/.*" - "expression/generator/.*" + - "br/pkg/mock/.*" diff --git a/bindinfo/handle.go b/bindinfo/handle.go index b1920047a20b6..d757c9fb578c9 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -131,7 +131,7 @@ func (h *BindHandle) Update(fullLoad bool) (err error) { } exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source + stmt, err := exec.ParseWithParamsInternal(context.TODO(), `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time, create_time`, updateTime) if err != nil { return err @@ -697,7 +697,7 @@ func (h *BindHandle) extractCaptureFilterFromStorage() (filter *captureFilter) { tables: make(map[stmtctx.TableEntry]struct{}), } exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT filter_type, filter_value FROM mysql.capture_plan_baselines_blacklist order by filter_type`) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), `SELECT filter_type, filter_value FROM mysql.capture_plan_baselines_blacklist order by filter_type`) if err != nil { logutil.BgLogger().Warn("[sql-bind] failed to parse query for mysql.capture_plan_baselines_blacklist load", zap.Error(err)) return @@ -923,7 +923,7 @@ func (h *BindHandle) SaveEvolveTasksToStore() { } func getEvolveParameters(ctx sessionctx.Context) (time.Duration, time.Time, time.Time, error) { - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams( + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal( context.TODO(), "SELECT variable_name, variable_value FROM mysql.global_variables WHERE variable_name IN (%?, %?, %?)", variable.TiDBEvolvePlanTaskMaxTime, diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index f4921da0fa55e..eb7ab37802e4e 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -82,7 +82,8 @@ const ( gRPCBackOffMaxDelay = 10 * time.Minute // See: https://github.com/tikv/tikv/blob/e030a0aae9622f3774df89c62f21b2171a72a69e/etc/config-template.toml#L360 - regionMaxKeyCount = 1_440_000 + // lower the max-key-count to avoid tikv trigger region auto split + regionMaxKeyCount = 1_280_000 defaultRegionSplitSize = 96 * units.MiB propRangeIndex = "tikv.range_index" @@ -782,7 +783,12 @@ func (local *local) WriteToTiKV( size := int64(0) totalCount := int64(0) firstLoop := true - regionMaxSize := regionSplitSize * 4 / 3 + // if region-split-size <= 96MiB, we bump the threshold a bit to avoid too many retry split + // because the range-properties is not 100% accurate + regionMaxSize := regionSplitSize + if regionSplitSize <= defaultRegionSplitSize { + regionMaxSize = regionSplitSize * 4 / 3 + } for iter.First(); iter.Valid(); iter.Next() { size += int64(len(iter.Key()) + len(iter.Value())) diff --git a/br/pkg/lightning/restore/table_restore.go b/br/pkg/lightning/restore/table_restore.go index 8664943e75199..a60d34dcfa20c 100644 --- a/br/pkg/lightning/restore/table_restore.go +++ b/br/pkg/lightning/restore/table_restore.go @@ -998,8 +998,8 @@ func estimateCompactionThreshold(cp *checkpoints.TableCheckpoint, factor int64) threshold := totalRawFileSize / 512 threshold = utils.NextPowerOfTwo(threshold) if threshold < compactionLowerThreshold { - // disable compaction if threshold is smaller than lower bound - threshold = 0 + // too may small SST files will cause inaccuracy of region range estimation, + threshold = compactionLowerThreshold } else if threshold > compactionUpperThreshold { threshold = compactionUpperThreshold } diff --git a/cmd/explaintest/r/new_character_set_builtin.result b/cmd/explaintest/r/new_character_set_builtin.result index fa733d990ef1e..f587a5ac1370e 100644 --- a/cmd/explaintest/r/new_character_set_builtin.result +++ b/cmd/explaintest/r/new_character_set_builtin.result @@ -1,3 +1,4 @@ +set @@sql_mode = ''; drop table if exists t; create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20)); insert into t values ('一二三', '一二三', '一二三'); @@ -244,8 +245,8 @@ insert into t values ('65'), ('123456'), ('123456789'); select char(a using gbk), char(a using utf8), char(a) from t; char(a using gbk) char(a using utf8) char(a) A A A -釦 @ @ -NULL [ [ +釦  @ +[ [ [ select char(12345678 using gbk); char(12345678 using gbk) 糰N @@ -253,8 +254,8 @@ set @@tidb_enable_vectorized_expression = true; select char(a using gbk), char(a using utf8), char(a) from t; char(a using gbk) char(a using utf8) char(a) A A A -釦 @ @ -NULL [ [ +釦  @ +[ [ [ select char(12345678 using gbk); char(12345678 using gbk) 糰N diff --git a/cmd/explaintest/t/new_character_set_builtin.test b/cmd/explaintest/t/new_character_set_builtin.test index 09b823cdcfaa9..bb0a6321e8a53 100644 --- a/cmd/explaintest/t/new_character_set_builtin.test +++ b/cmd/explaintest/t/new_character_set_builtin.test @@ -1,3 +1,4 @@ +set @@sql_mode = ''; -- test for builtin function hex(), length(), ascii(), octet_length() drop table if exists t; create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20)); diff --git a/ddl/column.go b/ddl/column.go index ee1eed6c6d669..0bebd5ac6ea60 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -996,7 +996,7 @@ func (w *worker) doModifyColumnTypeWithData( } defer w.sessPool.put(ctx) - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), valStr) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal(context.Background(), valStr) if err != nil { job.State = model.JobStateCancelled failpoint.Return(ver, err) @@ -1703,7 +1703,7 @@ func checkForNullValue(ctx context.Context, sctx sessionctx.Context, isDataTrunc } } buf.WriteString(" limit 1") - stmt, err := sctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, buf.String(), paramsList...) + stmt, err := sctx.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal(ctx, buf.String(), paramsList...) if err != nil { return errors.Trace(err) } diff --git a/ddl/error.go b/ddl/error.go index f95b5baf0e4a9..6819920860e39 100644 --- a/ddl/error.go +++ b/ddl/error.go @@ -310,4 +310,6 @@ var ( errDependentByFunctionalIndex = dbterror.ClassDDL.NewStd(mysql.ErrDependentByFunctionalIndex) // errFunctionalIndexOnBlob when the expression of expression index returns blob or text. errFunctionalIndexOnBlob = dbterror.ClassDDL.NewStd(mysql.ErrFunctionalIndexOnBlob) + // ErrIncompatibleTiFlashAndPlacement when placement and tiflash replica options are set at the same time + ErrIncompatibleTiFlashAndPlacement = dbterror.ClassDDL.NewStdErr(mysql.ErrUnsupportedDDLOperation, parser_mysql.Message("Placement and tiflash replica options cannot be set at the same time", nil)) ) diff --git a/ddl/partition.go b/ddl/partition.go index 6f6d3b8b0ff59..87a536d5181de 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1551,7 +1551,7 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde } defer w.sessPool.put(ctx) - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(w.ddlJobCtx, sql, paramList...) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal(w.ddlJobCtx, sql, paramList...) if err != nil { return errors.Trace(err) } @@ -1569,7 +1569,7 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde func buildCheckSQLForRangeExprPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { var buf strings.Builder paramList := make([]interface{}, 0, 4) - // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParams(...) + // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParamsInternal(...) // So we write it to the origin sql string here. if index == 0 { buf.WriteString("select 1 from %n.%n where ") diff --git a/ddl/placement_policy.go b/ddl/placement_policy.go index 3ccba4ef346f6..80faeead391a5 100644 --- a/ddl/placement_policy.go +++ b/ddl/placement_policy.go @@ -381,3 +381,19 @@ func checkPlacementPolicyNotUsedByTable(tblInfo *model.TableInfo, policy *model. return nil } + +func tableHasPlacementSettings(tblInfo *model.TableInfo) bool { + if tblInfo.DirectPlacementOpts != nil || tblInfo.PlacementPolicyRef != nil { + return true + } + + if tblInfo.Partition != nil { + for _, def := range tblInfo.Partition.Definitions { + if def.DirectPlacementOpts != nil || def.PlacementPolicyRef != nil { + return true + } + } + } + + return false +} diff --git a/ddl/placement_sql_test.go b/ddl/placement_sql_test.go index 62d6f36db1c4a..7dd419bac54fc 100644 --- a/ddl/placement_sql_test.go +++ b/ddl/placement_sql_test.go @@ -21,6 +21,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/ddl/placement" mysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/model" @@ -472,3 +473,107 @@ func (s *testDBSuite6) TestEnablePlacementCheck(c *C) { tk.MustGetErrCode("create table m (c int) partition by range (c) (partition p1 values less than (200) followers=2);", mysql.ErrUnsupportedDDLOperation) tk.MustGetErrCode("alter table t partition p1 placement policy=\"placement_x\";", mysql.ErrUnsupportedDDLOperation) } + +func (s *testDBSuite6) TestPlacementTiflashCheck(c *C) { + tk := testkit.NewTestKit(c, s.store) + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "set @@global.tidb_enable_alter_placement=1") + c.Assert(err, IsNil) + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount", `return(true)`), IsNil) + defer func() { + err := failpoint.Disable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount") + c.Assert(err, IsNil) + }() + + tk.MustExec("use test") + tk.MustExec("drop placement policy if exists p1") + tk.MustExec("drop table if exists tp") + + tk.MustExec("create placement policy p1 primary_region='r1' regions='r1'") + defer tk.MustExec("drop placement policy if exists p1") + + tk.MustExec(`CREATE TABLE tp (id INT) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN (1000) + )`) + defer tk.MustExec("drop table if exists tp") + tk.MustExec("alter table tp set tiflash replica 1") + + err = tk.ExecToErr("alter table tp placement policy p1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + err = tk.ExecToErr("alter table tp primary_region='r2' regions='r2'") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + err = tk.ExecToErr("alter table tp partition p0 placement policy p1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + err = tk.ExecToErr("alter table tp partition p0 primary_region='r2' regions='r2'") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100),\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) + + tk.MustExec("drop table tp") + tk.MustExec(`CREATE TABLE tp (id INT) placement policy p1 PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN (1000) + )`) + err = tk.ExecToErr("alter table tp set tiflash replica 1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PLACEMENT POLICY=`p1` */\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100),\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) + + tk.MustExec("drop table tp") + tk.MustExec(`CREATE TABLE tp (id INT) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100) placement policy p1 , + PARTITION p1 VALUES LESS THAN (1000) + )`) + err = tk.ExecToErr("alter table tp set tiflash replica 1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100) /*T![placement] PLACEMENT POLICY=`p1` */,\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) + + tk.MustExec("drop table tp") + tk.MustExec(`CREATE TABLE tp (id INT) primary_region='r2' regions='r2' PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN (1000) + )`) + err = tk.ExecToErr("alter table tp set tiflash replica 1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"r2\" REGIONS=\"r2\" */\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100),\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) + + tk.MustExec("drop table tp") + tk.MustExec(`CREATE TABLE tp (id INT) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100) primary_region='r3' regions='r3', + PARTITION p1 VALUES LESS THAN (1000) + )`) + err = tk.ExecToErr("alter table tp set tiflash replica 1") + c.Assert(ddl.ErrIncompatibleTiFlashAndPlacement.Equal(err), IsTrue) + tk.MustQuery("show create table tp").Check(testkit.Rows("" + + "tp CREATE TABLE `tp` (\n" + + " `id` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin\n" + + "PARTITION BY RANGE (`id`)\n" + + "(PARTITION `p0` VALUES LESS THAN (100) /*T![placement] PRIMARY_REGION=\"r3\" REGIONS=\"r3\" */,\n" + + " PARTITION `p1` VALUES LESS THAN (1000))")) +} diff --git a/ddl/reorg.go b/ddl/reorg.go index d1a8f5b0b4b59..54cda19a7974d 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -341,7 +341,7 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { return statistics.PseudoRowCount } sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" - stmt, err := executor.ParseWithParams(w.ddlJobCtx, sql, tblInfo.ID) + stmt, err := executor.ParseWithParamsInternal(w.ddlJobCtx, sql, tblInfo.ID) if err != nil { return statistics.PseudoRowCount } diff --git a/ddl/table.go b/ddl/table.go index cd2b818450c57..83f7ad0b0e58a 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -954,6 +954,10 @@ func (w *worker) onSetTableFlashReplica(t *meta.Meta, job *model.Job) (ver int64 return ver, errors.Trace(err) } + if replicaInfo.Count > 0 && tableHasPlacementSettings(tblInfo) { + return ver, errors.Trace(ErrIncompatibleTiFlashAndPlacement) + } + // Ban setting replica count for tables in system database. if tidb_util.IsMemOrSysDB(job.SchemaName) { return ver, errors.Trace(errUnsupportedAlterReplicaForSysTable) @@ -1274,6 +1278,10 @@ func onAlterTablePartitionPlacement(t *meta.Meta, job *model.Job) (ver int64, er return 0, err } + if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Count > 0 { + return 0, errors.Trace(ErrIncompatibleTiFlashAndPlacement) + } + ptInfo := tblInfo.GetPartitionInfo() var partitionDef *model.PartitionDefinition definitions := ptInfo.Definitions @@ -1341,6 +1349,10 @@ func onAlterTablePlacement(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, return 0, err } + if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Count > 0 { + return 0, errors.Trace(ErrIncompatibleTiFlashAndPlacement) + } + if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, policyRefInfo); err != nil { return 0, errors.Trace(err) } diff --git a/ddl/util/util.go b/ddl/util/util.go index 62dba84cb4e62..993c0c226f6a0 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -176,7 +176,7 @@ func LoadGlobalVars(ctx context.Context, sctx sessionctx.Context, varNames []str paramNames = append(paramNames, name) } buf.WriteString(")") - stmt, err := e.ParseWithParams(ctx, buf.String(), paramNames...) + stmt, err := e.ParseWithParamsInternal(ctx, buf.String(), paramNames...) if err != nil { return errors.Trace(err) } diff --git a/domain/sysvar_cache.go b/domain/sysvar_cache.go index d89ba88a76ee0..bb0ff2d0c9ab0 100644 --- a/domain/sysvar_cache.go +++ b/domain/sysvar_cache.go @@ -94,7 +94,7 @@ func (do *Domain) fetchTableValues(ctx sessionctx.Context) (map[string]string, e tableContents := make(map[string]string) // Copy all variables from the table to tableContents exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), `SELECT variable_name, variable_value FROM mysql.global_variables`) + stmt, err := exec.ParseWithParamsInternal(context.Background(), `SELECT variable_name, variable_value FROM mysql.global_variables`) if err != nil { return tableContents, err } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index a99d7e71a69f4..976775ce5dc4c 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -1483,6 +1483,7 @@ func TestAvgDecimal(t *testing.T) { tk.MustExec("insert into td values (0,29815);") tk.MustExec("insert into td values (10017,-32661);") tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260")) + tk.MustQuery(" SELECT AVG(col_bigint) OVER (PARTITION BY col_smallint) as field2 FROM td where col_smallint = -23828;").Sort().Check(testkit.Rows("4.0000")) tk.MustExec("drop table td;") } diff --git a/executor/analyze.go b/executor/analyze.go index 372108764babd..5397c1ee0608c 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -1570,7 +1570,7 @@ type AnalyzeFastExec struct { func (e *AnalyzeFastExec) calculateEstimateSampleStep() (err error) { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) var stmt ast.StmtNode - stmt, err = exec.ParseWithParams(context.TODO(), "select flag from mysql.stats_histograms where table_id = %?", e.tableID.GetStatisticsID()) + stmt, err = exec.ParseWithParamsInternal(context.TODO(), "select flag from mysql.stats_histograms where table_id = %?", e.tableID.GetStatisticsID()) if err != nil { return } diff --git a/executor/brie.go b/executor/brie.go index 9be7349c494a8..e72958d9d5f7d 100644 --- a/executor/brie.go +++ b/executor/brie.go @@ -462,7 +462,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error) // These queries execute without privilege checking, since the calling statements // such as BACKUP and RESTORE have already been privilege checked. func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error { - stmt, err := gs.se.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, sql) + stmt, err := gs.se.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal(ctx, sql) if err != nil { return err } diff --git a/executor/builder.go b/executor/builder.go index 73d20b93eb9bb..78c94fd02f1aa 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -4192,7 +4192,7 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) Executor { partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) for _, desc := range v.WindowFuncDescs { - aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false) + aggDesc, err := aggregation.NewAggFuncDescForWindowFunc(b.ctx, desc, false) if err != nil { b.err = err return nil diff --git a/executor/ddl.go b/executor/ddl.go index ac97d95d2fa5a..579df42b70fcb 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -506,7 +506,7 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx zap.String("table", fullti.Name.O), ) exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) if err != nil { return err } diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index 1e4fcae3829ba..06d35d24fce1d 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -190,7 +190,7 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex func getRowCountAllTable(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select table_id, count from mysql.stats_meta") + stmt, err := exec.ParseWithParamsInternal(ctx, "select table_id, count from mysql.stats_meta") if err != nil { return nil, err } @@ -215,7 +215,7 @@ type tableHistID struct { func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[tableHistID]uint64, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") + stmt, err := exec.ParseWithParamsInternal(ctx, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") if err != nil { return nil, err } diff --git a/executor/inspection_profile.go b/executor/inspection_profile.go index 90ea2bc224c5f..b8ab8b383df43 100644 --- a/executor/inspection_profile.go +++ b/executor/inspection_profile.go @@ -167,7 +167,7 @@ func (n *metricNode) getLabelValue(label string) *metricValue { func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error { exec := pb.sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), query) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), query) if err != nil { return err } diff --git a/executor/inspection_result.go b/executor/inspection_result.go index bb26993824684..4062dab3c3176 100644 --- a/executor/inspection_result.go +++ b/executor/inspection_result.go @@ -140,7 +140,7 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct e.statusToInstanceAddress = make(map[string]string) var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select instance,status_address from information_schema.cluster_info;") + stmt, err := exec.ParseWithParamsInternal(ctx, "select instance,status_address from information_schema.cluster_info;") if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -251,7 +251,7 @@ func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.C } var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) + stmt, err := exec.ParseWithParamsInternal(ctx, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -261,7 +261,7 @@ func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.C generateDetail := func(tp, item string) string { var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) + stmt, err := exec.ParseWithParamsInternal(ctx, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -347,7 +347,7 @@ func (c configInspection) inspectCheckConfig(ctx context.Context, sctx sessionct } sql.Reset() fmt.Fprintf(sql, "select type,instance,value from information_schema.%s where %s", cas.table, cas.cond) - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -378,7 +378,7 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct } var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") + stmt, err := exec.ParseWithParamsInternal(ctx, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -405,7 +405,7 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct ipToCount[ip]++ } - stmt, err = exec.ParseWithParams(ctx, "select instance, value from metrics_schema.node_total_memory where time=now()") + stmt, err = exec.ParseWithParamsInternal(ctx, "select instance, value from metrics_schema.node_total_memory where time=now()") if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -473,7 +473,7 @@ func (versionInspection) inspect(ctx context.Context, sctx sessionctx.Context, f exec := sctx.(sqlexec.RestrictedSQLExecutor) var rows []chunk.Row // check the configuration consistent - stmt, err := exec.ParseWithParams(ctx, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") + stmt, err := exec.ParseWithParamsInternal(ctx, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -643,7 +643,7 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx sql.Reset() fmt.Fprintf(sql, "select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", strings.Join(def.Labels, "`,`"), util.MetricSchemaName.L, rule.tbl, condition) - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -698,7 +698,7 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se (select instance,job from metrics_schema.up %[1]s group by instance,job having max(value)-min(value)>0) as t1 join (select instance,min(time) as min_time from metrics_schema.up %[1]s and value=0 group by instance,job) as t2 on t1.instance=t2.instance order by job`, condition) var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -726,7 +726,7 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se // Check from log. sql.Reset() fmt.Fprintf(sql, "select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) - stmt, err = exec.ParseWithParams(ctx, sql.String()) + stmt, err = exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -863,7 +863,7 @@ func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sess (select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 where t1.cpu > %[2]f;`, rule.component, rule.threshold, condition) } - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -1036,7 +1036,7 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess } else { fmt.Fprintf(sql, "select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) } - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -1222,7 +1222,7 @@ func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionF continue } sql := rule.genSQL(filter.timeRange) - stmt, err := exec.ParseWithParams(ctx, sql) + stmt, err := exec.ParseWithParamsInternal(ctx, sql) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -1245,7 +1245,7 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx exec := sctx.(sqlexec.RestrictedSQLExecutor) var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -1259,7 +1259,7 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx sql.Reset() fmt.Fprintf(sql, `select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) var subRows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { subRows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } diff --git a/executor/inspection_summary.go b/executor/inspection_summary.go index ffd235451cebb..e709373363e74 100644 --- a/executor/inspection_summary.go +++ b/executor/inspection_summary.go @@ -460,7 +460,7 @@ func (e *inspectionSummaryRetriever) retrieve(ctx context.Context, sctx sessionc util.MetricSchemaName.L, name, cond) } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, sql) + stmt, err := exec.ParseWithParamsInternal(ctx, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/metrics_reader.go b/executor/metrics_reader.go index d6df64dc7a377..2ef9f1196ad94 100644 --- a/executor/metrics_reader.go +++ b/executor/metrics_reader.go @@ -233,7 +233,7 @@ func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx. } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, sql) + stmt, err := exec.ParseWithParamsInternal(ctx, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } @@ -318,7 +318,7 @@ func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sess util.MetricSchemaName.L, name, cond) } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, sql) + stmt, err := exec.ParseWithParamsInternal(ctx, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/opt_rule_blacklist.go b/executor/opt_rule_blacklist.go index 0d915c9eb6966..b013087d925f2 100644 --- a/executor/opt_rule_blacklist.go +++ b/executor/opt_rule_blacklist.go @@ -37,7 +37,7 @@ func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) e // LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist. func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") + stmt, err := exec.ParseWithParamsInternal(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") if err != nil { return err } diff --git a/executor/reload_expr_pushdown_blacklist.go b/executor/reload_expr_pushdown_blacklist.go index 81c8ea4f3cccb..0284edba08f38 100644 --- a/executor/reload_expr_pushdown_blacklist.go +++ b/executor/reload_expr_pushdown_blacklist.go @@ -39,7 +39,7 @@ func (e *ReloadExprPushdownBlacklistExec) Next(ctx context.Context, _ *chunk.Chu // LoadExprPushdownBlacklist loads the latest data from table mysql.expr_pushdown_blacklist. func LoadExprPushdownBlacklist(ctx sessionctx.Context) (err error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") + stmt, err := exec.ParseWithParamsInternal(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") if err != nil { return err } diff --git a/executor/show.go b/executor/show.go index 2a4eac148f74e..ea4bdf4f1aee8 100644 --- a/executor/show.go +++ b/executor/show.go @@ -342,7 +342,7 @@ func (e *ShowExec) fetchShowBind() error { func (e *ShowExec) fetchShowEngines(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `SELECT * FROM information_schema.engines`) + stmt, err := exec.ParseWithParamsInternal(ctx, `SELECT * FROM information_schema.engines`) if err != nil { return errors.Trace(err) } @@ -473,7 +473,7 @@ func (e *ShowExec) fetchShowTableStatus(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `SELECT + stmt, err := exec.ParseWithParamsInternal(ctx, `SELECT table_name, engine, version, row_format, table_rows, avg_row_length, data_length, max_data_length, index_length, data_free, auto_increment, create_time, update_time, check_time, @@ -1433,7 +1433,7 @@ func (e *ShowExec) fetchShowCreateUser(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `SELECT plugin FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, strings.ToLower(hostName)) + stmt, err := exec.ParseWithParamsInternal(ctx, `SELECT plugin FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, strings.ToLower(hostName)) if err != nil { return errors.Trace(err) } @@ -1453,7 +1453,7 @@ func (e *ShowExec) fetchShowCreateUser(ctx context.Context) error { authplugin = rows[0].GetString(0) } - stmt, err = exec.ParseWithParams(ctx, `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) + stmt, err = exec.ParseWithParamsInternal(ctx, `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) if err != nil { return errors.Trace(err) } diff --git a/executor/show_placement.go b/executor/show_placement.go index d77c0a31000f1..a53a86ffba019 100644 --- a/executor/show_placement.go +++ b/executor/show_placement.go @@ -107,7 +107,7 @@ func (b *showPlacementLabelsResultBuilder) sortMapKeys(m map[string]interface{}) func (e *ShowExec) fetchShowPlacementLabels(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "SELECT DISTINCT LABEL FROM %n.%n", "INFORMATION_SCHEMA", infoschema.TableTiKVStoreStatus) + stmt, err := exec.ParseWithParamsInternal(ctx, "SELECT DISTINCT LABEL FROM %n.%n", "INFORMATION_SCHEMA", infoschema.TableTiKVStoreStatus) if err != nil { return errors.Trace(err) } diff --git a/executor/simple.go b/executor/simple.go index 82e941eca143f..ee6932ad1dc91 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -966,7 +966,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) if !ok { return errors.Trace(ErrPasswordFormat) } - stmt, err := exec.ParseWithParams(ctx, + stmt, err := exec.ParseWithParamsInternal(ctx, `UPDATE %n.%n SET authentication_string=%?, plugin=%? WHERE Host=%? and User=%?;`, mysql.SystemDB, mysql.UserTable, pwd, spec.AuthOpt.AuthPlugin, strings.ToLower(spec.User.Hostname), spec.User.Username, ) @@ -980,7 +980,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) } if len(privData) > 0 { - stmt, err := exec.ParseWithParams(ctx, "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) + stmt, err := exec.ParseWithParamsInternal(ctx, "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) if err != nil { return err } @@ -1358,7 +1358,7 @@ func (e *SimpleExec) executeDropUser(ctx context.Context, s *ast.DropUserStmt) e func userExists(ctx context.Context, sctx sessionctx.Context, name string, host string) (bool, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, strings.ToLower(host)) + stmt, err := exec.ParseWithParamsInternal(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, strings.ToLower(host)) if err != nil { return false, err } @@ -1441,7 +1441,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error // update mysql.user exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, strings.ToLower(h)) + stmt, err := exec.ParseWithParamsInternal(ctx, `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, strings.ToLower(h)) if err != nil { return err } diff --git a/executor/simple_test.go b/executor/simple_test.go index b8dc034076ec7..6d83b91b08485 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -988,3 +988,25 @@ func (s *testSuite3) TestDropRoleAfterRevoke(c *C) { tk.MustExec("revoke r1, r3 from root;") tk.MustExec("drop role r1;") } + +func (s *testSuiteWithCliBaseCharset) TestUserWithSetNames(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("set names gbk;") + + gbkString := string([]byte{0xD2, 0xBB}) + + tk.MustExec("drop user if exists '一'@'localhost';") + tk.MustExec("create user '一'@'localhost' IDENTIFIED BY '" + gbkString + "';") + + result := tk.MustQuery(`SELECT authentication_string FROM mysql.User WHERE User="一" and Host="localhost";`) + result.Check(testkit.Rows(auth.EncodePassword("一"))) + + tk.MustExec(`ALTER USER '一'@'localhost' IDENTIFIED BY '` + gbkString + gbkString + `';`) + result = tk.MustQuery(`SELECT authentication_string FROM mysql.User WHERE User="一" and Host="localhost";`) + result.Check(testkit.Rows(auth.EncodePassword("一一"))) + + tk.MustExec(`RENAME USER '一'@'localhost' to '一'`) + + tk.MustExec("drop user '一';") +} diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 10d51bb1a27d7..14e7c13e9c17e 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -829,6 +829,7 @@ func (s *tiflashTestSuite) TestAvgOverflow(c *C) { tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") tk.MustExec("set @@session.tidb_enforce_mpp=ON") tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260")) + tk.MustQuery(" SELECT AVG(col_bigint) OVER (PARTITION BY col_smallint) as field2 FROM td where col_smallint = -23828;").Sort().Check(testkit.Rows("4.0000")) tk.MustExec("drop table if exists td;") } diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 30f020e7dfdf2..8882b93ec05d7 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -43,6 +43,7 @@ type AggFuncDesc struct { } // NewAggFuncDesc creates an aggregation function signature descriptor. +// this func cannot be called twice as the TypeInfer has changed the type of args in the first time. func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) { b, err := newBaseFuncDesc(ctx, name, args) if err != nil { @@ -51,6 +52,14 @@ func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expre return &AggFuncDesc{baseFuncDesc: b, HasDistinct: hasDistinct}, nil } +// NewAggFuncDescForWindowFunc creates an aggregation function from window functions, where baseFuncDesc may be ready. +func NewAggFuncDescForWindowFunc(ctx sessionctx.Context, Desc *WindowFuncDesc, hasDistinct bool) (*AggFuncDesc, error) { + if Desc.RetTp == nil { // safety check + return NewAggFuncDesc(ctx, Desc.Name, Desc.Args, hasDistinct) + } + return &AggFuncDesc{baseFuncDesc: baseFuncDesc{Desc.Name, Desc.Args, Desc.RetTp}, HasDistinct: hasDistinct}, nil +} + // String implements the fmt.Stringer interface. func (a *AggFuncDesc) String() string { buffer := bytes.NewBufferString(a.Name) diff --git a/expression/builtin_convert_charset.go b/expression/builtin_convert_charset.go index fc709cc7c61f0..a24ed138eb578 100644 --- a/expression/builtin_convert_charset.go +++ b/expression/builtin_convert_charset.go @@ -16,7 +16,6 @@ package expression import ( "fmt" - "unicode/utf8" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/ast" @@ -27,6 +26,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/dbterror" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tipb/go-tipb" ) @@ -92,9 +92,9 @@ func (b *builtinInternalToBinarySig) evalString(row chunk.Row) (res string, isNu return res, isNull, err } tp := b.args[0].GetType() - enc := charset.NewEncoding(tp.Charset) - res, err = enc.EncodeString(val) - return res, false, err + enc := charset.FindEncoding(tp.Charset) + ret, err := enc.Transform(nil, hack.Slice(val), charset.OpEncode) + return string(ret), false, err } func (b *builtinInternalToBinarySig) vectorized() bool { @@ -111,7 +111,7 @@ func (b *builtinInternalToBinarySig) vecEvalString(input *chunk.Chunk, result *c if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } - enc := charset.NewEncoding(b.args[0].GetType().Charset) + enc := charset.FindEncoding(b.args[0].GetType().Charset) result.ReserveString(n) var encodedBuf []byte for i := 0; i < n; i++ { @@ -119,11 +119,11 @@ func (b *builtinInternalToBinarySig) vecEvalString(input *chunk.Chunk, result *c result.AppendNull() continue } - strBytes, err := enc.Encode(encodedBuf, buf.GetBytes(i)) + encodedBuf, err = enc.Transform(encodedBuf, buf.GetBytes(i), charset.OpEncode) if err != nil { return err } - result.AppendBytes(strBytes) + result.AppendBytes(encodedBuf) } return nil } @@ -170,9 +170,13 @@ func (b *builtinInternalFromBinarySig) evalString(row chunk.Row) (res string, is if isNull || err != nil { return val, isNull, err } - transferString := b.getTransferFunc() - tBytes, err := transferString([]byte(val)) - return string(tBytes), false, err + enc := charset.FindEncoding(b.tp.Charset) + ret, err := enc.Transform(nil, hack.Slice(val), charset.OpDecode) + if err != nil { + strHex := fmt.Sprintf("%X", val) + err = errCannotConvertString.GenWithStackByArgs(strHex, charset.CharsetBin, b.tp.Charset) + } + return string(ret), false, err } func (b *builtinInternalFromBinarySig) vectorized() bool { @@ -189,45 +193,25 @@ func (b *builtinInternalFromBinarySig) vecEvalString(input *chunk.Chunk, result if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } - transferString := b.getTransferFunc() + enc := charset.FindEncoding(b.tp.Charset) + var encBuf []byte result.ReserveString(n) for i := 0; i < n; i++ { if buf.IsNull(i) { result.AppendNull() continue } - str, err := transferString(buf.GetBytes(i)) + str := buf.GetBytes(i) + encBuf, err = enc.Transform(encBuf, str, charset.OpDecode) if err != nil { - return err + strHex := fmt.Sprintf("%X", str) + return errCannotConvertString.GenWithStackByArgs(strHex, charset.CharsetBin, b.tp.Charset) } - result.AppendBytes(str) + result.AppendBytes(encBuf) } return nil } -func (b *builtinInternalFromBinarySig) getTransferFunc() func([]byte) ([]byte, error) { - var transferString func([]byte) ([]byte, error) - if b.tp.Charset == charset.CharsetUTF8MB4 || b.tp.Charset == charset.CharsetUTF8 { - transferString = func(s []byte) ([]byte, error) { - if !utf8.Valid(s) { - return nil, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset) - } - return s, nil - } - } else { - enc := charset.NewEncoding(b.tp.Charset) - var buf []byte - transferString = func(s []byte) ([]byte, error) { - str, err := enc.Decode(buf, s) - if err != nil { - return nil, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset) - } - return str, nil - } - } - return transferString -} - // BuildToBinaryFunction builds to_binary function. func BuildToBinaryFunction(ctx sessionctx.Context, expr Expression) (res Expression) { fc := &tidbToBinaryFunctionClass{baseFunctionClass{InternalFuncToBinary, 1, 1}} @@ -258,26 +242,94 @@ func BuildFromBinaryFunction(ctx sessionctx.Context, expr Expression, tp *types. return FoldConstant(res) } +type funcProp int8 + +const ( + funcPropNone funcProp = iota + // The arguments of these functions are wrapped with to_binary(). + // For compatibility reason, legacy charsets arguments are not wrapped. + // Legacy charsets: utf8mb4, utf8, latin1, ascii, binary. + funcPropBinAware + // The arguments of these functions are wrapped with to_binary() or from_binary() according to + // the evaluated result charset and the argument charset. + // For binary argument && string result, wrap it with from_binary(). + // For string argument && binary result, wrap it with to_binary(). + funcPropAuto +) + +// convertActionMap collects from https://dev.mysql.com/doc/refman/8.0/en/string-functions.html. +var convertActionMap = map[funcProp][]string{ + funcPropNone: { + /* args != strings */ + ast.Bin, ast.CharFunc, ast.DateFormat, ast.Oct, ast.Space, + /* only 1 string arg, no implicit conversion */ + ast.CharLength, ast.CharacterLength, ast.FromBase64, ast.Lcase, ast.Left, ast.LoadFile, + ast.Lower, ast.LTrim, ast.Mid, ast.Ord, ast.Quote, ast.Repeat, ast.Reverse, ast.Right, + ast.RTrim, ast.Soundex, ast.Substr, ast.Substring, ast.Ucase, ast.Unhex, ast.Upper, ast.WeightString, + /* args are independent, no implicit conversion */ + ast.Elt, + }, + funcPropBinAware: { + /* result is binary-aware */ + ast.ASCII, ast.BitLength, ast.Hex, ast.Length, ast.OctetLength, ast.ToBase64, + /* encrypt functions */ + ast.AesDecrypt, ast.Decode, ast.Encode, ast.PasswordFunc, ast.MD5, ast.SHA, ast.SHA1, + ast.SHA2, ast.Compress, ast.AesEncrypt, + }, + funcPropAuto: { + /* string functions */ ast.Concat, ast.ConcatWS, ast.ExportSet, ast.Field, ast.FindInSet, + ast.InsertFunc, ast.Instr, ast.Lpad, ast.Locate, ast.Lpad, ast.MakeSet, ast.Position, + ast.Replace, ast.Rpad, ast.SubstringIndex, ast.Trim, + /* operators */ + ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, ast.NE, ast.NullEQ, ast.If, ast.Ifnull, ast.In, + ast.Case, + /* string comparing */ + ast.Like, ast.Strcmp, + /* regex */ + ast.Regexp, + }, +} + +var convertFuncsMap = map[string]funcProp{} + +func init() { + for k, fns := range convertActionMap { + for _, f := range fns { + convertFuncsMap[f] = k + } + } +} + // HandleBinaryLiteral wraps `expr` with to_binary or from_binary sig. func HandleBinaryLiteral(ctx sessionctx.Context, expr Expression, ec *ExprCollation, funcName string) Expression { - switch funcName { - case ast.Concat, ast.ConcatWS, ast.Lower, ast.Lcase, ast.Reverse, ast.Upper, ast.Ucase, ast.Quote, ast.Coalesce, - ast.Left, ast.Right, ast.Repeat, ast.Trim, ast.LTrim, ast.RTrim, ast.Substr, ast.SubstringIndex, ast.Replace, - ast.Substring, ast.Mid, ast.Translate, ast.InsertFunc, ast.Lpad, ast.Rpad, ast.Elt, ast.ExportSet, ast.MakeSet, - ast.FindInSet, ast.Regexp, ast.Field, ast.Locate, ast.Instr, ast.Position, ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, - ast.NE, ast.NullEQ, ast.Strcmp, ast.If, ast.Ifnull, ast.Like, ast.In, ast.DateFormat, ast.TimeFormat: - if ec.Charset == charset.CharsetBin && expr.GetType().Charset != charset.CharsetBin { + argChs, dstChs := expr.GetType().Charset, ec.Charset + switch convertFuncsMap[funcName] { + case funcPropNone: + return expr + case funcPropBinAware: + if isLegacyCharset(argChs) { + return expr + } + return BuildToBinaryFunction(ctx, expr) + case funcPropAuto: + if argChs != charset.CharsetBin && dstChs == charset.CharsetBin { + if isLegacyCharset(argChs) { + return expr + } return BuildToBinaryFunction(ctx, expr) - } else if ec.Charset != charset.CharsetBin && expr.GetType().Charset == charset.CharsetBin { + } else if argChs == charset.CharsetBin && dstChs != charset.CharsetBin { ft := expr.GetType().Clone() ft.Charset, ft.Collate = ec.Charset, ec.Collation return BuildFromBinaryFunction(ctx, expr, ft) } - case ast.Hex, ast.Length, ast.OctetLength, ast.ASCII, ast.ToBase64, ast.AesEncrypt, ast.AesDecrypt, ast.Decode, ast.Encode, - ast.PasswordFunc, ast.MD5, ast.SHA, ast.SHA1, ast.SHA2, ast.Compress: - if _, err := charset.GetDefaultCollationLegacy(expr.GetType().Charset); err != nil { - return BuildToBinaryFunction(ctx, expr) - } } return expr } + +func isLegacyCharset(chs string) bool { + switch chs { + case charset.CharsetUTF8, charset.CharsetUTF8MB4, charset.CharsetASCII, charset.CharsetLatin1, charset.CharsetBin: + return true + } + return false +} diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 2dec9405f03db..bda56e7bc4bdb 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -16,6 +16,7 @@ package expression import ( "encoding/hex" + "fmt" "strings" "testing" @@ -91,9 +92,10 @@ func TestSQLEncode(t *testing.T) { d, err := f.Eval(chunk.Row{}) require.NoError(t, err) if test.origin != nil { - result, err := charset.NewEncoding(test.chs).EncodeString(test.origin.(string)) + enc := charset.FindEncoding(test.chs) + result, err := enc.Transform(nil, []byte(test.origin.(string)), charset.OpEncode) require.NoError(t, err) - require.Equal(t, types.NewCollationStringDatum(result, test.chs), d) + require.Equal(t, types.NewCollationStringDatum(string(result), test.chs), d) } else { result := types.NewDatum(test.origin) require.Equal(t, result.GetBytes(), d.GetBytes()) @@ -163,7 +165,8 @@ func TestAESEncrypt(t *testing.T) { testAmbiguousInput(t, ctx, ast.AesEncrypt) // Test GBK String - gbkStr, _ := charset.NewEncoding("gbk").EncodeString("你好") + enc := charset.FindEncoding("gbk") + gbkStr, _ := enc.Transform(nil, []byte("你好"), charset.OpEncode) gbkTests := []struct { mode string chs string @@ -188,19 +191,20 @@ func TestAESEncrypt(t *testing.T) { } for _, tt := range gbkTests { + msg := fmt.Sprintf("%v", tt) err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, tt.chs) - require.NoError(t, err) + require.NoError(t, err, msg) err = variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode) - require.NoError(t, err) + require.NoError(t, err, msg) - args := datumsToConstants([]types.Datum{types.NewDatum(tt.origin)}) + args := primitiveValsToConstants(ctx, []interface{}{tt.origin}) args = append(args, primitiveValsToConstants(ctx, tt.params)...) f, err := fc.getFunction(ctx, args) - require.NoError(t, err) + require.NoError(t, err, msg) crypt, err := evalBuiltinFunc(f, chunk.Row{}) - require.NoError(t, err) - require.Equal(t, types.NewDatum(tt.crypt), toHex(crypt)) + require.NoError(t, err, msg) + require.Equal(t, types.NewDatum(tt.crypt), toHex(crypt), msg) } } @@ -209,21 +213,22 @@ func TestAESDecrypt(t *testing.T) { fc := funcs[ast.AesDecrypt] for _, tt := range aesTests { + msg := fmt.Sprintf("%v", tt) err := variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode) - require.NoError(t, err) + require.NoError(t, err, msg) args := []types.Datum{fromHex(tt.crypt)} for _, param := range tt.params { args = append(args, types.NewDatum(param)) } f, err := fc.getFunction(ctx, datumsToConstants(args)) - require.NoError(t, err) + require.NoError(t, err, msg) str, err := evalBuiltinFunc(f, chunk.Row{}) - require.NoError(t, err) + require.NoError(t, err, msg) if tt.origin == nil { require.True(t, str.IsNull()) continue } - require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str) + require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str, msg) } err := variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, "aes-128-ecb") require.NoError(t, err) @@ -231,7 +236,9 @@ func TestAESDecrypt(t *testing.T) { testAmbiguousInput(t, ctx, ast.AesDecrypt) // Test GBK String - gbkStr, _ := charset.NewEncoding("gbk").EncodeString("你好") + enc := charset.FindEncoding("gbk") + r, _ := enc.Transform(nil, []byte("你好"), charset.OpEncode) + gbkStr := string(r) gbkTests := []struct { mode string chs string @@ -256,18 +263,19 @@ func TestAESDecrypt(t *testing.T) { } for _, tt := range gbkTests { + msg := fmt.Sprintf("%v", tt) err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, tt.chs) - require.NoError(t, err) + require.NoError(t, err, msg) err = variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode) - require.NoError(t, err) + require.NoError(t, err, msg) // Set charset and collate except first argument args := datumsToConstants([]types.Datum{fromHex(tt.crypt)}) args = append(args, primitiveValsToConstants(ctx, tt.params)...) f, err := fc.getFunction(ctx, args) - require.NoError(t, err) + require.NoError(t, err, msg) str, err := evalBuiltinFunc(f, chunk.Row{}) - require.NoError(t, err) - require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str) + require.NoError(t, err, msg) + require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str, msg) } } diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 9ac2eb370d380..925aecb3dc33a 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -41,7 +41,6 @@ import ( "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" - "golang.org/x/text/transform" ) var ( @@ -706,7 +705,7 @@ func (c *lowerFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi sig = &builtinLowerSig{bf} sig.setPbCode(tipb.ScalarFuncSig_Lower) } else { - sig = &builtinLowerUTF8Sig{bf, charset.NewEncoding(argTp.Charset)} + sig = &builtinLowerUTF8Sig{bf} sig.setPbCode(tipb.ScalarFuncSig_LowerUTF8) } return sig, nil @@ -714,15 +713,11 @@ func (c *lowerFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi type builtinLowerUTF8Sig struct { baseBuiltinFunc - encoding *charset.Encoding } func (b *builtinLowerUTF8Sig) Clone() builtinFunc { newSig := &builtinLowerUTF8Sig{} newSig.cloneFrom(&b.baseBuiltinFunc) - if b.encoding != nil { - newSig.encoding = charset.NewEncoding(b.encoding.Name()) - } return newSig } @@ -733,8 +728,8 @@ func (b *builtinLowerUTF8Sig) evalString(row chunk.Row) (d string, isNull bool, if isNull || err != nil { return d, isNull, err } - - return b.encoding.ToLower(d), false, nil + enc := charset.FindEncoding(b.args[0].GetType().Charset) + return enc.ToLower(d), false, nil } type builtinLowerSig struct { @@ -905,7 +900,7 @@ func (c *upperFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi sig = &builtinUpperSig{bf} sig.setPbCode(tipb.ScalarFuncSig_Upper) } else { - sig = &builtinUpperUTF8Sig{bf, charset.NewEncoding(argTp.Charset)} + sig = &builtinUpperUTF8Sig{bf} sig.setPbCode(tipb.ScalarFuncSig_UpperUTF8) } return sig, nil @@ -913,15 +908,11 @@ func (c *upperFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi type builtinUpperUTF8Sig struct { baseBuiltinFunc - encoding *charset.Encoding } func (b *builtinUpperUTF8Sig) Clone() builtinFunc { newSig := &builtinUpperUTF8Sig{} newSig.cloneFrom(&b.baseBuiltinFunc) - if b.encoding != nil { - newSig.encoding = charset.NewEncoding(b.encoding.Name()) - } return newSig } @@ -932,8 +923,8 @@ func (b *builtinUpperUTF8Sig) evalString(row chunk.Row) (d string, isNull bool, if isNull || err != nil { return d, isNull, err } - - return b.encoding.ToUpper(d), false, nil + enc := charset.FindEncoding(b.args[0].GetType().Charset) + return enc.ToUpper(d), false, nil } type builtinUpperSig struct { @@ -1141,27 +1132,27 @@ func (b *builtinConvertSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, err } - - // Since charset is already validated and set from getFunction(), there's no - // need to get charset from args again. - encoding, _ := charset.Lookup(b.tp.Charset) - // However, if `b.tp.Charset` is abnormally set to a wrong charset, we still - // return with error. - if encoding == nil { - return "", true, errUnknownCharacterSet.GenWithStackByArgs(b.tp.Charset) + argTp, resultTp := b.args[0].GetType(), b.tp + if !charset.IsSupportedEncoding(resultTp.Charset) { + return "", false, errUnknownCharacterSet.GenWithStackByArgs(resultTp.Charset) } - // if expr is binary string and convert meet error, we should return NULL. - if types.IsBinaryStr(b.args[0].GetType()) { - exprInternal, _, err := transform.String(encoding.NewDecoder(), expr) - return exprInternal, err != nil, nil + if types.IsBinaryStr(argTp) { + // Convert charset binary -> utf8. If it meets error, NULL is returned. + enc := charset.FindEncoding(resultTp.Charset) + ret, err := enc.Transform(nil, hack.Slice(expr), charset.OpDecodeReplace) + return string(ret), err != nil, nil + } else if types.IsBinaryStr(resultTp) { + // Convert charset utf8 -> binary. + enc := charset.FindEncoding(argTp.Charset) + ret, err := enc.Transform(nil, hack.Slice(expr), charset.OpEncode) + return string(ret), false, err } - if types.IsBinaryStr(b.tp) { - enc := charset.NewEncoding(b.args[0].GetType().Charset) - expr, err = enc.EncodeString(expr) - return expr, false, err + enc := charset.FindEncoding(resultTp.Charset) + if !charset.IsValidString(enc, expr) { + replace, _ := enc.Transform(nil, hack.Slice(expr), charset.OpReplace) + return string(replace), false, nil } - enc := charset.NewEncoding(b.tp.Charset) - return string(enc.EncodeInternal(nil, []byte(expr))), false, nil + return expr, false, nil } type substringFunctionClass struct { @@ -2327,12 +2318,7 @@ func (b *builtinBitLengthSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull || err != nil { return 0, isNull, err } - argTp := b.args[0].GetType() - dBytes, err := charset.NewEncoding(argTp.Charset).Encode(nil, hack.Slice(val)) - if err != nil { - return 0, isNull, err - } - return int64(len(dBytes) * 8), false, nil + return int64(len(val) * 8), false, nil } type charFunctionClass struct { @@ -2421,12 +2407,15 @@ func (b *builtinCharSig) evalString(row chunk.Row) (string, bool, error) { } dBytes := b.convertToBytes(bigints) - resultBytes, err := charset.NewEncoding(b.tp.Charset).Decode(nil, dBytes) + enc := charset.FindEncoding(b.tp.Charset) + res, err := enc.Transform(nil, dBytes, charset.OpDecode) if err != nil { b.ctx.GetSessionVars().StmtCtx.AppendWarning(err) - return "", true, nil + if b.ctx.GetSessionVars().StrictSQLMode { + return "", true, nil + } } - return string(resultBytes), false, nil + return string(res), false, nil } type charLengthFunctionClass struct { @@ -2893,43 +2882,19 @@ func (b *builtinOrdSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull, err } - charSet := b.args[0].GetType().Charset - ord, err := chooseOrdFunc(charSet) + strBytes := hack.Slice(str) + enc := charset.FindEncoding(b.args[0].GetType().Charset) + w := len(charset.EncodingUTF8Impl.Peek(strBytes)) + res, err := enc.Transform(nil, strBytes[:w], charset.OpEncode) if err != nil { - return 0, false, err - } - - enc := charset.NewEncoding(charSet) - leftMost, err := enc.EncodeFirstChar(nil, hack.Slice(str)) - if err != nil { - return 0, false, err - } - return ord(leftMost), false, nil -} - -func chooseOrdFunc(charSet string) (func([]byte) int64, error) { - // use utf8 by default - if charSet == "" { - charSet = charset.CharsetUTF8 - } - desc, err := charset.GetCharsetInfo(charSet) - if err != nil { - return nil, err - } - if desc.Maxlen == 1 { - return ordSingleByte, nil - } - return ordOthers, nil -} - -func ordSingleByte(src []byte) int64 { - if len(src) == 0 { - return 0 + // Fallback to the first byte. + return calcOrd(strBytes[:1]), false, nil } - return int64(src[0]) + // Only the first character is considered. + return calcOrd(res[:len(enc.Peek(res))]), false, nil } -func ordOthers(leftMost []byte) int64 { +func calcOrd(leftMost []byte) int64 { var result int64 var factor int64 = 1 for i := len(leftMost) - 1; i >= 0; i-- { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 8aeea63add9e8..5cd3dbbb2757c 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -920,6 +920,7 @@ func TestConvert(t *testing.T) { wrongFunction := f.(*builtinConvertSig) wrongFunction.tp.Charset = "wrongcharset" _, err = evalBuiltinFunc(wrongFunction, chunk.Row{}) + require.Error(t, err) require.Equal(t, "[expression:1115]Unknown character set: 'wrongcharset'", err.Error()) } @@ -1402,14 +1403,10 @@ func TestBitLength(t *testing.T) { } func TestChar(t *testing.T) { + collate.SetCharsetFeatEnabledForTest(true) + defer collate.SetCharsetFeatEnabledForTest(false) ctx := createContext(t) - stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate - stmtCtx.IgnoreTruncate = true - defer func() { - stmtCtx.IgnoreTruncate = origin - }() - + ctx.GetSessionVars().StmtCtx.IgnoreTruncate = true tbl := []struct { str string iNum int64 @@ -1418,30 +1415,36 @@ func TestChar(t *testing.T) { result interface{} warnings int }{ - {"65", 66, 67.5, "utf8", "ABD", 0}, // float - {"65", 16740, 67.5, "utf8", "AAdD", 0}, // large num - {"65", -1, 67.5, nil, "A\xff\xff\xff\xffD", 0}, // nagtive int - {"a", -1, 67.5, nil, "\x00\xff\xff\xff\xffD", 0}, // invalid 'a' - // TODO: Uncomment it when issue #29685 be closed - // {"65", -1, 67.5, "utf8", nil, 1}, // with utf8, return nil - // {"a", -1, 67.5, "utf8", nil, 2}, // with utf8, return nil - // TODO: Uncomment it when gbk be added into charsetInfos - // {"1234567", 1234567, 1234567, "gbk", "謬謬謬", 0}, // test char for gbk - // {"123456789", 123456789, 123456789, "gbk", nil, 3}, // invalid 123456789 in gbk - } - for _, v := range tbl { + {"65", 66, 67.5, "utf8", "ABD", 0}, // float + {"65", 16740, 67.5, "utf8", "AAdD", 0}, // large num + {"65", -1, 67.5, nil, "A\xff\xff\xff\xffD", 0}, // negative int + {"a", -1, 67.5, nil, "\x00\xff\xff\xff\xffD", 0}, // invalid 'a' + {"65", -1, 67.5, "utf8", nil, 1}, // with utf8, return nil + {"a", -1, 67.5, "utf8", nil, 1}, // with utf8, return nil + {"1234567", 1234567, 1234567, "gbk", "\u0012謬\u0012謬\u0012謬", 0}, // test char for gbk + {"123456789", 123456789, 123456789, "gbk", nil, 1}, // invalid 123456789 in gbk + } + run := func(i int, result interface{}, warnCnt int, dts ...interface{}) { fc := funcs[ast.CharFunc] - f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(v.str, v.iNum, v.fNum, v.charset))) - require.NoError(t, err) - require.NotNil(t, f) + f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(dts...))) + require.NoError(t, err, i) + require.NotNil(t, f, i) r, err := evalBuiltinFunc(f, chunk.Row{}) - require.NoError(t, err) - trequire.DatumEqual(t, types.NewDatum(v.result), r) - if v.warnings != 0 { - warnings := ctx.GetSessionVars().StmtCtx.GetWarnings() - require.Equal(t, v.warnings, len(warnings)) + require.NoError(t, err, i) + trequire.DatumEqual(t, types.NewDatum(result), r, i) + if warnCnt != 0 { + warnings := ctx.GetSessionVars().StmtCtx.TruncateWarnings(0) + require.Equal(t, warnCnt, len(warnings), fmt.Sprintf("%d: %v", i, warnings)) } } + for i, v := range tbl { + run(i, v.result, v.warnings, v.str, v.iNum, v.fNum, v.charset) + } + // char() returns null only when the sql_mode is strict. + ctx.GetSessionVars().StrictSQLMode = true + run(-1, nil, 1, 123456, "utf8") + ctx.GetSessionVars().StrictSQLMode = false + run(-2, string([]byte{1}), 1, 123456, "utf8") } func TestCharLength(t *testing.T) { @@ -2205,11 +2208,11 @@ func TestOrd(t *testing.T) { {2.3, 50, "", false, false}, {nil, 0, "", true, false}, {"", 0, "", false, false}, - {"你好", 14990752, "", false, false}, - {"にほん", 14909867, "", false, false}, - {"한국", 15570332, "", false, false}, - {"👍", 4036989325, "", false, false}, - {"א", 55184, "", false, false}, + {"你好", 14990752, "utf8mb4", false, false}, + {"にほん", 14909867, "utf8mb4", false, false}, + {"한국", 15570332, "utf8mb4", false, false}, + {"👍", 4036989325, "utf8mb4", false, false}, + {"א", 55184, "utf8mb4", false, false}, {"abc", 97, "gbk", false, false}, {"一二三", 53947, "gbk", false, false}, {"àáèé", 43172, "gbk", false, false}, diff --git a/expression/builtin_string_vec.go b/expression/builtin_string_vec.go index 62c123faf07bd..3da555f9319ed 100644 --- a/expression/builtin_string_vec.go +++ b/expression/builtin_string_vec.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" - "golang.org/x/text/transform" ) func (b *builtinLowerSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { @@ -46,11 +45,10 @@ func (b *builtinLowerUTF8Sig) vecEvalString(input *chunk.Chunk, result *chunk.Co if err := b.args[0].VecEvalString(b.ctx, input, result); err != nil { return err } - + enc := charset.FindEncoding(b.args[0].GetType().Charset) for i := 0; i < input.NumRows(); i++ { - result.SetRaw(i, []byte(b.encoding.ToLower(result.GetString(i)))) + result.SetRaw(i, []byte(enc.ToLower(result.GetString(i)))) } - return nil } @@ -146,9 +144,9 @@ func (b *builtinUpperUTF8Sig) vecEvalString(input *chunk.Chunk, result *chunk.Co if err := b.args[0].VecEvalString(b.ctx, input, result); err != nil { return err } - + enc := charset.FindEncoding(b.args[0].GetType().Charset) for i := 0; i < input.NumRows(); i++ { - result.SetRaw(i, []byte(b.encoding.ToUpper(result.GetString(i)))) + result.SetRaw(i, []byte(enc.ToUpper(result.GetString(i)))) } return nil } @@ -677,49 +675,59 @@ func (b *builtinConvertSig) vecEvalString(input *chunk.Chunk, result *chunk.Colu if err := b.args[0].VecEvalString(b.ctx, input, expr); err != nil { return err } - // Since charset is already validated and set from getFunction(), there's no - // need to get charset from args again. - encoding, _ := charset.Lookup(b.tp.Charset) - // However, if `b.tp.Charset` is abnormally set to a wrong charset, we still - // return with error. - if encoding == nil { - return errUnknownCharacterSet.GenWithStackByArgs(b.tp.Charset) + argTp, resultTp := b.args[0].GetType(), b.tp + result.ReserveString(n) + done := vecEvalStringConvertBinary(result, n, expr, argTp, resultTp) + if done { + return nil } - decoder := encoding.NewDecoder() - isBinaryStr := types.IsBinaryStr(b.args[0].GetType()) - isRetBinary := types.IsBinaryStr(b.tp) - enc := charset.NewEncoding(b.tp.Charset) - if isRetBinary { - enc = charset.NewEncoding(b.args[0].GetType().Charset) + enc := charset.FindEncoding(resultTp.Charset) + var encBuf []byte + for i := 0; i < n; i++ { + if expr.IsNull(i) { + result.AppendNull() + continue + } + exprI := expr.GetBytes(i) + if !charset.IsValid(enc, exprI) { + encBuf, _ = enc.Transform(encBuf, exprI, charset.OpReplace) + result.AppendBytes(encBuf) + } else { + result.AppendBytes(exprI) + } } + return nil +} - result.ReserveString(n) +func vecEvalStringConvertBinary(result *chunk.Column, n int, expr *chunk.Column, + argTp, resultTp *types.FieldType) (done bool) { + var chs string + var op charset.Op + if types.IsBinaryStr(argTp) { + chs = resultTp.Charset + op = charset.OpDecode + } else if types.IsBinaryStr(resultTp) { + chs = argTp.Charset + op = charset.OpEncode + } else { + return false + } + enc := charset.FindEncoding(chs) + var encBuf []byte for i := 0; i < n; i++ { if expr.IsNull(i) { result.AppendNull() continue } - exprI := expr.GetString(i) - if isBinaryStr { - target, _, err := transform.String(decoder, exprI) - if err != nil { - result.AppendNull() - continue - } - result.AppendString(target) + encBuf, err := enc.Transform(encBuf, expr.GetBytes(i), op) + if err != nil { + result.AppendNull() } else { - if isRetBinary { - str, err := enc.EncodeString(exprI) - if err != nil { - return err - } - result.AppendString(str) - continue - } - result.AppendString(string(enc.EncodeInternal(nil, []byte(exprI)))) + result.AppendBytes(encBuf) } + continue } - return nil + return true } func (b *builtinSubstringIndexSig) vectorized() bool { @@ -2068,15 +2076,9 @@ func (b *builtinOrdSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) err return err } - charSet := b.args[0].GetType().Charset - ord, err := chooseOrdFunc(charSet) - if err != nil { - return err - } - - enc := charset.NewEncoding(charSet) - var encodedBuf []byte - + enc := charset.FindEncoding(b.args[0].GetType().Charset) + var x [4]byte + encBuf := x[:] result.ResizeInt64(n, false) result.MergeNulls(buf) i64s := result.Int64s() @@ -2084,12 +2086,15 @@ func (b *builtinOrdSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) err if result.IsNull(i) { continue } - str := buf.GetBytes(i) - encoded, err := enc.EncodeFirstChar(encodedBuf, str) + strBytes := buf.GetBytes(i) + w := len(charset.EncodingUTF8Impl.Peek(strBytes)) + encBuf, err = enc.Transform(encBuf, strBytes[:w], charset.OpEncode) if err != nil { - return err + i64s[i] = calcOrd(strBytes[:1]) + continue } - i64s[i] = ord(encoded) + // Only the first character is considered. + i64s[i] = calcOrd(encBuf[:len(enc.Peek(encBuf))]) } return nil } @@ -2231,9 +2236,6 @@ func (b *builtinBitLengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum return err } - argTp := b.args[0].GetType() - enc := charset.NewEncoding(argTp.Charset) - result.ResizeInt64(n, false) result.MergeNulls(buf) i64s := result.Int64s() @@ -2242,11 +2244,7 @@ func (b *builtinBitLengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum continue } str := buf.GetBytes(i) - dBytes, err := enc.Encode(nil, str) - if err != nil { - return err - } - i64s[i] = int64(len(dBytes) * 8) + i64s[i] = int64(len(str) * 8) } return nil } @@ -2282,7 +2280,8 @@ func (b *builtinCharSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) bufint[i] = buf[i].Int64s() } var resultBytes []byte - enc := charset.NewEncoding(b.tp.Charset) + enc := charset.FindEncoding(b.tp.Charset) + hasStrictMode := b.ctx.GetSessionVars().StrictSQLMode for i := 0; i < n; i++ { bigints = bigints[0:0] for j := 0; j < l-1; j++ { @@ -2292,12 +2291,13 @@ func (b *builtinCharSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) bigints = append(bigints, bufint[j][i]) } dBytes := b.convertToBytes(bigints) - - resultBytes, err := enc.Decode(resultBytes, dBytes) + resultBytes, err = enc.Transform(resultBytes, dBytes, charset.OpDecode) if err != nil { b.ctx.GetSessionVars().StmtCtx.AppendWarning(err) - result.AppendNull() - continue + if hasStrictMode { + result.AppendNull() + continue + } } result.AppendString(string(resultBytes)) } diff --git a/expression/collation.go b/expression/collation.go index 80a2720c8cfe4..8dc5df02e55e0 100644 --- a/expression/collation.go +++ b/expression/collation.go @@ -296,6 +296,7 @@ func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, e } func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) bool { + enc := charset.FindEncoding(ec.Charset) for _, arg := range args { if arg.GetType().Charset == ec.Charset { continue @@ -311,7 +312,10 @@ func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) if err != nil { return false } - if !isNull && !isValidString(str, ec.Charset) { + if isNull { + continue + } + if !charset.IsValidString(enc, str) { return false } } else { @@ -324,25 +328,6 @@ func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) return true } -// isValidString check if str can convert to dstChs charset without data loss. -func isValidString(str string, dstChs string) bool { - switch dstChs { - case charset.CharsetASCII: - return charset.StringValidatorASCII{}.Validate(str) == -1 - case charset.CharsetLatin1: - // For backward compatibility, we do not block SQL like select '啊' = convert('a' using latin1) collate latin1_bin; - return true - case charset.CharsetUTF8, charset.CharsetUTF8MB4: - // String in tidb is actually use utf8mb4 encoding. - return true - case charset.CharsetBinary: - // Convert to binary is always safe. - return true - default: - return charset.StringValidatorOther{Charset: dstChs}.Validate(str) == -1 - } -} - // inferCollation infers collation, charset, coercibility and check the legitimacy. func inferCollation(exprs ...Expression) *ExprCollation { if len(exprs) == 0 { diff --git a/expression/integration_test.go b/expression/integration_test.go index 79cb81daf90f6..95dc3157507ad 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -378,7 +378,6 @@ func TestConvertToBit(t *testing.T) { } func TestStringBuiltin(t *testing.T) { - t.Skip("it has been broken. Please fix it as soon as possible.") store, clean := testkit.CreateMockStore(t) defer clean() @@ -813,6 +812,25 @@ func TestStringBuiltin(t *testing.T) { "-38.04620119 38.04620115 -38.04620119,38.04620115")) } +func TestInvalidStrings(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + // Test convert invalid string. + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a binary(5));") + tk.MustExec("insert into t values (0x1e240), ('ABCDE');") + tk.MustExec("set tidb_enable_vectorized_expression = on;") + tk.MustQuery("select convert(t.a using utf8) from t;").Check(testkit.Rows("", "ABCDE")) + tk.MustQuery("select convert(0x1e240 using utf8);").Check(testkit.Rows("")) + tk.MustExec("set tidb_enable_vectorized_expression = off;") + tk.MustQuery("select convert(t.a using utf8) from t;").Check(testkit.Rows("", "ABCDE")) + tk.MustQuery("select convert(0x1e240 using utf8);").Check(testkit.Rows("")) +} + func TestEncryptionBuiltin(t *testing.T) { store, clean := testkit.CreateMockStore(t) defer clean() diff --git a/expression/util.go b/expression/util.go index d7b92329d51f6..3a793cfddc640 100644 --- a/expression/util.go +++ b/expression/util.go @@ -1145,7 +1145,7 @@ func (r *SQLDigestTextRetriever) runFetchDigestQuery(ctx context.Context, sctx s stmt += " where digest in (" + strings.Repeat("%?,", len(inValues)-1) + "%?)" } - stmtNode, err := exec.ParseWithParams(ctx, stmt, inValues...) + stmtNode, err := exec.ParseWithParamsInternal(ctx, stmt, inValues...) if err != nil { return nil, err } diff --git a/parser/charset/encoding.go b/parser/charset/encoding.go index 8bd1b92c9bcf6..25257c44e440b 100644 --- a/parser/charset/encoding.go +++ b/parser/charset/encoding.go @@ -13,212 +13,131 @@ package charset -import ( - "bytes" - "fmt" - "reflect" - "strings" - "unicode" - "unsafe" - - "github.com/cznic/mathutil" - "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tidb/parser/terror" - "golang.org/x/text/encoding" - "golang.org/x/text/transform" +// Make sure all of them implement Encoding interface. +var ( + _ Encoding = &encodingUTF8{} + _ Encoding = &encodingUTF8MB3Strict{} + _ Encoding = &encodingASCII{} + _ Encoding = &encodingLatin1{} + _ Encoding = &encodingBin{} + _ Encoding = &encodingGBK{} ) -var errInvalidCharacterString = terror.ClassParser.NewStd(mysql.ErrInvalidCharacterString) - -type EncodingLabel string - -// Format trim and change the label to lowercase. -func Format(label string) EncodingLabel { - return EncodingLabel(strings.ToLower(strings.Trim(label, "\t\n\r\f "))) -} - -// Formatted is used when the label is already trimmed and it is lowercase. -func Formatted(label string) EncodingLabel { - return EncodingLabel(label) -} - -// Encoding provide a interface to encode/decode a string with specific encoding. -type Encoding struct { - enc encoding.Encoding - name string - charLength func([]byte) int - specialCase unicode.SpecialCase -} - -// enabled indicates whether the non-utf8 encoding is used. -func (e *Encoding) enabled() bool { - return e != UTF8Encoding -} - -// Name returns the name of the current encoding. -func (e *Encoding) Name() string { - return e.name -} - -// CharLength returns the next character length in bytes. -func (e *Encoding) CharLength(bs []byte) int { - return e.charLength(bs) +// IsSupportedEncoding checks if the charset is fully supported. +func IsSupportedEncoding(charset string) bool { + _, ok := encodingMap[charset] + return ok } -// NewEncoding creates a new Encoding. -func NewEncoding(label string) *Encoding { - if len(label) == 0 { - return UTF8Encoding +// FindEncoding finds the encoding according to charset. +func FindEncoding(charset string) Encoding { + if len(charset) == 0 { + return EncodingBinImpl } - - if e, exist := encodingMap[Format(label)]; exist { + if e, exist := encodingMap[charset]; exist { return e } - return UTF8Encoding -} - -// Encode convert bytes from utf-8 charset to a specific charset. -func (e *Encoding) Encode(dest, src []byte) ([]byte, error) { - if !e.enabled() { - return src, nil - } - return e.transform(e.enc.NewEncoder(), dest, src, false) -} - -// EncodeString convert a string from utf-8 charset to a specific charset. -func (e *Encoding) EncodeString(src string) (string, error) { - if !e.enabled() { - return src, nil - } - bs, err := e.transform(e.enc.NewEncoder(), nil, Slice(src), false) - return string(bs), err -} - -// EncodeFirstChar convert first code point of bytes from utf-8 charset to a specific charset. -func (e *Encoding) EncodeFirstChar(dest, src []byte) ([]byte, error) { - srcNextLen := e.nextCharLenInSrc(src, false) - srcEnd := mathutil.Min(srcNextLen, len(src)) - if !e.enabled() { - return src[:srcEnd], nil - } - return e.transform(e.enc.NewEncoder(), dest, src[:srcEnd], false) -} - -// EncodeInternal convert bytes from utf-8 charset to a specific charset, we actually do not do the real convert, just find the inconvertible character and use ? replace. -// The code below is equivalent to -// expr, _ := e.Encode(dest, src) -// ret, _ := e.Decode(nil, expr) -// return ret -func (e *Encoding) EncodeInternal(dest, src []byte) []byte { - if !e.enabled() { - return src - } - if dest == nil { - dest = make([]byte, 0, len(src)) - } - var srcOffset int + return EncodingBinImpl +} + +var encodingMap = map[string]Encoding{ + CharsetUTF8MB4: EncodingUTF8Impl, + CharsetUTF8: EncodingUTF8Impl, + CharsetGBK: EncodingGBKImpl, + CharsetLatin1: EncodingLatin1Impl, + CharsetBin: EncodingBinImpl, + CharsetASCII: EncodingASCIIImpl, +} + +// Encoding provide encode/decode functions for a string with a specific charset. +type Encoding interface { + // Name is the name of the encoding. + Name() string + // Tp is the type of the encoding. + Tp() EncodingTp + // Peek returns the next char. + Peek(src []byte) []byte + // Foreach iterates the characters in in current encoding. + Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) + // Transform map the bytes in src to dest according to Op. + Transform(dest, src []byte, op Op) ([]byte, error) + // ToUpper change a string to uppercase. + ToUpper(src string) string + // ToLower change a string to lowercase. + ToLower(src string) string +} + +type EncodingTp int8 + +const ( + EncodingTpNone EncodingTp = iota + EncodingTpUTF8 + EncodingTpUTF8MB3Strict + EncodingTpASCII + EncodingTpLatin1 + EncodingTpBin + EncodingTpGBK +) - var buf [4]byte - transformer := e.enc.NewEncoder() - for srcOffset < len(src) { - length := UTF8Encoding.CharLength(src[srcOffset:]) - _, _, err := transformer.Transform(buf[:], src[srcOffset:srcOffset+length], true) - if err != nil { - dest = append(dest, byte('?')) - } else { - dest = append(dest, src[srcOffset:srcOffset+length]...) - } - srcOffset += length - } +// Op is used by Encoding.Transform. +type Op int16 + +const ( + opFromUTF8 Op = 1 << iota + opToUTF8 + opTruncateTrim + opTruncateReplace + opCollectFrom + opCollectTo + opSkipError +) - return dest -} +const ( + OpReplace = opFromUTF8 | opTruncateReplace | opCollectFrom | opSkipError + OpEncode = opFromUTF8 | opTruncateTrim | opCollectTo + OpEncodeNoErr = OpEncode | opSkipError + OpEncodeReplace = opFromUTF8 | opTruncateReplace | opCollectTo + OpDecode = opToUTF8 | opTruncateTrim | opCollectTo + OpDecodeReplace = opToUTF8 | opTruncateReplace | opCollectTo +) -// Decode convert bytes from a specific charset to utf-8 charset. -func (e *Encoding) Decode(dest, src []byte) ([]byte, error) { - if !e.enabled() { - return src, nil - } - return e.transform(e.enc.NewDecoder(), dest, src, true) +// IsValid checks whether the bytes is valid in current encoding. +func IsValid(e Encoding, src []byte) bool { + isValid := true + e.Foreach(src, opFromUTF8, func(from, to []byte, ok bool) bool { + isValid = ok + return ok + }) + return isValid } -// DecodeString convert a string from a specific charset to utf-8 charset. -func (e *Encoding) DecodeString(src string) (string, error) { - if !e.enabled() { - return src, nil - } - bs, err := e.transform(e.enc.NewDecoder(), nil, Slice(src), true) - return string(bs), err +// IsValidString is a string version of IsValid. +func IsValidString(e Encoding, str string) bool { + return IsValid(e, Slice(str)) } -func (e *Encoding) transform(transformer transform.Transformer, dest, src []byte, isDecoding bool) ([]byte, error) { - if len(dest) < len(src) { - dest = make([]byte, len(src)*2) - } - if len(src) == 0 { - return src, nil - } - var destOffset, srcOffset int - var encodingErr error - for { - srcNextLen := e.nextCharLenInSrc(src[srcOffset:], isDecoding) - srcEnd := mathutil.Min(srcOffset+srcNextLen, len(src)) - nDest, nSrc, err := transformer.Transform(dest[destOffset:], src[srcOffset:srcEnd], false) - if err == transform.ErrShortDst { - dest = enlargeCapacity(dest) - } else if err != nil || isDecoding && beginWithReplacementChar(dest[destOffset:destOffset+nDest]) { - if encodingErr == nil { - encodingErr = e.generateErr(src[srcOffset:], srcNextLen) - } - dest[destOffset] = byte('?') - nDest, nSrc = 1, srcNextLen // skip the source bytes that cannot be decoded normally. +// CountValidBytes counts the first valid bytes in src that +// can be encode to the current encoding. +func CountValidBytes(e Encoding, src []byte) int { + nSrc := 0 + e.Foreach(src, opFromUTF8, func(from, to []byte, ok bool) bool { + if ok { + nSrc += len(from) } - destOffset += nDest - srcOffset += nSrc - // The source bytes are exhausted. - if srcOffset >= len(src) { - return dest[:destOffset], encodingErr + return ok + }) + return nSrc +} + +// CountValidBytesDecode counts the first valid bytes in src that +// can be decode to utf-8. +func CountValidBytesDecode(e Encoding, src []byte) int { + nSrc := 0 + e.Foreach(src, opToUTF8, func(from, to []byte, ok bool) bool { + if ok { + nSrc += len(from) } - } -} - -func (e *Encoding) nextCharLenInSrc(srcRest []byte, isDecoding bool) int { - if isDecoding { - if e.charLength != nil { - return e.charLength(srcRest) - } - return len(srcRest) - } - return UTF8Encoding.CharLength(srcRest) -} - -func enlargeCapacity(dest []byte) []byte { - newDest := make([]byte, len(dest)*2) - copy(newDest, dest) - return newDest -} - -func (e *Encoding) generateErr(srcRest []byte, srcNextLen int) error { - cutEnd := mathutil.Min(srcNextLen, len(srcRest)) - invalidBytes := fmt.Sprintf("%X", string(srcRest[:cutEnd])) - return errInvalidCharacterString.GenWithStackByArgs(e.name, invalidBytes) -} - -// replacementBytes are bytes for the replacement rune 0xfffd. -var replacementBytes = []byte{0xEF, 0xBF, 0xBD} - -// beginWithReplacementChar check if dst has the prefix '0xEFBFBD'. -func beginWithReplacementChar(dst []byte) bool { - return bytes.HasPrefix(dst, replacementBytes) -} - -// Slice converts string to slice without copy. -// Use at your own risk. -func Slice(s string) (b []byte) { - pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) - pBytes.Data = pString.Data - pBytes.Len = pString.Len - pBytes.Cap = pString.Len - return + return ok + }) + return nSrc } diff --git a/parser/charset/encoding_ascii.go b/parser/charset/encoding_ascii.go new file mode 100644 index 0000000000000..df5fed9c3bce2 --- /dev/null +++ b/parser/charset/encoding_ascii.go @@ -0,0 +1,71 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + go_unicode "unicode" + + "golang.org/x/text/encoding" +) + +// EncodingASCIIImpl is the instance of encodingASCII +var EncodingASCIIImpl = &encodingASCII{encodingBase{enc: encoding.Nop}} + +func init() { + EncodingASCIIImpl.self = EncodingASCIIImpl +} + +// encodingASCII is the ASCII encoding. +type encodingASCII struct { + encodingBase +} + +// Name implements Encoding interface. +func (e *encodingASCII) Name() string { + return CharsetASCII +} + +// Tp implements Encoding interface. +func (e *encodingASCII) Tp() EncodingTp { + return EncodingTpASCII +} + +// Peek implements Encoding interface. +func (e *encodingASCII) Peek(src []byte) []byte { + if len(src) == 0 { + return src + } + return src[:1] +} + +func (e *encodingASCII) Transform(dest, src []byte, op Op) ([]byte, error) { + if IsValid(e, src) { + return src, nil + } + return e.encodingBase.Transform(dest, src, op) +} + +func (e *encodingASCII) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { + for i, w := 0, 0; i < len(src); i += w { + w = 1 + ok := true + if src[i] > go_unicode.MaxASCII { + w = len(EncodingUTF8Impl.Peek(src[i:])) + ok = false + } + if !fn(src[i:i+w], src[i:i+w], ok) { + return + } + } +} diff --git a/parser/charset/encoding_base.go b/parser/charset/encoding_base.go new file mode 100644 index 0000000000000..275db24c5a3d6 --- /dev/null +++ b/parser/charset/encoding_base.go @@ -0,0 +1,117 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "bytes" + "fmt" + "reflect" + "strings" + "unsafe" + + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/parser/terror" + "golang.org/x/text/encoding" + "golang.org/x/text/transform" +) + +var errInvalidCharacterString = terror.ClassParser.NewStd(mysql.ErrInvalidCharacterString) + +// encodingBase defines some generic functions. +type encodingBase struct { + enc encoding.Encoding + self Encoding +} + +func (b encodingBase) ToUpper(src string) string { + return strings.ToUpper(src) +} + +func (b encodingBase) ToLower(src string) string { + return strings.ToLower(src) +} + +func (b encodingBase) Transform(dest, src []byte, op Op) (result []byte, err error) { + if dest == nil { + dest = make([]byte, len(src)) + } + dest = dest[:0] + b.self.Foreach(src, op, func(from, to []byte, ok bool) bool { + if !ok { + if err == nil && (op&opSkipError == 0) { + err = generateEncodingErr(b.self.Name(), from) + } + if op&opTruncateTrim != 0 { + return false + } + if op&opTruncateReplace != 0 { + dest = append(dest, '?') + return true + } + } + if op&opCollectFrom != 0 { + dest = append(dest, from...) + } else if op&opCollectTo != 0 { + dest = append(dest, to...) + } + return true + }) + return dest, err +} + +func (b encodingBase) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { + var tfm transform.Transformer + var peek func([]byte) []byte + if op&opFromUTF8 != 0 { + tfm = b.enc.NewEncoder() + peek = EncodingUTF8Impl.Peek + } else { + tfm = b.enc.NewDecoder() + peek = b.self.Peek + } + var buf [4]byte + for i, w := 0, 0; i < len(src); i += w { + w = len(peek(src[i:])) + nDst, _, err := tfm.Transform(buf[:], src[i:i+w], false) + meetErr := err != nil || (op&opToUTF8 != 0 && beginWithReplacementChar(buf[:nDst])) + if !fn(src[i:i+w], buf[:nDst], !meetErr) { + return + } + } +} + +// replacementBytes are bytes for the replacement rune 0xfffd. +var replacementBytes = []byte{0xEF, 0xBF, 0xBD} + +// beginWithReplacementChar check if dst has the prefix '0xEFBFBD'. +func beginWithReplacementChar(dst []byte) bool { + return bytes.HasPrefix(dst, replacementBytes) +} + +// generateEncodingErr generates an invalid string in charset error. +func generateEncodingErr(name string, invalidBytes []byte) error { + arg := fmt.Sprintf("%X", invalidBytes) + return errInvalidCharacterString.FastGenByArgs(name, arg) +} + +// Slice converts string to slice without copy. +// Use at your own risk. +func Slice(s string) (b []byte) { + pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pBytes.Data = pString.Data + pBytes.Len = pString.Len + pBytes.Cap = pString.Len + return +} diff --git a/parser/charset/encoding_bin.go b/parser/charset/encoding_bin.go new file mode 100644 index 0000000000000..30fd87644c571 --- /dev/null +++ b/parser/charset/encoding_bin.go @@ -0,0 +1,61 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "golang.org/x/text/encoding" +) + +// EncodingBinImpl is the instance of encodingBin. +var EncodingBinImpl = &encodingBin{encodingBase{enc: encoding.Nop}} + +func init() { + EncodingBinImpl.self = EncodingBinImpl +} + +// encodingBin is the binary encoding. +type encodingBin struct { + encodingBase +} + +// Name implements Encoding interface. +func (e *encodingBin) Name() string { + return CharsetBin +} + +// Tp implements Encoding interface. +func (e *encodingBin) Tp() EncodingTp { + return EncodingTpBin +} + +// Peek implements Encoding interface. +func (e *encodingBin) Peek(src []byte) []byte { + if len(src) == 0 { + return src + } + return src[:1] +} + +// Foreach implements Encoding interface. +func (e *encodingBin) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { + for i := 0; i < len(src); i++ { + if !fn(src[i:i+1], src[i:i+1], true) { + return + } + } +} + +func (e *encodingBin) Transform(dest, src []byte, op Op) ([]byte, error) { + return src, nil +} diff --git a/parser/charset/encoding_gbk.go b/parser/charset/encoding_gbk.go new file mode 100644 index 0000000000000..3dc3fe14fed6c --- /dev/null +++ b/parser/charset/encoding_gbk.go @@ -0,0 +1,93 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "strings" + "unicode" + + "golang.org/x/text/encoding/simplifiedchinese" +) + +// EncodingGBKImpl is the instance of encodingGBK +var EncodingGBKImpl = &encodingGBK{encodingBase{enc: simplifiedchinese.GBK}} + +func init() { + EncodingGBKImpl.self = EncodingGBKImpl +} + +// encodingGBK is GBK encoding. +type encodingGBK struct { + encodingBase +} + +// Name implements Encoding interface. +func (e *encodingGBK) Name() string { + return CharsetGBK +} + +// Tp implements Encoding interface. +func (e *encodingGBK) Tp() EncodingTp { + return EncodingTpGBK +} + +// Peek implements Encoding interface. +func (e *encodingGBK) Peek(src []byte) []byte { + charLen := 2 + if len(src) == 0 || src[0] < 0x80 { + // A byte in the range 00–7F is a single byte that means the same thing as it does in ASCII. + charLen = 1 + } + if charLen < len(src) { + return src[:charLen] + } + return src +} + +// ToUpper implements Encoding interface. +func (e *encodingGBK) ToUpper(d string) string { + return strings.ToUpperSpecial(GBKCase, d) +} + +// ToLower implements Encoding interface. +func (e *encodingGBK) ToLower(d string) string { + return strings.ToLowerSpecial(GBKCase, d) +} + +// GBKCase follows https://dev.mysql.com/worklog/task/?id=4583. +var GBKCase = unicode.SpecialCase{ + unicode.CaseRange{Lo: 0x00E0, Hi: 0x00E1, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00E8, Hi: 0x00EA, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00EC, Hi: 0x00ED, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00F2, Hi: 0x00F3, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00F9, Hi: 0x00FA, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x00FC, Hi: 0x00FC, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x0101, Hi: 0x0101, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x0113, Hi: 0x0113, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x011B, Hi: 0x011B, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x012B, Hi: 0x012B, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x0144, Hi: 0x0144, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x0148, Hi: 0x0148, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x014D, Hi: 0x014D, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x016B, Hi: 0x016B, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01CE, Hi: 0x01CE, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D0, Hi: 0x01D0, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D2, Hi: 0x01D2, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D4, Hi: 0x01D4, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D6, Hi: 0x01D6, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01D8, Hi: 0x01D8, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01DA, Hi: 0x01DA, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x01DC, Hi: 0x01DC, Delta: [unicode.MaxCase]rune{0, 0, 0}}, + unicode.CaseRange{Lo: 0x216A, Hi: 0x216B, Delta: [unicode.MaxCase]rune{0, 0, 0}}, +} diff --git a/parser/charset/encoding_latin1.go b/parser/charset/encoding_latin1.go new file mode 100644 index 0000000000000..1d2992b87642d --- /dev/null +++ b/parser/charset/encoding_latin1.go @@ -0,0 +1,51 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import "golang.org/x/text/encoding" + +// EncodingLatin1Impl is the instance of encodingLatin1. +// TiDB uses utf8 implementation for latin1 charset because of the backward compatibility. +var EncodingLatin1Impl = &encodingLatin1{encodingUTF8{encodingBase{enc: encoding.Nop}}} + +func init() { + EncodingLatin1Impl.self = EncodingLatin1Impl +} + +// encodingLatin1 compatibles with latin1 in old version TiDB. +type encodingLatin1 struct { + encodingUTF8 +} + +// Name implements Encoding interface. +func (e *encodingLatin1) Name() string { + return CharsetLatin1 +} + +// Peek implements Encoding interface. +func (e *encodingLatin1) Peek(src []byte) []byte { + if len(src) == 0 { + return src + } + return src[:1] +} + +// Tp implements Encoding interface. +func (e *encodingLatin1) Tp() EncodingTp { + return EncodingTpLatin1 +} + +func (e *encodingLatin1) Transform(dest, src []byte, op Op) ([]byte, error) { + return src, nil +} diff --git a/parser/charset/encoding_table.go b/parser/charset/encoding_table.go index 2de9d957d923a..2780272296acb 100644 --- a/parser/charset/encoding_table.go +++ b/parser/charset/encoding_table.go @@ -14,11 +14,6 @@ package charset import ( - "strings" - go_unicode "unicode" - "unicode/utf8" - - "github.com/cznic/mathutil" "golang.org/x/text/encoding" "golang.org/x/text/encoding/charmap" "golang.org/x/text/encoding/japanese" @@ -26,28 +21,20 @@ import ( "golang.org/x/text/encoding/simplifiedchinese" "golang.org/x/text/encoding/traditionalchinese" "golang.org/x/text/encoding/unicode" + "strings" ) -var encodingMap = map[EncodingLabel]*Encoding{ - CharsetUTF8MB4: UTF8Encoding, - CharsetUTF8: UTF8Encoding, - CharsetGBK: GBKEncoding, - CharsetLatin1: LatinEncoding, - CharsetBin: BinaryEncoding, - CharsetASCII: ASCIIEncoding, -} - // Lookup returns the encoding with the specified label, and its canonical // name. It returns nil and the empty string if label is not one of the // standard encodings for HTML. Matching is case-insensitive and ignores // leading and trailing whitespace. func Lookup(label string) (e encoding.Encoding, name string) { label = strings.ToLower(strings.Trim(label, "\t\n\r\f ")) - return lookup(Formatted(label)) + return lookup(label) } -func lookup(label EncodingLabel) (e encoding.Encoding, name string) { - enc := encodings[string(label)] +func lookup(label string) (e encoding.Encoding, name string) { + enc := encodings[label] return enc.e, enc.name } @@ -274,179 +261,3 @@ var encodings = map[string]struct { "utf-16le": {unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM), "utf-16le"}, "x-user-defined": {charmap.XUserDefined, "x-user-defined"}, } - -// TruncateStrategy indicates the way to handle the invalid strings in specific charset. -// - TruncateStrategyEmpty: returns an empty string. -// - TruncateStrategyTrim: returns the valid prefix part of string. -// - TruncateStrategyReplace: returns the whole string, but the invalid characters are replaced with '?'. -type TruncateStrategy int8 - -const ( - TruncateStrategyEmpty TruncateStrategy = iota - TruncateStrategyTrim - TruncateStrategyReplace -) - -var _ StringValidator = StringValidatorASCII{} -var _ StringValidator = StringValidatorUTF8{} -var _ StringValidator = StringValidatorOther{} - -// StringValidator is used to check if a string is valid in the specific charset. -type StringValidator interface { - Validate(str string) (invalidPos int) - Truncate(str string, strategy TruncateStrategy) (result string, invalidPos int) -} - -// StringValidatorASCII checks whether a string is valid ASCII string. -type StringValidatorASCII struct{} - -// Validate checks whether the string is valid in the given charset. -func (s StringValidatorASCII) Validate(str string) int { - _, invalidPos := s.Truncate(str, TruncateStrategyEmpty) - return invalidPos -} - -// Truncate implement the interface StringValidator. -func (s StringValidatorASCII) Truncate(str string, strategy TruncateStrategy) (string, int) { - invalidPos := -1 - for i := 0; i < len(str); i++ { - if str[i] > go_unicode.MaxASCII { - invalidPos = i - break - } - } - if invalidPos == -1 { - // Quick check passed. - return str, -1 - } - switch strategy { - case TruncateStrategyEmpty: - return "", invalidPos - case TruncateStrategyTrim: - return str[:invalidPos], invalidPos - case TruncateStrategyReplace: - result := make([]byte, 0, len(str)) - for i, w := 0, 0; i < len(str); i += w { - w = 1 - if str[i] > go_unicode.MaxASCII { - w = UTF8Encoding.CharLength(Slice(str)[i:]) - w = mathutil.Min(w, len(str)-i) - result = append(result, '?') - continue - } - result = append(result, str[i:i+w]...) - } - return string(result), invalidPos - } - return str, -1 -} - -// StringValidatorUTF8 checks whether a string is valid UTF8 string. -type StringValidatorUTF8 struct { - IsUTF8MB4 bool // Distinguish between "utf8" and "utf8mb4" - CheckMB4ValueInUTF8 bool -} - -// Validate checks whether the string is valid in the given charset. -func (s StringValidatorUTF8) Validate(str string) int { - _, invalidPos := s.Truncate(str, TruncateStrategyEmpty) - return invalidPos -} - -// Truncate implement the interface StringValidator. -func (s StringValidatorUTF8) Truncate(str string, strategy TruncateStrategy) (string, int) { - if str == "" { - return str, -1 - } - if s.IsUTF8MB4 && utf8.ValidString(str) { - // Quick check passed. - return str, -1 - } - doMB4CharCheck := !s.IsUTF8MB4 && s.CheckMB4ValueInUTF8 - var result []byte - if strategy == TruncateStrategyReplace { - result = make([]byte, 0, len(str)) - } - invalidPos := -1 - for i, w := 0, 0; i < len(str); i += w { - var rv rune - rv, w = utf8.DecodeRuneInString(str[i:]) - if (rv == utf8.RuneError && w == 1) || (w > 3 && doMB4CharCheck) { - if invalidPos == -1 { - invalidPos = i - } - switch strategy { - case TruncateStrategyEmpty: - return "", invalidPos - case TruncateStrategyTrim: - return str[:i], invalidPos - case TruncateStrategyReplace: - result = append(result, '?') - continue - } - } - if strategy == TruncateStrategyReplace { - result = append(result, str[i:i+w]...) - } - } - if strategy == TruncateStrategyReplace { - return string(result), invalidPos - } - return str, -1 -} - -// StringValidatorOther checks whether a string is valid string in given charset. -type StringValidatorOther struct { - Charset string -} - -// Validate checks whether the string is valid in the given charset. -func (s StringValidatorOther) Validate(str string) int { - _, invalidPos := s.Truncate(str, TruncateStrategyEmpty) - return invalidPos -} - -// Truncate implement the interface StringValidator. -func (s StringValidatorOther) Truncate(str string, strategy TruncateStrategy) (string, int) { - if str == "" { - return str, -1 - } - enc := NewEncoding(s.Charset) - if !enc.enabled() { - return str, -1 - } - var result []byte - if strategy == TruncateStrategyReplace { - result = make([]byte, 0, len(str)) - } - var buf [4]byte - strBytes := Slice(str) - transformer := enc.enc.NewEncoder() - invalidPos := -1 - for i, w := 0, 0; i < len(str); i += w { - w = UTF8Encoding.CharLength(strBytes[i:]) - w = mathutil.Min(w, len(str)-i) - _, _, err := transformer.Transform(buf[:], strBytes[i:i+w], true) - if err != nil { - if invalidPos == -1 { - invalidPos = i - } - switch strategy { - case TruncateStrategyEmpty: - return "", invalidPos - case TruncateStrategyTrim: - return str[:i], invalidPos - case TruncateStrategyReplace: - result = append(result, '?') - continue - } - } - if strategy == TruncateStrategyReplace { - result = append(result, strBytes[i:i+w]...) - } - } - if strategy == TruncateStrategyReplace { - return string(result), invalidPos - } - return str, -1 -} diff --git a/parser/charset/encoding_test.go b/parser/charset/encoding_test.go index 51f5b53b3e2fd..a78aa640d8be5 100644 --- a/parser/charset/encoding_test.go +++ b/parser/charset/encoding_test.go @@ -24,21 +24,21 @@ import ( ) func TestEncoding(t *testing.T) { - enc := charset.NewEncoding(charset.CharsetGBK) + enc := charset.FindEncoding(charset.CharsetGBK) require.Equal(t, charset.CharsetGBK, enc.Name()) txt := []byte("一二三四") e, _ := charset.Lookup("gbk") gbkEncodedTxt, _, err := transform.Bytes(e.NewEncoder(), txt) require.NoError(t, err) - result, err := enc.Decode(nil, gbkEncodedTxt) + result, err := enc.Transform(nil, gbkEncodedTxt, charset.OpDecode) require.NoError(t, err) require.Equal(t, txt, result) - gbkEncodedTxt2, err := enc.Encode(nil, txt) + gbkEncodedTxt2, err := enc.Transform(nil, txt, charset.OpEncode) require.NoError(t, err) require.Equal(t, gbkEncodedTxt2, gbkEncodedTxt) - result, err = enc.Decode(nil, gbkEncodedTxt2) + result, err = enc.Transform(nil, gbkEncodedTxt2, charset.OpDecode) require.NoError(t, err) require.Equal(t, txt, result) @@ -58,7 +58,7 @@ func TestEncoding(t *testing.T) { } for _, tc := range GBKCases { cmt := fmt.Sprintf("%v", tc) - result, err = enc.Decode(nil, []byte(tc.utf8Str)) + result, err := enc.Transform(nil, []byte(tc.utf8Str), charset.OpDecodeReplace) if tc.isValid { require.NoError(t, err, cmt) } else { @@ -78,7 +78,7 @@ func TestEncoding(t *testing.T) { } for _, tc := range utf8Cases { cmt := fmt.Sprintf("%v", tc) - result, err = enc.Encode(nil, []byte(tc.utf8Str)) + result, err := enc.Transform(nil, []byte(tc.utf8Str), charset.OpEncodeReplace) if tc.isValid { require.NoError(t, err, cmt) } else { @@ -88,111 +88,54 @@ func TestEncoding(t *testing.T) { } } -func TestStringValidatorASCII(t *testing.T) { - v := charset.StringValidatorASCII{} - testCases := []struct { - str string - strategy charset.TruncateStrategy - expected string - invalidPos int - }{ - {"", charset.TruncateStrategyEmpty, "", -1}, - {"qwerty", charset.TruncateStrategyEmpty, "qwerty", -1}, - {"qwÊrty", charset.TruncateStrategyEmpty, "", 2}, - {"qwÊrty", charset.TruncateStrategyTrim, "qw", 2}, - {"qwÊrty", charset.TruncateStrategyReplace, "qw?rty", 2}, - {"中文", charset.TruncateStrategyEmpty, "", 0}, - {"中文?qwert", charset.TruncateStrategyTrim, "", 0}, - {"中文?qwert", charset.TruncateStrategyReplace, "???qwert", 0}, - } - for _, tc := range testCases { - msg := fmt.Sprintf("%v", tc) - actual, invalidPos := v.Truncate(tc.str, tc.strategy) - require.Equal(t, tc.expected, actual, msg) - require.Equal(t, tc.invalidPos, invalidPos, msg) - } - require.Equal(t, -1, v.Validate("qwerty")) - require.Equal(t, 2, v.Validate("qwÊrty")) - require.Equal(t, 0, v.Validate("中文")) -} - -func TestStringValidatorUTF8(t *testing.T) { - // Test charset "utf8mb4". - v := charset.StringValidatorUTF8{IsUTF8MB4: true} +func TestEncodingValidate(t *testing.T) { oxfffefd := string([]byte{0xff, 0xfe, 0xfd}) testCases := []struct { - str string - strategy charset.TruncateStrategy - expected string - invalidPos int - }{ - {"", charset.TruncateStrategyEmpty, "", -1}, - {"qwerty", charset.TruncateStrategyEmpty, "qwerty", -1}, - {"qwÊrty", charset.TruncateStrategyEmpty, "qwÊrty", -1}, - {"qwÊ合法字符串", charset.TruncateStrategyEmpty, "qwÊ合法字符串", -1}, - {"😂", charset.TruncateStrategyEmpty, "😂", -1}, - {oxfffefd, charset.TruncateStrategyEmpty, "", 0}, - {oxfffefd, charset.TruncateStrategyReplace, "???", 0}, - {"中文" + oxfffefd, charset.TruncateStrategyTrim, "中文", 6}, - {"中文" + oxfffefd, charset.TruncateStrategyReplace, "中文???", 6}, - {string(utf8.RuneError), charset.TruncateStrategyEmpty, "�", -1}, - } - for _, tc := range testCases { - msg := fmt.Sprintf("%v", tc) - actual, invalidPos := v.Truncate(tc.str, tc.strategy) - require.Equal(t, tc.expected, actual, msg) - require.Equal(t, tc.invalidPos, invalidPos, msg) - } - // Test charset "utf8" with checking mb4 value. - v = charset.StringValidatorUTF8{IsUTF8MB4: false, CheckMB4ValueInUTF8: true} - testCases = []struct { - str string - strategy charset.TruncateStrategy - expected string - invalidPos int - }{ - {"", charset.TruncateStrategyEmpty, "", -1}, - {"qwerty", charset.TruncateStrategyEmpty, "qwerty", -1}, - {"qwÊrty", charset.TruncateStrategyEmpty, "qwÊrty", -1}, - {"qwÊ合法字符串", charset.TruncateStrategyEmpty, "qwÊ合法字符串", -1}, - {"😂", charset.TruncateStrategyEmpty, "", 0}, - {"😂", charset.TruncateStrategyReplace, "?", 0}, - {"valid_str😂", charset.TruncateStrategyReplace, "valid_str?", 9}, - {oxfffefd, charset.TruncateStrategyEmpty, "", 0}, - {oxfffefd, charset.TruncateStrategyReplace, "???", 0}, - {"中文" + oxfffefd, charset.TruncateStrategyTrim, "中文", 6}, - {"中文" + oxfffefd, charset.TruncateStrategyReplace, "中文???", 6}, - {string(utf8.RuneError), charset.TruncateStrategyEmpty, "�", -1}, - } - for _, tc := range testCases { - msg := fmt.Sprintf("%v", tc) - actual, invalidPos := v.Truncate(tc.str, tc.strategy) - require.Equal(t, tc.expected, actual, msg) - require.Equal(t, tc.invalidPos, invalidPos, msg) - } -} - -func TestStringValidatorGBK(t *testing.T) { - v := charset.StringValidatorOther{Charset: "gbk"} - testCases := []struct { - str string - strategy charset.TruncateStrategy - expected string - invalidPos int + chs string + str string + expected string + nSrc int + ok bool }{ - {"", charset.TruncateStrategyEmpty, "", -1}, - {"asdf", charset.TruncateStrategyEmpty, "asdf", -1}, - {"中文", charset.TruncateStrategyEmpty, "中文", -1}, - {"À", charset.TruncateStrategyEmpty, "", 0}, - {"À", charset.TruncateStrategyReplace, "?", 0}, - {"中文À中文", charset.TruncateStrategyTrim, "中文", 6}, - {"中文À中文", charset.TruncateStrategyReplace, "中文?中文", 6}, - {"asdfÀ", charset.TruncateStrategyReplace, "asdf?", 4}, + {charset.CharsetASCII, "", "", 0, true}, + {charset.CharsetASCII, "qwerty", "qwerty", 6, true}, + {charset.CharsetASCII, "qwÊrty", "qw?rty", 2, false}, + {charset.CharsetASCII, "中文", "??", 0, false}, + {charset.CharsetASCII, "中文?qwert", "???qwert", 0, false}, + {charset.CharsetUTF8MB4, "", "", 0, true}, + {charset.CharsetUTF8MB4, "qwerty", "qwerty", 6, true}, + {charset.CharsetUTF8MB4, "qwÊrty", "qwÊrty", 7, true}, + {charset.CharsetUTF8MB4, "qwÊ合法字符串", "qwÊ合法字符串", 19, true}, + {charset.CharsetUTF8MB4, "😂", "😂", 4, true}, + {charset.CharsetUTF8MB4, oxfffefd, "???", 0, false}, + {charset.CharsetUTF8MB4, "中文" + oxfffefd, "中文???", 6, false}, + {charset.CharsetUTF8MB4, string(utf8.RuneError), "�", 3, true}, + {charset.CharsetUTF8, "", "", 0, true}, + {charset.CharsetUTF8, "qwerty", "qwerty", 6, true}, + {charset.CharsetUTF8, "qwÊrty", "qwÊrty", 7, true}, + {charset.CharsetUTF8, "qwÊ合法字符串", "qwÊ合法字符串", 19, true}, + {charset.CharsetUTF8, "😂", "?", 0, false}, + {charset.CharsetUTF8, "valid_str😂", "valid_str?", 9, false}, + {charset.CharsetUTF8, oxfffefd, "???", 0, false}, + {charset.CharsetUTF8, "中文" + oxfffefd, "中文???", 6, false}, + {charset.CharsetUTF8, string(utf8.RuneError), "�", 3, true}, + {charset.CharsetGBK, "", "", 0, true}, + {charset.CharsetGBK, "asdf", "asdf", 4, true}, + {charset.CharsetGBK, "中文", "中文", 6, true}, + {charset.CharsetGBK, "À", "?", 0, false}, + {charset.CharsetGBK, "中文À中文", "中文?中文", 6, false}, + {charset.CharsetGBK, "asdfÀ", "asdf?", 4, false}, } for _, tc := range testCases { msg := fmt.Sprintf("%v", tc) - actual, invalidPos := v.Truncate(tc.str, tc.strategy) - require.Equal(t, tc.expected, actual, msg) - require.Equal(t, tc.invalidPos, invalidPos, msg) + enc := charset.FindEncoding(tc.chs) + if tc.chs == charset.CharsetUTF8 { + enc = charset.EncodingUTF8MB3StrictImpl + } + strBytes := []byte(tc.str) + ok := charset.IsValid(enc, strBytes) + require.Equal(t, tc.ok, ok, msg) + replace, _ := enc.Transform(nil, strBytes, charset.OpReplace) + require.Equal(t, tc.expected, string(replace), msg) } } diff --git a/parser/charset/encoding_utf8.go b/parser/charset/encoding_utf8.go new file mode 100644 index 0000000000000..871a5e5ec33c1 --- /dev/null +++ b/parser/charset/encoding_utf8.go @@ -0,0 +1,114 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "unicode/utf8" + + "golang.org/x/text/encoding" +) + +// EncodingUTF8Impl is the instance of encodingUTF8. +var EncodingUTF8Impl = &encodingUTF8{encodingBase{enc: encoding.Nop}} + +// EncodingUTF8MB3StrictImpl is the instance of encodingUTF8MB3Strict. +var EncodingUTF8MB3StrictImpl = &encodingUTF8MB3Strict{ + encodingUTF8{ + encodingBase{ + enc: encoding.Nop, + }, + }, +} + +func init() { + EncodingUTF8Impl.self = EncodingUTF8Impl + EncodingUTF8MB3StrictImpl.self = EncodingUTF8MB3StrictImpl +} + +// encodingUTF8 is TiDB's default encoding. +type encodingUTF8 struct { + encodingBase +} + +// Name implements Encoding interface. +func (e *encodingUTF8) Name() string { + return CharsetUTF8MB4 +} + +// Tp implements Encoding interface. +func (e *encodingUTF8) Tp() EncodingTp { + return EncodingTpUTF8 +} + +// Peek implements Encoding interface. +func (e *encodingUTF8) Peek(src []byte) []byte { + nextLen := 4 + if len(src) == 0 || src[0] < 0x80 { + nextLen = 1 + } else if src[0] < 0xe0 { + nextLen = 2 + } else if src[0] < 0xf0 { + nextLen = 3 + } + if len(src) < nextLen { + return src + } + return src[:nextLen] +} + +// Transform implements Encoding interface. +func (e *encodingUTF8) Transform(dest, src []byte, op Op) ([]byte, error) { + if IsValid(e, src) { + return src, nil + } + return e.encodingBase.Transform(dest, src, op) +} + +// Foreach implements Encoding interface. +func (e *encodingUTF8) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { + var rv rune + for i, w := 0, 0; i < len(src); i += w { + rv, w = utf8.DecodeRune(src[i:]) + meetErr := rv == utf8.RuneError && w == 1 + if !fn(src[i:i+w], src[i:i+w], !meetErr) { + return + } + } +} + +// encodingUTF8MB3Strict is the strict mode of EncodingUTF8MB3. +// MB4 characters are considered invalid. +type encodingUTF8MB3Strict struct { + encodingUTF8 +} + +// Foreach implements Encoding interface. +func (e *encodingUTF8MB3Strict) Foreach(src []byte, op Op, fn func(srcCh, dstCh []byte, ok bool) bool) { + for i, w := 0, 0; i < len(src); i += w { + var rv rune + rv, w = utf8.DecodeRune(src[i:]) + meetErr := (rv == utf8.RuneError && w == 1) || w > 3 + if !fn(src[i:i+w], src[i:i+w], !meetErr) { + return + } + } +} + +// Transform implements Encoding interface. +func (e *encodingUTF8MB3Strict) Transform(dest, src []byte, op Op) ([]byte, error) { + if IsValid(e, src) { + return src, nil + } + return e.encodingBase.Transform(dest, src, op) +} diff --git a/parser/charset/gbk.go b/parser/charset/gbk.go deleted file mode 100644 index 5686c6e1b50f0..0000000000000 --- a/parser/charset/gbk.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package charset - -import "golang.org/x/text/encoding/simplifiedchinese" - -var GBKEncoding = &Encoding{ - enc: simplifiedchinese.GBK, - name: CharsetGBK, - charLength: func(bs []byte) int { - if len(bs) == 0 || bs[0] < 0x80 { - // A byte in the range 00–7F is a single byte that means the same thing as it does in ASCII. - return 1 - } - return 2 - }, - specialCase: GBKCase, -} diff --git a/parser/charset/latin.go b/parser/charset/latin.go deleted file mode 100644 index 04de80d250aef..0000000000000 --- a/parser/charset/latin.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package charset - -import ( - "golang.org/x/text/encoding" - "golang.org/x/text/encoding/charmap" -) - -var ( - LatinEncoding = &Encoding{ - enc: charmap.Windows1252, - name: CharsetLatin1, - charLength: func(bytes []byte) int { - return 1 - }, - specialCase: nil, - } - - BinaryEncoding = &Encoding{ - enc: encoding.Nop, - name: CharsetBin, - charLength: func(bytes []byte) int { - return 1 - }, - specialCase: nil, - } - - ASCIIEncoding = &Encoding{ - enc: encoding.Nop, - name: CharsetASCII, - charLength: func(bytes []byte) int { - return 1 - }, - specialCase: nil, - } -) diff --git a/parser/charset/special_case_tables.go b/parser/charset/special_case_tables.go deleted file mode 100644 index 8a92ee717c566..0000000000000 --- a/parser/charset/special_case_tables.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package charset - -import ( - "strings" - "unicode" -) - -func (e *Encoding) ToUpper(d string) string { - return strings.ToUpperSpecial(e.specialCase, d) -} - -func (e *Encoding) ToLower(d string) string { - return strings.ToLowerSpecial(e.specialCase, d) -} - -func LookupSpecialCase(label string) unicode.SpecialCase { - label = strings.ToLower(strings.Trim(label, "\t\n\r\f ")) - return specailCases[label].c -} - -var specailCases = map[string]struct { - c unicode.SpecialCase -}{ - "utf-8": {nil}, - "ibm866": {nil}, - "iso-8859-2": {nil}, - "iso-8859-3": {nil}, - "iso-8859-4": {nil}, - "iso-8859-5": {nil}, - "iso-8859-6": {nil}, - "iso-8859-7": {nil}, - "iso-8859-8": {nil}, - "iso-8859-8-i": {nil}, - "iso-8859-10": {nil}, - "iso-8859-13": {nil}, - "iso-8859-14": {nil}, - "iso-8859-15": {nil}, - "iso-8859-16": {nil}, - "koi8-r": {nil}, - "macintosh": {nil}, - "windows-874": {nil}, - "windows-1250": {nil}, - "windows-1251": {nil}, - "windows-1252": {nil}, - "windows-1253": {nil}, - "windows-1254": {nil}, - "windows-1255": {nil}, - "windows-1256": {nil}, - "windows-1257": {nil}, - "windows-1258": {nil}, - "x-mac-cyrillic": {nil}, - "gbk": {GBKCase}, - "gb18030": {nil}, - "hz-gb-2312": {nil}, - "big5": {nil}, - "euc-jp": {nil}, - "iso-2022-jp": {nil}, - "shift_jis": {nil}, - "euc-kr": {nil}, - "replacement": {nil}, - "utf-16be": {nil}, - "utf-16le": {nil}, - "x-user-defined": {nil}, -} - -// follow https://dev.mysql.com/worklog/task/?id=4583 for GBK -var GBKCase = unicode.SpecialCase{ - unicode.CaseRange{0x00E0, 0x00E1, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00E8, 0x00EA, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00EC, 0x00ED, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00F2, 0x00F3, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00F9, 0x00FA, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x00FC, 0x00FC, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x0101, 0x0101, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x0113, 0x0113, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x011B, 0x011B, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x012B, 0x012B, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x0144, 0x0144, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x0148, 0x0148, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x014D, 0x014D, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x016B, 0x016B, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01CE, 0x01CE, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D0, 0x01D0, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D2, 0x01D2, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D4, 0x01D4, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D6, 0x01D6, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01D8, 0x01D8, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01DA, 0x01DA, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x01DC, 0x01DC, [unicode.MaxCase]rune{0, 0, 0}}, - unicode.CaseRange{0x216A, 0x216B, [unicode.MaxCase]rune{0, 0, 0}}, -} diff --git a/parser/charset/utf.go b/parser/charset/utf.go deleted file mode 100644 index 301aaba49d19a..0000000000000 --- a/parser/charset/utf.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package charset - -import ( - "golang.org/x/text/encoding" -) - -var UTF8Encoding = &Encoding{ - enc: encoding.Nop, - name: CharsetUTF8MB4, - charLength: func(bs []byte) int { - if len(bs) == 0 || bs[0] < 0x80 { - return 1 - } else if bs[0] < 0xe0 { - return 2 - } else if bs[0] < 0xf0 { - return 3 - } - return 4 - }, - specialCase: nil, -} diff --git a/parser/lexer.go b/parser/lexer.go index 94358fe51a962..c274a53f9f049 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -40,7 +40,7 @@ type Scanner struct { r reader buf bytes.Buffer - encoding *charset.Encoding + encoding charset.Encoding errs []error warns []error @@ -146,12 +146,18 @@ func (s *Scanner) AppendWarn(err error) { } func (s *Scanner) tryDecodeToUTF8String(sql string) string { - utf8Lit, err := s.encoding.DecodeString(sql) + if mysql.IsUTF8Charset(s.encoding.Name()) { + // Skip utf8 encoding because `ToUTF8` validates the whole SQL. + // This can cause failure when the SQL contains BLOB values. + // TODO: Convert charset on every token and use 'binary' encoding to decode token. + return sql + } + utf8Lit, err := s.encoding.Transform(nil, charset.Slice(sql), charset.OpDecodeReplace) if err != nil { s.AppendError(err) s.lastErrorAsWarn() } - return utf8Lit + return string(utf8Lit) } func (s *Scanner) getNextToken() int { diff --git a/parser/parser.go b/parser/parser.go index 5f167d88ee7e8..a4fa6c451eed0 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -14989,7 +14989,7 @@ yynewstate: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", yyS[yypt-1].ident)) return 1 } - expr := ast.NewValueExpr(yyS[yypt-0].ident, parser.charset, parser.collation) + expr := ast.NewValueExpr(yyS[yypt-0].ident, yyS[yypt-1].ident, co) tp := expr.GetType() tp.Charset = yyS[yypt-1].ident tp.Collate = co @@ -15013,7 +15013,7 @@ yynewstate: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", yyS[yypt-1].ident)) return 1 } - expr := ast.NewValueExpr(yyS[yypt-0].item, parser.charset, parser.collation) + expr := ast.NewValueExpr(yyS[yypt-0].item, yyS[yypt-1].ident, co) tp := expr.GetType() tp.Charset = yyS[yypt-1].ident tp.Collate = co @@ -15029,7 +15029,7 @@ yynewstate: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", yyS[yypt-1].ident)) return 1 } - expr := ast.NewValueExpr(yyS[yypt-0].item, parser.charset, parser.collation) + expr := ast.NewValueExpr(yyS[yypt-0].item, yyS[yypt-1].ident, co) tp := expr.GetType() tp.Charset = yyS[yypt-1].ident tp.Collate = co diff --git a/parser/parser.y b/parser/parser.y index 4c2dc969450bd..91e3b05918fc9 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -6456,7 +6456,7 @@ Literal: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", $1)) return 1 } - expr := ast.NewValueExpr($2, parser.charset, parser.collation) + expr := ast.NewValueExpr($2, $1, co) tp := expr.GetType() tp.Charset = $1 tp.Collate = co @@ -6480,7 +6480,7 @@ Literal: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", $1)) return 1 } - expr := ast.NewValueExpr($2, parser.charset, parser.collation) + expr := ast.NewValueExpr($2, $1, co) tp := expr.GetType() tp.Charset = $1 tp.Collate = co @@ -6496,7 +6496,7 @@ Literal: yylex.AppendError(ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", $1)) return 1 } - expr := ast.NewValueExpr($2, parser.charset, parser.collation) + expr := ast.NewValueExpr($2, $1, co) tp := expr.GetType() tp.Charset = $1 tp.Collate = co diff --git a/parser/yy_parser.go b/parser/yy_parser.go index df3f416fad2e7..58e18083b28cb 100644 --- a/parser/yy_parser.go +++ b/parser/yy_parser.go @@ -396,7 +396,7 @@ var ( func resetParams(p *Parser) { p.charset = mysql.DefaultCharset p.collation = mysql.DefaultCollationName - p.lexer.encoding = charset.UTF8Encoding + p.lexer.encoding = charset.EncodingUTF8Impl } // ParseParam represents the parameter of parsing. @@ -436,6 +436,6 @@ type CharsetClient string // ApplyOn implements ParseParam interface. func (c CharsetClient) ApplyOn(p *Parser) error { - p.lexer.encoding = charset.NewEncoding(string(c)) + p.lexer.encoding = charset.FindEncoding(string(c)) return nil } diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 2d1d58a6984ae..0dfc4532a89bc 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -4987,7 +4987,7 @@ func (s *testIntegrationSuite) TestIssue30094(c *C) { )) tk.MustQuery(`explain format = 'brief' select * from t30094 where concat(a,'1') = _binary 0xe59388e59388e59388 collate binary and concat(a,'1') = _binary 0xe598bfe598bfe598bf collate binary;`).Check(testkit.Rows( "TableReader 8000.00 root data:Selection", - "└─Selection 8000.00 cop[tikv] eq(to_binary(concat(test.t30094.a, \"1\")), \"0xe59388e59388e59388\"), eq(to_binary(concat(test.t30094.a, \"1\")), \"0xe598bfe598bfe598bf\")", + "└─Selection 8000.00 cop[tikv] eq(concat(test.t30094.a, \"1\"), \"0xe59388e59388e59388\"), eq(concat(test.t30094.a, \"1\"), \"0xe598bfe598bfe598bf\")", " └─TableFullScan 10000.00 cop[tikv] table:t30094 keep order:false, stats:pseudo", )) } diff --git a/server/util.go b/server/util.go index d84b4cd566519..3bf27665e4d5c 100644 --- a/server/util.go +++ b/server/util.go @@ -291,14 +291,14 @@ func dumpBinaryRow(buffer []byte, columns []*ColumnInfo, row chunk.Row, d *resul } type inputDecoder struct { - encoding *charset.Encoding + encoding charset.Encoding buffer []byte } func newInputDecoder(chs string) *inputDecoder { return &inputDecoder{ - encoding: charset.NewEncoding(chs), + encoding: charset.FindEncoding(chs), buffer: nil, } } @@ -309,7 +309,7 @@ func (i *inputDecoder) clean() { } func (i *inputDecoder) decodeInput(src []byte) []byte { - result, err := i.encoding.Decode(i.buffer, src) + result, err := i.encoding.Transform(i.buffer, src, charset.OpDecode) if err != nil { return src } @@ -320,22 +320,23 @@ type resultEncoder struct { // chsName and encoding are unchanged after the initialization from // session variable @@character_set_results. chsName string - encoding *charset.Encoding + encoding charset.Encoding // dataEncoding can be updated to match the column data charset. - dataEncoding *charset.Encoding + dataEncoding charset.Encoding buffer []byte - isBinary bool - isNull bool + isBinary bool + isNull bool + dataIsBinary bool } // newResultEncoder creates a new resultEncoder. func newResultEncoder(chs string) *resultEncoder { return &resultEncoder{ chsName: chs, - encoding: charset.NewEncoding(chs), + encoding: charset.FindEncoding(chs), buffer: nil, isBinary: chs == charset.CharsetBinary, isNull: len(chs) == 0, @@ -352,7 +353,8 @@ func (d *resultEncoder) updateDataEncoding(chsID uint16) { if err != nil { logutil.BgLogger().Warn("unknown charset ID", zap.Error(err)) } - d.dataEncoding = charset.NewEncoding(chs) + d.dataEncoding = charset.FindEncoding(chs) + d.dataIsBinary = chsID == mysql.BinaryDefaultCollationID } func (d *resultEncoder) columnTypeInfoCharsetID(info *ColumnInfo) uint16 { @@ -367,24 +369,30 @@ func (d *resultEncoder) columnTypeInfoCharsetID(info *ColumnInfo) uint16 { return uint16(mysql.CharsetNameToID(d.chsName)) } +// encodeMeta encodes bytes for meta info like column names. +// Note that the result should be consumed immediately. func (d *resultEncoder) encodeMeta(src []byte) []byte { return d.encodeWith(src, d.encoding) } +// encodeData encodes bytes for row data. +// Note that the result should be consumed immediately. func (d *resultEncoder) encodeData(src []byte) []byte { - if d.isNull || d.isBinary { + if d.isNull || d.isBinary || d.dataIsBinary { // Use the column charset to encode. return d.encodeWith(src, d.dataEncoding) } return d.encodeWith(src, d.encoding) } -func (d *resultEncoder) encodeWith(src []byte, enc *charset.Encoding) []byte { - result, err := enc.Encode(d.buffer, src) +func (d *resultEncoder) encodeWith(src []byte, enc charset.Encoding) []byte { + var err error + d.buffer, err = enc.Transform(d.buffer, src, charset.OpEncode) if err != nil { logutil.BgLogger().Debug("encode error", zap.Error(err)) } - return result + // The buffer will be reused. + return d.buffer } func dumpTextRow(buffer []byte, columns []*ColumnInfo, row chunk.Row, d *resultEncoder) ([]byte, error) { diff --git a/session/session.go b/session/session.go index c57628d090f44..a7976a5001bee 100644 --- a/session/session.go +++ b/session/session.go @@ -124,9 +124,9 @@ type Session interface { Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement. // ExecuteStmt executes a parsed statement. ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error) - // Parse is deprecated, use ParseWithParams() instead. + // Parse is deprecated, use ParseWithParams() or ParseWithParamsInternal() instead. Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) - // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. + // ExecuteInternal is a helper around ParseWithParamsInternal() and ExecuteStmt(). It is not allowed to execute multiple statements. ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error) String() string // String is used to debug. CommitTxn(context.Context) error @@ -1153,7 +1153,7 @@ func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet, allo // getTableValue executes restricted sql and the result is one column. // It returns a string value. func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { - stmt, err := s.ParseWithParams(ctx, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) + stmt, err := s.ParseWithParamsInternal(ctx, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) if err != nil { return "", err } @@ -1175,7 +1175,7 @@ func (s *session) getTableValue(ctx context.Context, tblName string, varName str // replaceGlobalVariablesTableValue executes restricted sql updates the variable value // It will then notify the etcd channel that the value has changed. func (s *session) replaceGlobalVariablesTableValue(ctx context.Context, varName, val string) error { - stmt, err := s.ParseWithParams(ctx, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) + stmt, err := s.ParseWithParamsInternal(ctx, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) if err != nil { return err } @@ -1250,7 +1250,7 @@ func (s *session) SetGlobalSysVarOnly(name, value string) (err error) { // SetTiDBTableValue implements GlobalVarAccessor.SetTiDBTableValue interface. func (s *session) SetTiDBTableValue(name, value, comment string) error { - stmt, err := s.ParseWithParams(context.TODO(), `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) + stmt, err := s.ParseWithParamsInternal(context.TODO(), `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) if err != nil { return err } @@ -1520,6 +1520,16 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter return stmts[0], nil } +// ParseWithParamsInternal is same as ParseWithParams except set `s.sessionVars.InRestrictedSQL = true` +func (s *session) ParseWithParamsInternal(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) { + origin := s.sessionVars.InRestrictedSQL + s.sessionVars.InRestrictedSQL = true + defer func() { + s.sessionVars.InRestrictedSQL = origin + }() + return s.ParseWithParams(ctx, sql, args...) +} + // ExecRestrictedStmt implements RestrictedSQLExecutor interface. func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( []chunk.Row, []*ast.ResultField, error) { diff --git a/session/session_test.go b/session/session_test.go index 7b5febe0d18e0..4602758eeaa28 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -4377,22 +4377,17 @@ func (s *testSessionSerialSuite) TestProcessInfoIssue22068(c *C) { wg.Wait() } -func (s *testSessionSerialSuite) TestParseWithParams(c *C) { +func (s *testSessionSerialSuite) TestParseWithParamsInternal(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) se := tk.Se exec := se.(sqlexec.RestrictedSQLExecutor) // test compatibility with ExcuteInternal - origin := se.GetSessionVars().InRestrictedSQL - se.GetSessionVars().InRestrictedSQL = true - defer func() { - se.GetSessionVars().InRestrictedSQL = origin - }() - _, err := exec.ParseWithParams(context.TODO(), "SELECT 4") + _, err := exec.ParseWithParamsInternal(context.TODO(), "SELECT 4") c.Assert(err, IsNil) // test charset attack - stmt, err := exec.ParseWithParams(context.TODO(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + stmt, err := exec.ParseWithParamsInternal(context.TODO(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") c.Assert(err, IsNil) var sb strings.Builder @@ -4402,15 +4397,15 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1") // test invalid sql - _, err = exec.ParseWithParams(context.TODO(), "SELECT") + _, err = exec.ParseWithParamsInternal(context.TODO(), "SELECT") c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*") // test invalid arguments to escape - _, err = exec.ParseWithParams(context.TODO(), "SELECT %?, %?", 3) + _, err = exec.ParseWithParamsInternal(context.TODO(), "SELECT %?, %?", 3) c.Assert(err, ErrorMatches, "missing arguments.*") // test noescape - stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3") + stmt, err = exec.ParseWithParamsInternal(context.TODO(), "SELECT 3") c.Assert(err, IsNil) sb.Reset() diff --git a/session/tidb_test.go b/session/tidb_test.go index 661a8f19d5f47..759eaa02702f4 100644 --- a/session/tidb_test.go +++ b/session/tidb_test.go @@ -37,7 +37,7 @@ func TestSysSessionPoolGoroutineLeak(t *testing.T) { count := 200 stmts := make([]ast.StmtNode, count) for i := 0; i < count; i++ { - stmt, err := se.ParseWithParams(context.Background(), "select * from mysql.user limit 1") + stmt, err := se.ParseWithParamsInternal(context.Background(), "select * from mysql.user limit 1") require.NoError(t, err) stmts[i] = stmt } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 0889d00e431e5..50876ccdf8661 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -128,7 +128,7 @@ func (h *Handle) withRestrictedSQLExecutor(ctx context.Context, fn func(context. func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, sql, params...) + stmt, err := exec.ParseWithParamsInternal(ctx, sql, params...) if err != nil { return nil, nil, errors.Trace(err) } @@ -138,7 +138,7 @@ func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...in func (h *Handle) execRestrictedSQLWithStatsVer(ctx context.Context, statsVer int, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, sql, params...) + stmt, err := exec.ParseWithParamsInternal(ctx, sql, params...) // TODO: An ugly way to set @@tidb_partition_prune_mode. Need to be improved. if _, ok := stmt.(*ast.AnalyzeTableStmt); ok { pruneMode := h.CurrentPruneMode() @@ -155,7 +155,7 @@ func (h *Handle) execRestrictedSQLWithStatsVer(ctx context.Context, statsVer int func (h *Handle) execRestrictedSQLWithSnapshot(ctx context.Context, sql string, snapshot uint64, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, sql, params...) + stmt, err := exec.ParseWithParamsInternal(ctx, sql, params...) if err != nil { return nil, nil, errors.Trace(err) } @@ -1385,7 +1385,7 @@ type statsReader struct { func (sr *statsReader) read(sql string, args ...interface{}) (rows []chunk.Row, fields []*ast.ResultField, err error) { ctx := context.TODO() - stmt, err := sr.ctx.ParseWithParams(ctx, sql, args...) + stmt, err := sr.ctx.ParseWithParamsInternal(ctx, sql, args...) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/table/column.go b/table/column.go index 0225a88556c0b..d7e9a9ec5dadb 100644 --- a/table/column.go +++ b/table/column.go @@ -170,7 +170,7 @@ func truncateTrailingSpaces(v *types.Datum) { v.SetString(str, v.Collation()) } -func handleWrongCharsetValue(ctx sessionctx.Context, col *model.ColumnInfo, str string, i int) error { +func handleWrongCharsetValue(ctx sessionctx.Context, col *model.ColumnInfo, str []byte, i int) error { sc := ctx.GetSessionVars().StmtCtx var strval strings.Builder for j := 0; j < 6; j++ { @@ -328,46 +328,57 @@ func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, r truncateTrailingSpaces(&casted) } - if v := makeStringValidator(ctx, col); v != nil { - str := casted.GetString() - strategy := charset.TruncateStrategyReplace - if val.Collation() == charset.CollationBin { - strategy = charset.TruncateStrategyTrim - } - if newStr, invalidPos := v.Truncate(str, strategy); invalidPos >= 0 { - casted = types.NewStringDatum(newStr) - err = handleWrongCharsetValue(ctx, col, str, invalidPos) - } - } + err = validateStringDatum(ctx, &val, &casted, col) if forceIgnoreTruncate { err = nil } return casted, err } -func makeStringValidator(ctx sessionctx.Context, col *model.ColumnInfo) charset.StringValidator { - switch col.Charset { - case charset.CharsetASCII: - if ctx.GetSessionVars().SkipASCIICheck { - return nil - } - return charset.StringValidatorASCII{} - case charset.CharsetUTF8: - if ctx.GetSessionVars().SkipUTF8Check { - return nil - } - needCheckMB4 := config.GetGlobalConfig().CheckMb4ValueInUTF8 - return charset.StringValidatorUTF8{IsUTF8MB4: false, CheckMB4ValueInUTF8: needCheckMB4} - case charset.CharsetUTF8MB4: - if ctx.GetSessionVars().SkipUTF8Check { - return nil +func validateStringDatum(ctx sessionctx.Context, origin, casted *types.Datum, col *model.ColumnInfo) error { + // Only strings are need to validate. + if !types.IsString(col.Tp) { + return nil + } + fromBinary := origin.Kind() == types.KindBinaryLiteral || + (origin.Kind() == types.KindString && origin.Collation() == charset.CollationBin) + toBinary := types.IsTypeBlob(col.Tp) || col.Charset == charset.CharsetBin + if fromBinary && toBinary { + return nil + } + enc := charset.FindEncoding(col.Charset) + // Skip utf8 check if possible. + if enc.Tp() == charset.EncodingTpUTF8 && ctx.GetSessionVars().SkipUTF8Check { + return nil + } + // Skip ascii check if possible. + if enc.Tp() == charset.EncodingTpASCII && ctx.GetSessionVars().SkipASCIICheck { + return nil + } + if col.Charset == charset.CharsetUTF8 && config.GetGlobalConfig().CheckMb4ValueInUTF8 { + // Use a strict mode implementation. 4 bytes characters are invalid. + enc = charset.EncodingUTF8MB3StrictImpl + } + if fromBinary { + src := casted.GetBytes() + encBytes, err := enc.Transform(nil, src, charset.OpDecode) + if err != nil { + casted.SetBytesAsString(encBytes, charset.CollationUTF8MB4, 0) + nSrc := charset.CountValidBytesDecode(enc, src) + return handleWrongCharsetValue(ctx, col, src, nSrc) } - return charset.StringValidatorUTF8{IsUTF8MB4: true} - case charset.CharsetLatin1, charset.CharsetBinary: + casted.SetBytesAsString(encBytes, charset.CollationUTF8MB4, 0) return nil - default: - return charset.StringValidatorOther{Charset: col.Charset} } + // Check if the string is valid in the given column charset. + str := casted.GetBytes() + if !charset.IsValid(enc, str) { + replace, _ := enc.Transform(nil, str, charset.OpReplace) + casted.SetBytesAsString(replace, charset.CollationUTF8MB4, 0) + nSrc := charset.CountValidBytes(enc, str) + return handleWrongCharsetValue(ctx, col, str, nSrc) + } + return nil } // ColDesc describes column information like MySQL desc and show columns do. diff --git a/telemetry/data_cluster_hardware.go b/telemetry/data_cluster_hardware.go index 611ae2e005384..3c2dca4928d78 100644 --- a/telemetry/data_cluster_hardware.go +++ b/telemetry/data_cluster_hardware.go @@ -69,7 +69,7 @@ func normalizeFieldName(name string) string { func getClusterHardware(ctx sessionctx.Context) ([]*clusterHardwareItem, error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_cluster_info.go b/telemetry/data_cluster_info.go index a1569c3e67634..d1d645a3803d1 100644 --- a/telemetry/data_cluster_info.go +++ b/telemetry/data_cluster_info.go @@ -37,7 +37,7 @@ type clusterInfoItem struct { func getClusterInfo(ctx sessionctx.Context) ([]*clusterInfoItem, error) { // Explicitly list all field names instead of using `*` to avoid potential leaking sensitive info when adding new fields in future. exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_feature_usage.go b/telemetry/data_feature_usage.go index cf32510097423..6255eeb2ec9df 100644 --- a/telemetry/data_feature_usage.go +++ b/telemetry/data_feature_usage.go @@ -77,7 +77,7 @@ func getClusterIndexUsageInfo(ctx sessionctx.Context) (cu *ClusterIndexUsage, er exec := ctx.(sqlexec.RestrictedSQLExecutor) // query INFORMATION_SCHEMA.tables to get the latest table information about ClusterIndex - stmt, err := exec.ParseWithParams(context.TODO(), ` + stmt, err := exec.ParseWithParamsInternal(context.TODO(), ` SELECT left(sha2(TABLE_NAME, 256), 6) table_name_hash, TIDB_PK_TYPE, TABLE_SCHEMA, TABLE_NAME FROM information_schema.tables WHERE table_schema not in ('INFORMATION_SCHEMA', 'METRICS_SCHEMA', 'PERFORMANCE_SCHEMA', 'mysql') diff --git a/util/admin/admin.go b/util/admin/admin.go index 3f68393f52833..9b6bb8c5168ce 100644 --- a/util/admin/admin.go +++ b/util/admin/admin.go @@ -328,7 +328,7 @@ func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices }() // Add `` for some names like `table name`. exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) + stmt, err := exec.ParseWithParamsInternal(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) if err != nil { return 0, 0, errors.Trace(err) } @@ -350,7 +350,7 @@ func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices return 0, 0, errors.Trace(err) } for i, idx := range indices { - stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) + stmt, err := exec.ParseWithParamsInternal(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) if err != nil { return 0, i, errors.Trace(err) } diff --git a/util/gcutil/gcutil.go b/util/gcutil/gcutil.go index 9216d37b1ad9a..c11a6ca66d996 100644 --- a/util/gcutil/gcutil.go +++ b/util/gcutil/gcutil.go @@ -72,7 +72,7 @@ func ValidateSnapshotWithGCSafePoint(snapshotTS, safePointTS uint64) error { // GetGCSafePoint loads GC safe point time from mysql.tidb. func GetGCSafePoint(ctx sessionctx.Context) (uint64, error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point") + stmt, err := exec.ParseWithParamsInternal(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point") if err != nil { return 0, errors.Trace(err) } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 4be2cae5ce12a..165eae7e8c452 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -46,6 +46,8 @@ type RestrictedSQLExecutor interface { // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. // This function only saves you from processing potentially unsafe parameters. ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) + // ParseWithParamsInternal is same as ParseWithParams except set `s.sessionVars.InRestrictedSQL = true` + ParseWithParamsInternal(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) // ExecRestrictedStmt run sql statement in ctx with some restriction. ExecRestrictedStmt(ctx context.Context, stmt ast.StmtNode, opts ...OptionFuncAlias) ([]chunk.Row, []*ast.ResultField, error) }