diff --git a/executor/distsql.go b/executor/distsql.go index ed5aae638c46f..3edb1dd709168 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -1128,6 +1128,13 @@ func (w *tableWorker) compareData(ctx context.Context, task *lookupTableTask, ta chk := newFirstChunk(tableReader) tblInfo := w.idxLookup.table.Meta() vals := make([]types.Datum, 0, len(w.idxTblCols)) + + // Prepare collator for compare. + collators := make([]collate.Collator, 0, len(w.idxColTps)) + for _, tp := range w.idxColTps { + collators = append(collators, collate.GetCollator(tp.Collate)) + } + for { err := Next(ctx, tableReader, chk) if err != nil { @@ -1169,7 +1176,7 @@ func (w *tableWorker) compareData(ctx context.Context, task *lookupTableTask, ta tp := &col.FieldType idxVal := idxRow.GetDatum(i, tp) tablecodec.TruncateIndexValue(&idxVal, w.idxLookup.index.Columns[i], col.ColumnInfo) - cmpRes, err := idxVal.CompareDatum(sctx, &vals[i]) + cmpRes, err := idxVal.Compare(sctx, &vals[i], collators[i]) if err != nil { return ErrDataInConsistentMisMatchIndex.GenWithStackByArgs(col.Name, handle, idxRow.GetDatum(i, tp), vals[i], err) diff --git a/types/datum.go b/types/datum.go index 05db79a64e18e..f4c46d5329d81 100644 --- a/types/datum.go +++ b/types/datum.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/parser/types" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/hack" ) @@ -548,7 +549,63 @@ func (d *Datum) SetValue(val interface{}, tp *types.FieldType) { } } +// Compare compares datum to another datum. +// Notes: don't rely on datum.collation to get the collator, it's tend to buggy. +// TODO: use this function to replace CompareDatum. After we remove all of usage of CompareDatum, we can rename this function back to CompareDatum. +func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collate.Collator) (int, error) { + if d.k == KindMysqlJSON && ad.k != KindMysqlJSON { + cmp, err := ad.Compare(sc, d, comparer) + return cmp * -1, errors.Trace(err) + } + switch ad.k { + case KindNull: + if d.k == KindNull { + return 0, nil + } + return 1, nil + case KindMinNotNull: + if d.k == KindNull { + return -1, nil + } else if d.k == KindMinNotNull { + return 0, nil + } + return 1, nil + case KindMaxValue: + if d.k == KindMaxValue { + return 0, nil + } + return -1, nil + case KindInt64: + return d.compareInt64(sc, ad.GetInt64()) + case KindUint64: + return d.compareUint64(sc, ad.GetUint64()) + case KindFloat32, KindFloat64: + return d.compareFloat64(sc, ad.GetFloat64()) + case KindString: + return d.compareStringNew(sc, ad.GetString(), comparer) + case KindBytes: + return comparer.Compare(d.GetString(), ad.GetString()), nil + case KindMysqlDecimal: + return d.compareMysqlDecimal(sc, ad.GetMysqlDecimal()) + case KindMysqlDuration: + return d.compareMysqlDuration(sc, ad.GetMysqlDuration()) + case KindMysqlEnum: + return d.compareMysqlEnumNew(sc, ad.GetMysqlEnum(), comparer) + case KindBinaryLiteral, KindMysqlBit: + return d.compareBinaryLiteralNew(sc, ad.GetBinaryLiteral4Cmp(), comparer) + case KindMysqlSet: + return d.compareMysqlSetNew(sc, ad.GetMysqlSet(), comparer) + case KindMysqlJSON: + return d.compareMysqlJSON(sc, ad.GetMysqlJSON()) + case KindMysqlTime: + return d.compareMysqlTime(sc, ad.GetMysqlTime()) + default: + return 0, nil + } +} + // CompareDatum compares datum to another datum. +// Deprecated: will be replaced with Compare. // TODO: return error properly. func (d *Datum) CompareDatum(sc *stmtctx.StatementContext, ad *Datum) (int, error) { if d.k == KindMysqlJSON && ad.k != KindMysqlJSON { @@ -673,6 +730,39 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er } } +func (d *Datum) compareStringNew(sc *stmtctx.StatementContext, s string, comparer collate.Collator) (int, error) { + switch d.k { + case KindNull, KindMinNotNull: + return -1, nil + case KindMaxValue: + return 1, nil + case KindString, KindBytes: + return comparer.Compare(d.GetString(), s), nil + case KindMysqlDecimal: + dec := new(MyDecimal) + err := sc.HandleTruncate(dec.FromString(hack.Slice(s))) + return d.GetMysqlDecimal().Compare(dec), errors.Trace(err) + case KindMysqlTime: + dt, err := ParseDatetime(sc, s) + return d.GetMysqlTime().Compare(dt), errors.Trace(err) + case KindMysqlDuration: + dur, err := ParseDuration(sc, s, MaxFsp) + return d.GetMysqlDuration().Compare(dur), errors.Trace(err) + case KindMysqlSet: + return comparer.Compare(d.GetMysqlSet().String(), s), nil + case KindMysqlEnum: + return comparer.Compare(d.GetMysqlEnum().String(), s), nil + case KindBinaryLiteral, KindMysqlBit: + return comparer.Compare(d.GetBinaryLiteral4Cmp().String(), s), nil + default: + fVal, err := StrToFloat(sc, s, false) + if err != nil { + return 0, errors.Trace(err) + } + return d.compareFloat64(sc, fVal) + } +} + func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, retCollation string) (int, error) { switch d.k { case KindNull, KindMinNotNull: @@ -748,6 +838,52 @@ func (d *Datum) compareMysqlDuration(sc *stmtctx.StatementContext, dur Duration) } } +func (d *Datum) compareMysqlEnumNew(sc *stmtctx.StatementContext, enum Enum, comparer collate.Collator) (int, error) { + switch d.k { + case KindNull, KindMinNotNull: + return -1, nil + case KindMaxValue: + return 1, nil + case KindString, KindBytes, KindMysqlEnum, KindMysqlSet: + return comparer.Compare(d.GetString(), enum.String()), nil + default: + return d.compareFloat64(sc, enum.ToNumber()) + } +} + +func (d *Datum) compareBinaryLiteralNew(sc *stmtctx.StatementContext, b BinaryLiteral, comparer collate.Collator) (int, error) { + switch d.k { + case KindNull, KindMinNotNull: + return -1, nil + case KindMaxValue: + return 1, nil + case KindString, KindBytes: + fallthrough // in this case, d is converted to Binary and then compared with b + case KindBinaryLiteral, KindMysqlBit: + return comparer.Compare(d.GetBinaryLiteral4Cmp().ToString(), b.ToString()), nil + default: + val, err := b.ToInt(sc) + if err != nil { + return 0, errors.Trace(err) + } + result, err := d.compareFloat64(sc, float64(val)) + return result, errors.Trace(err) + } +} + +func (d *Datum) compareMysqlSetNew(sc *stmtctx.StatementContext, set Set, comparer collate.Collator) (int, error) { + switch d.k { + case KindNull, KindMinNotNull: + return -1, nil + case KindMaxValue: + return 1, nil + case KindString, KindBytes, KindMysqlEnum, KindMysqlSet: + return comparer.Compare(d.GetString(), set.String()), nil + default: + return d.compareFloat64(sc, set.ToNumber()) + } +} + func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum) (int, error) { switch d.k { case KindNull, KindMinNotNull: