Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support require-secure-transport startup option #15341

Merged
merged 4 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,15 @@ func (l *Log) getDisableErrorStack() bool {

// Security is the security section of the config.
type Security struct {
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
RequireSecureTransport bool `toml:"require-secure-transport" json:"require-secure-transport"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
}

// The ErrConfigValidationFailed error is used so that external callers can do a type assertion
Expand Down
1 change: 1 addition & 0 deletions errno/errcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ const (
ErrInvalidJSONContainsPathType = 3150
ErrJSONUsedAsKey = 3152
ErrJSONDocumentNULLKey = 3158
ErrSecureTransportRequired = 3159
ErrBadUser = 3162
ErrUserAlreadyExists = 3163
ErrInvalidJSONPathArrayCell = 3165
Expand Down
1 change: 1 addition & 0 deletions errno/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ var MySQLErrName = map[uint16]string{
ErrInvalidJSONContainsPathType: "The second argument can only be either 'one' or 'all'.",
ErrJSONUsedAsKey: "JSON column '%-.192s' cannot be used in key specification.",
ErrJSONDocumentNULLKey: "JSON documents may not contain NULL member names.",
ErrSecureTransportRequired: "Connections using insecure transport are prohibited while --require_secure_transport=ON.",
ErrBadUser: "User %s does not exist.",
ErrUserAlreadyExists: "User %s already exists.",
ErrInvalidJSONPathArrayCell: "A path expression is not a path to a cell in an array.",
Expand Down
2 changes: 1 addition & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
if !s.NoRollbackOnError || config.GetGlobalConfig().Security.RequireSecureTransport {
return err
}
logutil.BgLogger().Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
Expand Down
3 changes: 3 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
Expand Down Expand Up @@ -521,6 +522,8 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}
}
} else if config.GetGlobalConfig().Security.RequireSecureTransport {
return errSecureTransportRequired.FastGenByArgs()
}

// Read the remaining part of the packet.
Expand Down
16 changes: 9 additions & 7 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ func init() {
}

var (
errUnknownFieldType = terror.ClassServer.New(errno.ErrUnknownFieldType, errno.MySQLErrName[errno.ErrUnknownFieldType])
errInvalidSequence = terror.ClassServer.New(errno.ErrInvalidSequence, errno.MySQLErrName[errno.ErrInvalidSequence])
errInvalidType = terror.ClassServer.New(errno.ErrInvalidType, errno.MySQLErrName[errno.ErrInvalidType])
errNotAllowedCommand = terror.ClassServer.New(errno.ErrNotAllowedCommand, errno.MySQLErrName[errno.ErrNotAllowedCommand])
errAccessDenied = terror.ClassServer.New(errno.ErrAccessDenied, errno.MySQLErrName[errno.ErrAccessDenied])
errConCount = terror.ClassServer.New(errno.ErrConCount, errno.MySQLErrName[errno.ErrConCount])
errUnknownFieldType = terror.ClassServer.New(errno.ErrUnknownFieldType, errno.MySQLErrName[errno.ErrUnknownFieldType])
errInvalidSequence = terror.ClassServer.New(errno.ErrInvalidSequence, errno.MySQLErrName[errno.ErrInvalidSequence])
errInvalidType = terror.ClassServer.New(errno.ErrInvalidType, errno.MySQLErrName[errno.ErrInvalidType])
errNotAllowedCommand = terror.ClassServer.New(errno.ErrNotAllowedCommand, errno.MySQLErrName[errno.ErrNotAllowedCommand])
errAccessDenied = terror.ClassServer.New(errno.ErrAccessDenied, errno.MySQLErrName[errno.ErrAccessDenied])
errConCount = terror.ClassServer.New(errno.ErrConCount, errno.MySQLErrName[errno.ErrConCount])
errSecureTransportRequired = terror.ClassServer.New(errno.ErrSecureTransportRequired, errno.MySQLErrName[errno.ErrSecureTransportRequired])
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down Expand Up @@ -209,12 +210,13 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
if err != nil {
logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err))
return nil, err
jackysp marked this conversation as resolved.
Show resolved Hide resolved
}
if tlsConfig != nil {
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
logutil.BgLogger().Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
} else if cfg.Security.RequireSecureTransport {
return nil, errSecureTransportRequired.FastGenByArgs()
}

setSystemTimeZoneVariable()
Expand Down
8 changes: 4 additions & 4 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,11 +663,11 @@ func (ts *tidbTestSerialSuite) TestErrorNoRollback(c *C) {
cfg.Port = cli.port
cfg.Status.ReportStatus = false

// test cannot startup with wrong tls config
cfg.Security = config.Security{
SSLCA: "wrong path",
SSLCert: "wrong path",
SSLKey: "wrong path",
RequireSecureTransport: true,
SSLCA: "wrong path",
SSLCert: "wrong path",
SSLKey: "wrong path",
}
_, err = NewServer(cfg, ts.tidbdrv)
c.Assert(err, NotNil)
Expand Down
59 changes: 32 additions & 27 deletions tidb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,33 +68,34 @@ import (

// Flag Names
const (
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmRepairMode = "repair-mode"
nmRepairList = "repair-list"
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmRepairMode = "repair-mode"
nmRepairList = "repair-list"
nmRequireSecureTransport = "require-secure-transport"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it is the same name as MySQL's, but I'm not sure adding it is approved.
\cc @siddontang

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use mysql --help but don't find hat it uses this in command flags.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a flag of mysqld, try mysqld --verbose --help


nmProxyProtocolNetworks = "proxy-protocol-networks"
nmProxyProtocolHeaderTimeout = "proxy-protocol-header-timeout"
Expand Down Expand Up @@ -124,6 +125,7 @@ var (
affinityCPU = flag.String(nmAffinityCPU, "", "affinity cpu (cpu-no. separated by comma, e.g. 1,2,3)")
repairMode = flagBoolean(nmRepairMode, false, "enable admin repair mode")
repairList = flag.String(nmRepairList, "", "admin repair table list")
requireTLS = flag.Bool(nmRequireSecureTransport, false, "require client use secure transport")

// Log
logLevel = flag.String(nmLogLevel, "info", "log level: info, debug, warn, error, fatal")
Expand Down Expand Up @@ -420,6 +422,9 @@ func overrideConfig(cfg *config.Config) {
if actualFlags[nmPluginDir] {
cfg.Plugin.Dir = *pluginDir
}
if actualFlags[nmRequireSecureTransport] {
cfg.Security.RequireSecureTransport = *requireTLS
}
if actualFlags[nmRepairMode] {
cfg.RepairMode = *repairMode
}
Expand Down
12 changes: 11 additions & 1 deletion util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tipb/go-tipb"
Expand Down Expand Up @@ -366,8 +367,13 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
return
}

requireTLS := config.GetGlobalConfig().Security.RequireSecureTransport

// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
if requireTLS {
clientAuthPolicy = tls.RequestClientCert
}
var certPool *x509.CertPool
if len(ca) > 0 {
var caCert []byte
Expand All @@ -379,7 +385,11 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
}
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
if requireTLS {
clientAuthPolicy = tls.RequireAndVerifyClientCert
} else {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
}
}
tlsConfig = &tls.Config{
Expand Down