diff --git a/executor/simple.go b/executor/simple.go index 077b5b4f4a489..0ef744f642aca 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -230,18 +230,46 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } continue } + + // begin a transaction to delete a user. + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "begin"); err != nil { + return errors.Trace(err) + } sql := fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = "%s" and User = "%s";`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql) - if err != nil { + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + failedUsers = append(failedUsers, user.String()) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + return errors.Trace(err) + } + continue + } + + // delete privileges from mysql.db + sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = "%s" and User = "%s";`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + failedUsers = append(failedUsers, user.String()) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + return errors.Trace(err) + } + continue + } + + // delete privileges from mysql.tables_priv + sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = "%s" and User = "%s";`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + failedUsers = append(failedUsers, user.String()) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + return errors.Trace(err) + } + continue + } + + //TODO: need delete columns_priv once we implement columns_priv functionality. + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "commit"); err != nil { failedUsers = append(failedUsers, user.String()) } } if len(failedUsers) > 0 { - // Commit the transaction even if we returns error - err := e.ctx.Txn().Commit(sessionctx.SetCommitCtx(context.Background(), e.ctx)) - if err != nil { - return errors.Trace(err) - } errMsg := "Operation DROP USER failed for " + strings.Join(failedUsers, ",") return terror.ClassExecutor.New(CodeCannotUser, errMsg) } diff --git a/executor/simple_test.go b/executor/simple_test.go index eabeea8af3de6..06ed2e99dda77 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -177,6 +177,25 @@ func (s *testSuite) TestUser(c *C) { tk.MustExec(createUserSQL) dropUserSQL = `DROP USER 'test1'@'localhost';` tk.MustExec(dropUserSQL) + tk.MustQuery("select * from mysql.db").Check(testkit.Rows( + "localhost test testDB Y Y Y Y Y Y Y N Y Y N N N N N N Y N N", + "localhost test testDB1 Y Y Y Y Y Y Y N Y Y N N N N N N Y N N] [% dddb_% dduser Y Y Y Y Y Y Y N Y Y N N N N N N Y N N", + "% test test Y N N N N N N N N N N N N N N N N N N", + "localhost test testDBRevoke N N N N N N N N N N N N N N N N N N N", + )) + + // Test drop user meet error + _, err = tk.Exec(dropUserSQL) + c.Assert(terror.ErrorEqual(err, terror.ClassExecutor.New(executor.CodeCannotUser, "")), IsTrue) + + createUserSQL = `CREATE USER 'test1'@'localhost'` + tk.MustExec(createUserSQL) + createUserSQL = `CREATE USER 'test2'@'localhost'` + tk.MustExec(createUserSQL) + + dropUserSQL = `DROP USER 'test1'@'localhost', 'test2'@'localhost', 'test3'@'localhost';` + _, err = tk.Exec(dropUserSQL) + c.Assert(terror.ErrorEqual(err, terror.ClassExecutor.New(executor.CodeCannotUser, "")), IsTrue) } func (s *testSuite) TestSetPwd(c *C) {