Skip to content

Commit

Permalink
executor: optimize load data assignment expressions (#46082) (#46110)
Browse files Browse the repository at this point in the history
close #46081
  • Loading branch information
ti-chi-bot authored Aug 15, 2023
1 parent 6e593be commit ab110b5
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 3 deletions.
5 changes: 5 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,11 @@ func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) Executor {
b.err = err
return nil
}
err = loadDataInfo.initColAssignExprs()
if err != nil {
b.err = err
return nil
}
loadDataExec := &LoadDataExec{
baseExecutor: newBaseExecutor(b.ctx, nil, v.ID()),
IsLocal: v.IsLocal,
Expand Down
32 changes: 29 additions & 3 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -116,7 +117,11 @@ type LoadDataInfo struct {
rows [][]types.Datum
Drained bool

ColumnAssignments []*ast.Assignment
ColumnAssignments []*ast.Assignment
ColumnAssignmentExprs []expression.Expression
// sessionCtx generate warnings when rewrite AST node into expression.
// we should generate such warnings for each row encoded.
exprWarnings []stmtctx.SQLWarn
ColumnsAndUserVars []*ast.ColumnNameOrUserVar
FieldMappings []*FieldMapping

Expand Down Expand Up @@ -211,6 +216,23 @@ func (e *LoadDataInfo) initLoadColumns(columnNames []string) error {
return nil
}

// initColAssignExprs creates the column assignment expressions using session context.
// RewriteAstExpr will write ast node in place(due to xxNode.Accept), but it doesn't change node content,
func (e *LoadDataInfo) initColAssignExprs() error {
for _, assign := range e.ColumnAssignments {
newExpr, err := expression.RewriteAstExpr(e.Ctx, assign.Expr, nil, nil)
if err != nil {
return err
}
// col assign expr warnings is static, we should generate it for each row processed.
// so we save it and clear it here.
e.exprWarnings = append(e.exprWarnings, e.Ctx.GetSessionVars().StmtCtx.GetWarnings()...)
e.Ctx.GetSessionVars().StmtCtx.SetWarnings(nil)
e.ColumnAssignmentExprs = append(e.ColumnAssignmentExprs, newExpr)
}
return nil
}

// initFieldMappings make a field mapping slice to implicitly map input field to table column or user defined variable
// the slice's order is the same as the order of the input fields.
// Returns a slice of same ordered column names without user defined variable names.
Expand Down Expand Up @@ -664,15 +686,19 @@ func (e *LoadDataInfo) colsToRow(ctx context.Context, cols []field) []types.Datu

row = append(row, types.NewDatum(string(cols[i].str)))
}
for i := 0; i < len(e.ColumnAssignments); i++ {

for i := 0; i < len(e.ColumnAssignmentExprs); i++ {
// eval expression of `SET` clause
d, err := expression.EvalAstExpr(e.Ctx, e.ColumnAssignments[i].Expr)
d, err := e.ColumnAssignmentExprs[i].Eval(chunk.Row{})
if err != nil {
e.handleWarning(err)
return nil
}
row = append(row, d)
}
if len(e.exprWarnings) > 0 {
e.Ctx.GetSessionVars().StmtCtx.AppendWarnings(e.exprWarnings)
}

// a new row buffer will be allocated in getRow
newRow, err := e.getRow(ctx, row)
Expand Down
67 changes: 67 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,73 @@ func (cli *testServerClient) runTestLoadData(t *testing.T, server *Server) {
require.NoError(t, rows.Close())
dbt.MustExec("drop table if exists pn")
})

err = fp.Close()
require.NoError(t, err)
err = os.Remove(path)
require.NoError(t, err)

fp, err = os.Create(path)
require.NoError(t, err)
require.NotNil(t, fp)

_, err = fp.WriteString(
`1,2` + "\n" +
`1,2,,4` + "\n" +
`1,2,3` + "\n" +
`,,,` + "\n" +
`,,3` + "\n" +
`1,,,4` + "\n")
require.NoError(t, err)

nullInt32 := func(val int32, valid bool) sql.NullInt32 {
return sql.NullInt32{Int32: val, Valid: valid}
}
expects := []struct {
col1 sql.NullInt32
col2 sql.NullInt32
col3 sql.NullInt32
col4 sql.NullInt32
}{
{nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(0, false)},
{nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(4, true)},
{nullInt32(1, true), nullInt32(2, true), nullInt32(3, true), nullInt32(0, false)},
{nullInt32(0, true), nullInt32(0, false), nullInt32(0, false), nullInt32(0, false)},
{nullInt32(0, true), nullInt32(0, false), nullInt32(3, true), nullInt32(0, false)},
{nullInt32(1, true), nullInt32(0, false), nullInt32(0, false), nullInt32(4, true)},
}

cli.runTestsOnNewDB(t, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Params["sql_mode"] = "''"
}, "LoadData", func(dbt *testkit.DBTestKit) {
dbt.MustExec("drop table if exists pn")
dbt.MustExec("create table pn (c1 int, c2 int, c3 int, c4 int)")
dbt.MustExec("set @@tidb_dml_batch_size = 1")
_, err1 := dbt.GetDB().Exec(fmt.Sprintf(`load data local infile %q into table pn FIELDS TERMINATED BY ',' (c1, @val2, @val3, @val4)
SET c2 = NULLIF(@val2, ''), c3 = NULLIF(@val3, ''), c4 = NULLIF(@val4, '')`, path))
require.NoError(t, err1)
var (
a sql.NullInt32
b sql.NullInt32
c sql.NullInt32
d sql.NullInt32
)
rows := dbt.MustQuery("select * from pn")
for _, expect := range expects {
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&a, &b, &c, &d)
require.NoError(t, err)
require.Equal(t, expect.col1, a)
require.Equal(t, expect.col2, b)
require.Equal(t, expect.col3, c)
require.Equal(t, expect.col4, d)
}

require.Falsef(t, rows.Next(), "unexpected data")
require.NoError(t, rows.Close())
dbt.MustExec("drop table if exists pn")
})
}

func (cli *testServerClient) runTestConcurrentUpdate(t *testing.T) {
Expand Down

0 comments on commit ab110b5

Please sign in to comment.