diff --git a/server/conn.go b/server/conn.go index 821beec213f29..c98d488fb90f0 100644 --- a/server/conn.go +++ b/server/conn.go @@ -57,6 +57,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/arena" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" @@ -391,16 +392,17 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { if err != nil { return errors.Trace(err) } - if !cc.server.skipAuth() { - // Do Auth. + host := variable.DefHostname + if !cc.server.isUnixSocket() { addr := cc.bufReadConn.RemoteAddr().String() - host, _, err1 := net.SplitHostPort(addr) - if err1 != nil { + // Do Auth. + host, _, err = net.SplitHostPort(addr) + if err != nil { return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, addr, "YES")) } - if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) { - return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, host, "YES")) - } + } + if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) { + return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, host, "YES")) } if cc.dbname != "" { err = cc.useDB(context.Background(), cc.dbname) diff --git a/server/server.go b/server/server.go index 8854f9cb12050..49455fe1d4e5c 100644 --- a/server/server.go +++ b/server/server.go @@ -129,7 +129,7 @@ func (s *Server) newConn(conn net.Conn) *clientConn { return cc } -func (s *Server) skipAuth() bool { +func (s *Server) isUnixSocket() bool { return s.cfg.Socket != "" } diff --git a/server/server_test.go b/server/server_test.go index b488a67d6db6e..7bb5b43635226 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -161,6 +161,9 @@ func (dbt *DBTest) mustQueryRows(query string, args ...interface{}) { func runTestRegression(c *C, overrider configOverrider, dbName string) { runTestsOnNewDB(c, overrider, dbName, func(dbt *DBTest) { + // Show the user + dbt.mustExec("select user()") + // Create Table dbt.mustExec("CREATE TABLE test (val TINYINT)") diff --git a/session/session.go b/session/session.go index f5d218f67db64..00e58ce3044ea 100644 --- a/session/session.go +++ b/session/session.go @@ -1100,12 +1100,15 @@ func (s *session) GetSessionVars() *variable.SessionVars { func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) - // Check IP. + // Check IP or localhost. var success bool user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt) if success { s.sessionVars.User = user return true + } else if user.Hostname == variable.DefHostname { + log.Errorf("User connection verification failed %s", user) + return false } // Check Hostname. @@ -1128,7 +1131,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by func getHostByIP(ip string) []string { if ip == "127.0.0.1" { - return []string{"localhost"} + return []string{variable.DefHostname} } addrs, err := net.LookupAddr(ip) terror.Log(errors.Trace(err)) diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 248ab3e5cd8b6..6bbfda2d9fc4e 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -219,6 +219,7 @@ const ( // Default TiDB system variable values. const ( + DefHostname = "localhost" DefIndexLookupConcurrency = 4 DefIndexLookupJoinConcurrency = 4 DefIndexSerialScanConcurrency = 1