diff --git a/executor/simple.go b/executor/simple.go index b23c0b8d6edcf..616187ac50558 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -56,6 +56,8 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro return nil } switch x := e.Statement.(type) { + case *ast.GrantRoleStmt: + err = e.executeGrantRole(x) case *ast.UseStmt: err = e.executeUse(x) case *ast.FlushStmt: @@ -277,6 +279,52 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { return nil } +func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { + failedUsers := make([]string, 0, len(s.Users)) + for _, role := range s.Roles { + exists, err := userExists(e.ctx, role.Username, role.Hostname) + if err != nil { + return err + } + if !exists { + return ErrCannotUser.GenWithStackByArgs("GRANT ROLE", role.String()) + } + } + for _, user := range s.Users { + exists, err := userExists(e.ctx, user.Username, user.Hostname) + if err != nil { + return err + } + if !exists { + return ErrCannotUser.GenWithStackByArgs("GRANT ROLE", user.String()) + } + } + + // begin a transaction to insert role graph edges. + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "begin"); err != nil { + return err + } + + for _, user := range s.Users { + for _, role := range s.Roles { + sql := fmt.Sprintf(`INSERT IGNORE INTO %s.%s (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ('%s','%s','%s','%s')`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + failedUsers = append(failedUsers, user.String()) + logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + return err + } + return ErrCannotUser.GenWithStackByArgs("GRANT ROLE", user.String()) + } + } + } + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "commit"); err != nil { + return err + } + err := domain.GetDomain(e.ctx).PrivilegeHandle().Update(e.ctx.(sessionctx.Context)) + return err +} + func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { failedUsers := make([]string, 0, len(s.UserList)) for _, user := range s.UserList { diff --git a/executor/simple_test.go b/executor/simple_test.go index 94305e7762fc1..b20a15dbef8d8 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -93,6 +93,7 @@ func (s *testSuite3) TestRole(c *C) { result := tk.MustQuery(`SELECT Password FROM mysql.User WHERE User="test" and Host="localhost"`) result.Check(nil) + // Test for DROP ROLE. createRoleSQL := `CREATE ROLE 'test'@'localhost';` tk.MustExec(createRoleSQL) // Make sure user test in mysql.User. @@ -119,6 +120,27 @@ func (s *testSuite3) TestRole(c *C) { result.Check(nil) result = tk.MustQuery(`SELECT * FROM mysql.default_roles WHERE DEFAULT_ROLE_USER="test" and DEFAULT_ROLE_HOST="localhost"`) result.Check(nil) + + // Test for GRANT ROLE + createRoleSQL = `CREATE ROLE 'r_1'@'localhost', 'r_2'@'localhost', 'r_3'@'localhost';` + tk.MustExec(createRoleSQL) + grantRoleSQL := `GRANT 'r_1'@'localhost' TO 'r_2'@'localhost';` + tk.MustExec(grantRoleSQL) + result = tk.MustQuery(`SELECT TO_USER FROM mysql.role_edges WHERE FROM_USER="r_1" and FROM_HOST="localhost"`) + result.Check(testkit.Rows("r_2")) + + grantRoleSQL = `GRANT 'r_1'@'localhost' TO 'r_3'@'localhost', 'r_4'@'localhost';` + _, err = tk.Exec(grantRoleSQL) + c.Check(err, NotNil) + result = tk.MustQuery(`SELECT FROM_USER FROM mysql.role_edges WHERE TO_USER="r_3" and TO_HOST="localhost"`) + result.Check(nil) + + dropRoleSQL := `DROP ROLE IF EXISTS 'r_1'@'localhost' ;` + tk.MustExec(dropRoleSQL) + dropRoleSQL = `DROP ROLE IF EXISTS 'r_2'@'localhost' ;` + tk.MustExec(dropRoleSQL) + dropRoleSQL = `DROP ROLE IF EXISTS 'r_3'@'localhost' ;` + tk.MustExec(dropRoleSQL) } func (s *testSuite3) TestUser(c *C) { diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 011d402fffef7..6b33cbed6a8a5 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -232,8 +232,9 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) { case *ast.AnalyzeTableStmt: return b.buildAnalyze(x) case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, - *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.GrantStmt, - *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, *ast.SetRoleStmt: + *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, + *ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, + *ast.GrantRoleStmt, *ast.SetRoleStmt: return b.buildSimple(node.(ast.StmtNode)) case ast.DDLNode: return b.buildDDL(x) @@ -1056,6 +1057,9 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) { b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) case *ast.GrantStmt: b.visitInfo = collectVisitInfoFromGrantStmt(b.ctx, b.visitInfo, raw) + case *ast.GrantRoleStmt: + err := ErrSpecificAccessDenied.GenWithStackByArgs("GRANT ROLE") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.GrantPriv, "", "", "", err) case *ast.RevokeStmt: b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil) case *ast.KillStmt: