diff --git a/server/conn.go b/server/conn.go index fd38018a354d6..9728668a72f26 100644 --- a/server/conn.go +++ b/server/conn.go @@ -56,6 +56,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/arena" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" @@ -391,16 +392,17 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { if err != nil { return errors.Trace(err) } - if !cc.server.skipAuth() { - // Do Auth. + host := variable.DefHostname + if !cc.server.isUnixSocket() { addr := cc.bufReadConn.RemoteAddr().String() - host, _, err1 := net.SplitHostPort(addr) - if err1 != nil { + // Do Auth. + host, _, err = net.SplitHostPort(addr) + if err != nil { return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, addr, "YES")) } - if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) { - return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, host, "YES")) - } + } + if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) { + return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, host, "YES")) } if cc.dbname != "" { err = cc.useDB(context.Background(), cc.dbname) diff --git a/server/server.go b/server/server.go index 51cc6930a6c9d..5a064ac5698e9 100644 --- a/server/server.go +++ b/server/server.go @@ -129,7 +129,7 @@ func (s *Server) newConn(conn net.Conn) *clientConn { return cc } -func (s *Server) skipAuth() bool { +func (s *Server) isUnixSocket() bool { return s.cfg.Socket != "" } diff --git a/server/server_test.go b/server/server_test.go index 7ff59ae67008e..be202a06a1006 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -161,6 +161,9 @@ func (dbt *DBTest) mustQueryRows(query string, args ...interface{}) { func runTestRegression(c *C, overrider configOverrider, dbName string) { runTestsOnNewDB(c, overrider, dbName, func(dbt *DBTest) { + // Show the user + dbt.mustExec("select user()") + // Create Table dbt.mustExec("CREATE TABLE test (val TINYINT)") diff --git a/session/session.go b/session/session.go index 01c0e9d9a2151..5ef1e8ac52d4a 100644 --- a/session/session.go +++ b/session/session.go @@ -1024,10 +1024,13 @@ func (s *session) GetSessionVars() *variable.SessionVars { func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) - // Check IP. + // Check IP or localhost. if pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt) { s.sessionVars.User = user return true + } else if user.Hostname == variable.DefHostname { + log.Errorf("User connection verification failed %s", user) + return false } // Check Hostname. @@ -1047,7 +1050,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by func getHostByIP(ip string) []string { if ip == "127.0.0.1" { - return []string{"localhost"} + return []string{variable.DefHostname} } addrs, err := net.LookupAddr(ip) terror.Log(errors.Trace(err)) diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 50a2cae5f5bfc..7b2801d4212ee 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -204,6 +204,7 @@ const ( // Default TiDB system variable values. const ( + DefHostname = "localhost" DefIndexLookupConcurrency = 4 DefIndexLookupJoinConcurrency = 4 DefIndexSerialScanConcurrency = 1