Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support require-secure-transport startup option (#15341) (#15415) #15442

Merged
merged 4 commits into from
Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,15 @@ type Log struct {

// Security is the security section of the config.
type Security struct {
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
RequireSecureTransport bool `toml:"require-secure-transport" json:"require-secure-transport"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
}

// The ErrConfigValidationFailed error is used so that external callers can do a type assertion
Expand Down
2 changes: 1 addition & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
if !s.NoRollbackOnError || config.GetGlobalConfig().Security.RequireSecureTransport {
return err
}
logutil.Logger(context.Background()).Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ require (
github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e
github.com/pingcap/kvproto v0.0.0-20200317043902-2838e21ca222
github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd
github.com/pingcap/parser v3.1.0-beta.1.0.20200317043536-9ebea32e03a6+incompatible
github.com/pingcap/parser v3.1.0-beta.1.0.20200318061433-f0b8f6cdca0d+incompatible
github.com/pingcap/pd/v3 v3.1.0-beta.2.0.20200312100832-1206736bd050
github.com/pingcap/tidb-tools v4.0.0-beta.1.0.20200317092225-ed6b2a87af54+incompatible
github.com/pingcap/tipb v0.0.0-20191126033718-169898888b24
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9 h1:AJD9pZYm72vMgPcQDww
github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8=
github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd h1:CV3VsP3Z02MVtdpTMfEgRJ4T9NGgGTxdHpJerent7rM=
github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8=
github.com/pingcap/parser v3.1.0-beta.1.0.20200317043536-9ebea32e03a6+incompatible h1:LXg9RBvy+a6odBikQF0IcITVFYYDCgbVf5YPhM+AIIU=
github.com/pingcap/parser v3.1.0-beta.1.0.20200317043536-9ebea32e03a6+incompatible/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/parser v3.1.0-beta.1.0.20200318061433-f0b8f6cdca0d+incompatible h1:+Jibmc9uklKz9/prpBggFyjZpqRM8phc1AOOJGxkP48=
github.com/pingcap/parser v3.1.0-beta.1.0.20200318061433-f0b8f6cdca0d+incompatible/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/pd/v3 v3.1.0-beta.2.0.20200312100832-1206736bd050 h1:mxPdR0pxnUcRfRGX2JnaLyAd9SZWeR42SzvMp4Zv3YI=
github.com/pingcap/pd/v3 v3.1.0-beta.2.0.20200312100832-1206736bd050/go.mod h1:0HfF1LfWLMuGpui0PKhGvkXxfjv1JslMRY6B+cae3dg=
github.com/pingcap/tidb-tools v4.0.0-beta.1.0.20200317092225-ed6b2a87af54+incompatible h1:tYADqdmWwgDOwf/qEN0trJAy6H3c3Tt/QZx1z4qVrRQ=
Expand Down
3 changes: 3 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
Expand Down Expand Up @@ -516,6 +517,8 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}
}
} else if config.GetGlobalConfig().Security.RequireSecureTransport {
return errSecureTransportRequired.FastGenByArgs()
}

// Read the remaining part of the packet.
Expand Down
31 changes: 18 additions & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ func init() {
}

var (
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
errSecureTransportRequired = terror.ClassServer.New(codeSecureTransportRequired, mysql.MySQLErrName[mysql.ErrSecureTransportRequired])
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down Expand Up @@ -205,6 +206,8 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
logutil.Logger(context.Background()).Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
} else if cfg.Security.RequireSecureTransport {
return nil, errSecureTransportRequired.FastGenByArgs()
}

setSystemTimeZoneVariable()
Expand Down Expand Up @@ -595,16 +598,18 @@ const (
codeInvalidSequence = 3
codeInvalidType = 4

codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
codeSecureTransportRequired = mysql.ErrSecureTransportRequired
)

func init() {
serverMySQLErrCodes := map[terror.ErrCode]uint16{
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
codeSecureTransportRequired: mysql.ErrSecureTransportRequired,
}
terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes
}
9 changes: 9 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,15 @@ func (ts *TidbTestSuite) TestErrorNoRollback(c *C) {
cfg.Port = 4006
cfg.Status.ReportStatus = false

cfg.Security = config.Security{
RequireSecureTransport: true,
SSLCA: "wrong path",
SSLCert: "wrong path",
SSLKey: "wrong path",
}
_, err = NewServer(cfg, ts.tidbdrv)
c.Assert(err, NotNil)

// test reload tls fail with/without "error no rollback option"
cfg.Security = config.Security{
SSLCA: "/tmp/ca-cert-rollback.pem",
Expand Down
56 changes: 30 additions & 26 deletions tidb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,32 @@ import (

// Flag Names
const (
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmRequireSecureTransport = "require-secure-transport"

nmProxyProtocolNetworks = "proxy-protocol-networks"
nmProxyProtocolHeaderTimeout = "proxy-protocol-header-timeout"
Expand All @@ -113,6 +114,7 @@ var (
tokenLimit = flag.Int(nmTokenLimit, 1000, "the limit of concurrent executed sessions")
pluginDir = flag.String(nmPluginDir, "/data/deploy/plugin", "the folder that hold plugin")
pluginLoad = flag.String(nmPluginLoad, "", "wait load plugin name(separated by comma)")
requireTLS = flag.Bool(nmRequireSecureTransport, false, "require client use secure transport")

// Log
logLevel = flag.String(nmLogLevel, "info", "log level: info, debug, warn, error, fatal")
Expand Down Expand Up @@ -439,7 +441,9 @@ func overrideConfig() {
if actualFlags[nmPluginDir] {
cfg.Plugin.Dir = *pluginDir
}

if actualFlags[nmRequireSecureTransport] {
cfg.Security.RequireSecureTransport = *requireTLS
}
// Log
if actualFlags[nmLogLevel] {
cfg.Log.Level = *logLevel
Expand Down
12 changes: 11 additions & 1 deletion util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -317,8 +318,13 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
return
}

requireTLS := config.GetGlobalConfig().Security.RequireSecureTransport

// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
if requireTLS {
clientAuthPolicy = tls.RequestClientCert
}
var certPool *x509.CertPool
if len(ca) > 0 {
var caCert []byte
Expand All @@ -330,7 +336,11 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
}
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
if requireTLS {
clientAuthPolicy = tls.RequireAndVerifyClientCert
} else {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
}
}
tlsConfig = &tls.Config{
Expand Down