From 67efcb0b84d82cb0f8532af71fa94f71dcfa1c5e Mon Sep 17 00:00:00 2001 From: lysu Date: Wed, 10 Jul 2019 15:59:48 +0800 Subject: [PATCH] plugin: support dynamic enable/disable plugins (#11122) (#11157) --- executor/admin_plugins.go | 52 +++++++ executor/builder.go | 6 + executor/show.go | 2 +- go.mod | 2 +- go.sum | 4 +- planner/core/common_plans.go | 17 +++ planner/core/planbuilder.go | 4 + plugin/audit.go | 3 +- plugin/const_test.go | 42 ++++++ plugin/helper_test.go | 54 +++++++ plugin/plugin.go | 156 +++++++++++++------ plugin/plugin_test.go | 281 +++++++++++++++++++++++++++++++++++ server/conn.go | 2 +- server/server.go | 7 +- 14 files changed, 574 insertions(+), 58 deletions(-) create mode 100644 executor/admin_plugins.go create mode 100644 plugin/const_test.go create mode 100644 plugin/helper_test.go create mode 100644 plugin/plugin_test.go diff --git a/executor/admin_plugins.go b/executor/admin_plugins.go new file mode 100644 index 0000000000000..440c1c0852306 --- /dev/null +++ b/executor/admin_plugins.go @@ -0,0 +1,52 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/plugin" + "github.com/pingcap/tidb/util/chunk" +) + +// AdminPluginsExec indicates AdminPlugins executor. +type AdminPluginsExec struct { + baseExecutor + Action core.AdminPluginsAction + Plugins []string +} + +// Next implements the Executor Next interface. +func (e *AdminPluginsExec) Next(ctx context.Context, _ *chunk.Chunk) error { + switch e.Action { + case core.Enable: + return e.changeDisableFlagAndFlush(false) + case core.Disable: + return e.changeDisableFlagAndFlush(true) + } + return nil +} + +func (e *AdminPluginsExec) changeDisableFlagAndFlush(disabled bool) error { + dom := domain.GetDomain(e.ctx) + for _, pluginName := range e.Plugins { + err := plugin.ChangeDisableFlagAndFlush(dom, pluginName, disabled) + if err != nil { + return err + } + } + return nil +} diff --git a/executor/builder.go b/executor/builder.go index 2bec6ffa8b769..94aa7fd37fdf4 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -108,6 +108,8 @@ func (b *executorBuilder) build(p plannercore.Plan) Executor { return b.buildChecksumTable(v) case *plannercore.ReloadExprPushdownBlacklist: return b.buildReloadExprPushdownBlacklist(v) + case *plannercore.AdminPlugins: + return b.buildAdminPlugins(v) case *plannercore.DDL: return b.buildDDL(v) case *plannercore.Deallocate: @@ -467,6 +469,10 @@ func (b *executorBuilder) buildReloadExprPushdownBlacklist(v *plannercore.Reload return &ReloadExprPushdownBlacklistExec{baseExecutor{ctx: b.ctx}} } +func (b *executorBuilder) buildAdminPlugins(v *plannercore.AdminPlugins) Executor { + return &AdminPluginsExec{baseExecutor: baseExecutor{ctx: b.ctx}, Action: v.Action, Plugins: v.Plugins} +} + func (b *executorBuilder) buildDeallocate(v *plannercore.Deallocate) Executor { base := newBaseExecutor(b.ctx, nil, v.ExplainID()) base.initCap = chunk.ZeroCapacity diff --git a/executor/show.go b/executor/show.go index 37fad01831aaf..0380d1e0e4f9a 100644 --- a/executor/show.go +++ b/executor/show.go @@ -1030,7 +1030,7 @@ func (e *ShowExec) fetchShowPlugins() error { tiPlugins := plugin.GetAll() for _, ps := range tiPlugins { for _, p := range ps { - e.appendRow([]interface{}{p.Name, p.State.String(), p.Kind.String(), p.Path, p.License, strconv.Itoa(int(p.Version))}) + e.appendRow([]interface{}{p.Name, p.StateValue(), p.Kind.String(), p.Path, p.License, strconv.Itoa(int(p.Version))}) } } return nil diff --git a/go.mod b/go.mod index aa59b390e05ee..793b28ed4e091 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e github.com/pingcap/kvproto v0.0.0-20190703131923-d9830856b531 github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 - github.com/pingcap/parser v0.0.0-20190613045206-37cc370a20a4 + github.com/pingcap/parser v0.0.0-20190710031629-52a9d3a79f41 github.com/pingcap/pd v0.0.0-20190424024702-bd1e2496a669 github.com/pingcap/tidb-tools v2.1.3-0.20190321065848-1e8b48f5c168+incompatible github.com/pingcap/tipb v0.0.0-20190428032612-535e1abaa330 diff --git a/go.sum b/go.sum index 08361b9f04a2f..7c97eeb4d21c2 100644 --- a/go.sum +++ b/go.sum @@ -170,8 +170,8 @@ github.com/pingcap/kvproto v0.0.0-20190703131923-d9830856b531/go.mod h1:QMdbTAXC github.com/pingcap/log v0.0.0-20190214045112-b37da76f67a7/go.mod h1:xsfkWVaFVV5B8e1K9seWfyJWFrIhbtUTAD8NV1Pq3+w= github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 h1:t2OQTpPJnrPDGlvA+3FwJptMTt6MEPdzK1Wt99oaefQ= github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw= -github.com/pingcap/parser v0.0.0-20190613045206-37cc370a20a4 h1:r5BvCTM1R9U9EjJntFREb67GMsgn8IK9vLTQ/HzRZBc= -github.com/pingcap/parser v0.0.0-20190613045206-37cc370a20a4/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/parser v0.0.0-20190710031629-52a9d3a79f41 h1:hsCjAYfXliEMyRQTiNAYHyYATfURKNSK1J0eaKfOm1w= +github.com/pingcap/parser v0.0.0-20190710031629-52a9d3a79f41/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd v0.0.0-20190424024702-bd1e2496a669 h1:ZoKjndm/Ig7Ru/wojrQkc/YLUttUdQXoH77gtuWCvL4= github.com/pingcap/pd v0.0.0-20190424024702-bd1e2496a669/go.mod h1:MUCxRzOkYiWZtlyi4MhxjCIj9PgQQ/j+BLNGm7aUsnM= github.com/pingcap/tidb-tools v2.1.3-0.20190321065848-1e8b48f5c168+incompatible h1:MkWCxgZpJBgY2f4HtwWMMFzSBb3+JPzeJgF3VrXE/bU= diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 9116e0aa8ba70..bb2f1730ce776 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -133,6 +133,23 @@ type ReloadExprPushdownBlacklist struct { baseSchemaProducer } +// AdminPluginsAction indicate action will be taken on plugins. +type AdminPluginsAction int + +const ( + // Enable indicates enable plugins. + Enable AdminPluginsAction = iota + 1 + // Disable indicates disable plugins. + Disable +) + +// AdminPlugins administrates tidb plugins. +type AdminPlugins struct { + baseSchemaProducer + Action AdminPluginsAction + Plugins []string +} + // Change represents a change plan. type Change struct { baseSchemaProducer diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 6ed3eb6c397dc..81d6af8cbb34e 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -670,6 +670,10 @@ func (b *PlanBuilder) buildAdmin(as *ast.AdminStmt) (Plan, error) { ret = p case ast.AdminReloadExprPushdownBlacklist: return &ReloadExprPushdownBlacklist{}, nil + case ast.AdminPluginEnable: + return &AdminPlugins{Action: Enable, Plugins: as.Plugins}, nil + case ast.AdminPluginDisable: + return &AdminPlugins{Action: Disable, Plugins: as.Plugins}, nil default: return nil, ErrUnsupportedType.GenWithStack("Unsupported ast.AdminStmt(%T) for buildAdmin", as) } diff --git a/plugin/audit.go b/plugin/audit.go index 8ad556495ac62..603b7e0f8982f 100644 --- a/plugin/audit.go +++ b/plugin/audit.go @@ -16,7 +16,6 @@ package plugin import ( "context" - "github.com/pingcap/parser/auth" "github.com/pingcap/tidb/sessionctx/variable" ) @@ -77,7 +76,7 @@ type AuditManifest struct { Manifest // OnConnectionEvent will be called when TiDB receive or disconnect from client. // return error will ignore and close current connection. - OnConnectionEvent func(ctx context.Context, identity *auth.UserIdentity, event ConnectionEvent, info *variable.ConnectionInfo) error + OnConnectionEvent func(ctx context.Context, event ConnectionEvent, info *variable.ConnectionInfo) error // OnGeneralEvent will be called during TiDB execution. OnGeneralEvent func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) // OnGlobalVariableEvent will be called when Change GlobalVariable. diff --git a/plugin/const_test.go b/plugin/const_test.go new file mode 100644 index 0000000000000..dd366b41d2c4e --- /dev/null +++ b/plugin/const_test.go @@ -0,0 +1,42 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "fmt" + "testing" +) + +func TestConstToString(t *testing.T) { + kinds := map[fmt.Stringer]string{ + Audit: "Audit", + Authentication: "Authentication", + Schema: "Schema", + Daemon: "Daemon", + Uninitialized: "Uninitialized", + Ready: "Ready", + Dying: "Dying", + Disable: "Disable", + Connected: "Connected", + Disconnect: "Disconnect", + ChangeUser: "ChangeUser", + PreAuth: "PreAuth", + ConnectionEvent(byte(15)): "", + } + for key, value := range kinds { + if key.String() != value { + t.Errorf("kind %s != %s", key.String(), kinds) + } + } +} diff --git a/plugin/helper_test.go b/plugin/helper_test.go new file mode 100644 index 0000000000000..1bb3fc71420ec --- /dev/null +++ b/plugin/helper_test.go @@ -0,0 +1,54 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import "testing" + +func TestPluginDeclare(t *testing.T) { + auditRaw := &AuditManifest{Manifest: Manifest{}} + auditExport := ExportManifest(auditRaw) + audit2 := DeclareAuditManifest(auditExport) + if audit2 != auditRaw { + t.Errorf("declare audit fail") + } + + authRaw := &AuthenticationManifest{Manifest: Manifest{}} + authExport := ExportManifest(authRaw) + auth2 := DeclareAuthenticationManifest(authExport) + if auth2 != authRaw { + t.Errorf("declare auth fail") + } + + schemaRaw := &SchemaManifest{Manifest: Manifest{}} + schemaExport := ExportManifest(schemaRaw) + schema2 := DeclareSchemaManifest(schemaExport) + if schema2 != schemaRaw { + t.Errorf("declare schema fail") + } + + daemonRaw := &DaemonManifest{Manifest: Manifest{}} + daemonExport := ExportManifest(daemonRaw) + daemon2 := DeclareDaemonManifest(daemonExport) + if daemon2 != daemonRaw { + t.Errorf("declare daemon fail") + } +} + +func TestDecode(t *testing.T) { + failID := ID("fail") + _, _, err := failID.Decode() + if err == nil { + t.Errorf("'fail' should not decode success") + } +} diff --git a/plugin/plugin.go b/plugin/plugin.go index a7c61cbd5819c..3821ccb401d3e 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -49,8 +49,9 @@ type plugins struct { // clone deep copies plugins info. func (p *plugins) clone() *plugins { np := &plugins{ - plugins: make(map[Kind][]Plugin, len(p.plugins)), - versions: make(map[string]uint16, len(p.versions)), + plugins: make(map[Kind][]Plugin, len(p.plugins)), + versions: make(map[string]uint16, len(p.versions)), + dyingPlugins: make([]Plugin, len(p.dyingPlugins)), } for key, value := range p.plugins { np.plugins[key] = append([]Plugin(nil), value...) @@ -94,39 +95,31 @@ type Config struct { // Plugin presents a TiDB plugin. type Plugin struct { *Manifest - library *gplugin.Plugin - State State - Path string + library *gplugin.Plugin + State State + Path string + Disabled uint32 } -type validateMode int - -const ( - initMode validateMode = iota - reloadMode -) +// StateValue returns readable state string. +func (p *Plugin) StateValue() string { + flag := "enable" + if atomic.LoadUint32(&p.Disabled) == 1 { + flag = "disable" + } + return p.State.String() + "-" + flag +} -func (p *Plugin) validate(ctx context.Context, tiPlugins *plugins, mode validateMode) error { - if mode == reloadMode { - var oldPlugin *Plugin - for i, item := range tiPlugins.plugins[p.Kind] { - if item.Name == p.Name { - oldPlugin = &tiPlugins.plugins[p.Kind][i] - break - } - } - if oldPlugin == nil { - return errUnsupportedReloadPlugin.GenWithStackByArgs(p.Name) - } - if len(p.SysVars) != len(oldPlugin.SysVars) { - return errUnsupportedReloadPluginVar.GenWithStackByArgs("") - } - for varName, varVal := range p.SysVars { - if oldPlugin.SysVars[varName] == nil || *oldPlugin.SysVars[varName] != *varVal { - return errUnsupportedReloadPluginVar.GenWithStackByArgs(varVal) - } - } +// DisableFlag changes the disable flag of plugin. +func (p *Plugin) DisableFlag(disable bool) { + if disable { + atomic.StoreUint32(&p.Disabled, 1) + } else { + atomic.StoreUint32(&p.Disabled, 0) } +} + +func (p *Plugin) validate(ctx context.Context, tiPlugins *plugins) error { if p.RequireVersion != nil { for component, reqVer := range p.RequireVersion { if ver, ok := tiPlugins.versions[component]; !ok || ver < reqVer { @@ -197,7 +190,7 @@ func Load(ctx context.Context, cfg Config) (err error) { // Cross validate & Load plugins. for kind := range tiPlugins.plugins { for i := range tiPlugins.plugins[kind] { - if err = tiPlugins.plugins[kind][i].validate(ctx, tiPlugins, initMode); err != nil { + if err = tiPlugins.plugins[kind][i].validate(ctx, tiPlugins); err != nil { if cfg.SkipWhenFail { logutil.Logger(ctx).Warn("validate plugin fail and disable plugin", zap.String("plugin", tiPlugins.plugins[kind][i].Name), zap.Error(err)) @@ -251,6 +244,7 @@ func Init(ctx context.Context, cfg Config) (err error) { path: pluginWatchPrefix + tiPlugins.plugins[kind][i].Name, etcd: cfg.EtcdClient, manifest: tiPlugins.plugins[kind][i].Manifest, + plugin: &tiPlugins.plugins[kind][i], } tiPlugins.plugins[kind][i].flushWatcher = watcher go util.WithRecovery(watcher.watchLoop, nil) @@ -267,6 +261,7 @@ type flushWatcher struct { path string etcd *clientv3.Client manifest *Manifest + plugin *Plugin } func (w *flushWatcher) watchLoop() { @@ -276,7 +271,16 @@ func (w *flushWatcher) watchLoop() { case <-w.ctx.Done(): return case <-watchChan: - err := w.manifest.OnFlush(w.ctx, w.manifest) + disabled, err := w.getPluginDisabledFlag() + if err != nil { + logutil.Logger(context.Background()).Error("get plugin disabled flag failure", zap.String("plugin", w.manifest.Name), zap.Error(err)) + } + if disabled { + atomic.StoreUint32(&w.manifest.flushWatcher.plugin.Disabled, 1) + } else { + atomic.StoreUint32(&w.manifest.flushWatcher.plugin.Disabled, 0) + } + err = w.manifest.OnFlush(w.ctx, w.manifest) if err != nil { logutil.Logger(context.Background()).Error("notify plugin flush event failed", zap.String("plugin", w.manifest.Name), zap.Error(err)) } @@ -284,26 +288,39 @@ func (w *flushWatcher) watchLoop() { } } -func loadOne(dir string, pluginID ID) (plugin Plugin, err error) { - plugin.Path = filepath.Join(dir, string(pluginID)+LibrarySuffix) - plugin.library, err = gplugin.Open(plugin.Path) +func (w *flushWatcher) getPluginDisabledFlag() (bool, error) { + if w == nil || w.etcd == nil { + return true, errors.New("etcd is need to get plugin enable status") + } + resp, err := w.etcd.Get(context.Background(), w.manifest.flushWatcher.path) if err != nil { - err = errors.Trace(err) - return + return true, errors.Trace(err) } - manifestSym, err := plugin.library.Lookup(ManifestSymbol) + if len(resp.Kvs) == 0 { + return false, nil + } + return string(resp.Kvs[0].Value) == "1", nil +} + +type loadFn func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) + +var testHook *struct { + loadOne loadFn +} + +func loadOne(dir string, pluginID ID) (plugin Plugin, err error) { + pName, pVersion, err := pluginID.Decode() if err != nil { err = errors.Trace(err) return } - manifest, ok := manifestSym.(func() *Manifest) - if !ok { - err = errInvalidPluginManifest.GenWithStackByArgs(string(pluginID)) - return + var manifest func() *Manifest + if testHook == nil { + manifest, err = loadManifestByGoPlugin(&plugin, dir, pluginID) + } else { + manifest, err = testHook.loadOne(&plugin, dir, pluginID) } - pName, pVersion, err := pluginID.Decode() if err != nil { - err = errors.Trace(err) return } plugin.Manifest = manifest() @@ -318,6 +335,27 @@ func loadOne(dir string, pluginID ID) (plugin Plugin, err error) { return } +func loadManifestByGoPlugin(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) { + plugin.Path = filepath.Join(dir, string(pluginID)+LibrarySuffix) + plugin.library, err = gplugin.Open(plugin.Path) + if err != nil { + err = errors.Trace(err) + return + } + manifestSym, err := plugin.library.Lookup(ManifestSymbol) + if err != nil { + err = errors.Trace(err) + return + } + var ok bool + manifest, ok = manifestSym.(func() *Manifest) + if !ok { + err = errInvalidPluginManifest.GenWithStackByArgs(string(pluginID)) + return + } + return +} + // Shutdown cleanups all plugin resources. // Notice: it just cleanups the resource of plugin, but cannot unload plugins(limited by go plugin). func Shutdown(ctx context.Context) { @@ -332,6 +370,9 @@ func Shutdown(ctx context.Context) { if p.flushWatcher != nil { p.flushWatcher.cancel() } + if p.OnShutdown == nil { + continue + } if err := p.OnShutdown(ctx, p.Manifest); err != nil { logutil.Logger(ctx).Error("call OnShutdown for failure", zap.String("plugin", p.Name), zap.Error(err)) @@ -369,6 +410,9 @@ func ForeachPlugin(kind Kind, fn func(plugin *Plugin) error) error { if p.State != Ready { continue } + if atomic.LoadUint32(&p.Disabled) == 1 { + continue + } err := fn(p) if err != nil { return err @@ -385,7 +429,7 @@ func IsEnable(kind Kind) bool { } for i := range plugins.plugins[kind] { p := &plugins.plugins[kind][i] - if p.State == Ready { + if p.State == Ready && atomic.LoadUint32(&p.Disabled) != 1 { return true } } @@ -407,7 +451,25 @@ func NotifyFlush(dom *domain.Domain, pluginName string) error { if p == nil || p.Manifest.flushWatcher == nil || p.State != Ready { return errors.Errorf("plugin %s doesn't exists or unsupported flush or doesn't start with PD", pluginName) } - _, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, "") + _, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, strconv.Itoa(int(p.Disabled))) + if err != nil { + return err + } + return nil +} + +// ChangeDisableFlagAndFlush changes plugin disable flag and notify other nodes to do same change. +func ChangeDisableFlagAndFlush(dom *domain.Domain, pluginName string, disable bool) error { + p := getByName(pluginName) + if p == nil || p.Manifest.flushWatcher == nil || p.State != Ready { + return errors.Errorf("plugin %s doesn't exists or unsupported flush or doesn't start with PD", pluginName) + } + disableInt := uint32(0) + if disable { + disableInt = 1 + } + atomic.StoreUint32(&p.Disabled, disableInt) + _, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, strconv.Itoa(int(disableInt))) if err != nil { return err } diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go new file mode 100644 index 0000000000000..0f5acb6b26616 --- /dev/null +++ b/plugin/plugin_test.go @@ -0,0 +1,281 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "context" + "io" + "strconv" + "testing" + + "github.com/pingcap/tidb/sessionctx/variable" +) + +func TestLoadPluginSuccess(t *testing.T) { + ctx := context.Background() + + pluginName := "tplugin" + pluginVersion := uint16(1) + pluginSign := pluginName + "-" + strconv.Itoa(int(pluginVersion)) + + cfg := Config{ + Plugins: []string{pluginSign}, + PluginDir: "", + GlobalSysVar: &variable.SysVars, + PluginVarNames: &variable.PluginVarNames, + EnvVersion: map[string]uint16{"go": 1112}, + } + + // setup load test hook. + testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) { + return func() *Manifest { + m := &AuditManifest{ + Manifest: Manifest{ + Kind: Authentication, + Name: pluginName, + Version: pluginVersion, + SysVars: map[string]*variable.SysVar{pluginName + "_key": {Scope: variable.ScopeGlobal, Name: pluginName + "_key", Value: "v1"}}, + OnInit: func(ctx context.Context, manifest *Manifest) error { + return nil + }, + OnShutdown: func(ctx context.Context, manifest *Manifest) error { + return nil + }, + Validate: func(ctx context.Context, manifest *Manifest) error { + return nil + }, + }, + OnGeneralEvent: func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) { + }, + } + return ExportManifest(m) + }, nil + }} + defer func() { + testHook = nil + }() + + // trigger load. + err := Load(ctx, cfg) + if err != nil { + t.Errorf("load plugin [%s] fail", pluginSign) + } + + err = Init(ctx, cfg) + if err != nil { + t.Errorf("init plugin [%s] fail", pluginSign) + } + + // load all. + ps := GetAll() + if len(ps) != 1 { + t.Errorf("loaded plugins is empty") + } + + // find plugin by type and name + p := Get(Authentication, "tplugin") + if p == nil { + t.Errorf("tplugin can not be load") + } + p = Get(Authentication, "tplugin2") + if p != nil { + t.Errorf("found miss plugin") + } + p = getByName("tplugin") + if p == nil { + t.Errorf("can not find miss plugin") + } + + // foreach plugin + err = ForeachPlugin(Authentication, func(plugin *Plugin) error { + return nil + }) + if err != nil { + t.Errorf("foreach error %v", err) + } + err = ForeachPlugin(Authentication, func(plugin *Plugin) error { + return io.EOF + }) + if err != io.EOF { + t.Errorf("foreach should return EOF error") + } + + Shutdown(ctx) +} + +func TestLoadPluginSkipError(t *testing.T) { + ctx := context.Background() + + pluginName := "tplugin" + pluginVersion := uint16(1) + pluginSign := pluginName + "-" + strconv.Itoa(int(pluginVersion)) + + cfg := Config{ + Plugins: []string{pluginSign, pluginSign, "notExists-2"}, + PluginDir: "", + PluginVarNames: &variable.PluginVarNames, + EnvVersion: map[string]uint16{"go": 1112}, + SkipWhenFail: true, + } + + // setup load test hook. + testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) { + return func() *Manifest { + m := &AuditManifest{ + Manifest: Manifest{ + Kind: Audit, + Name: pluginName, + Version: pluginVersion, + SysVars: map[string]*variable.SysVar{pluginName + "_key": {Scope: variable.ScopeGlobal, Name: pluginName + "_key", Value: "v1"}}, + OnInit: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + OnShutdown: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + Validate: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + }, + OnGeneralEvent: func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) { + }, + } + return ExportManifest(m) + }, nil + }} + defer func() { + testHook = nil + }() + + // trigger load. + err := Load(ctx, cfg) + if err != nil { + t.Errorf("load plugin [%s] fail %v", pluginSign, err) + } + + err = Init(ctx, cfg) + if err != nil { + t.Errorf("init plugin [%s] fail", pluginSign) + } + + // load all. + ps := GetAll() + if len(ps) != 1 { + t.Errorf("loaded plugins is empty") + } + + // find plugin by type and name + p := Get(Audit, "tplugin") + if p == nil { + t.Errorf("tplugin can not be load") + } + p = Get(Audit, "tplugin2") + if p != nil { + t.Errorf("found miss plugin") + } + p = getByName("tplugin") + if p == nil { + t.Errorf("can not find miss plugin") + } + p = getByName("not exists") + if p != nil { + t.Errorf("got not exists plugin") + } + + // foreach plugin + readyCount := 0 + err = ForeachPlugin(Authentication, func(plugin *Plugin) error { + readyCount++ + return nil + }) + if err != nil { + t.Errorf("foreach meet error %v", err) + } + if readyCount != 0 { + t.Errorf("validate fail can be load but no ready") + } + + Shutdown(ctx) +} + +func TestLoadFail(t *testing.T) { + ctx := context.Background() + + pluginName := "tplugin" + pluginVersion := uint16(1) + pluginSign := pluginName + "-" + strconv.Itoa(int(pluginVersion)) + + cfg := Config{ + Plugins: []string{pluginSign, pluginSign, "notExists-2"}, + PluginDir: "", + PluginVarNames: &variable.PluginVarNames, + EnvVersion: map[string]uint16{"go": 1112}, + SkipWhenFail: false, + } + + // setup load test hook. + testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) { + return func() *Manifest { + m := &AuditManifest{ + Manifest: Manifest{ + Kind: Audit, + Name: pluginName, + Version: pluginVersion, + SysVars: map[string]*variable.SysVar{pluginName + "_key": {Scope: variable.ScopeGlobal, Name: pluginName + "_key", Value: "v1"}}, + OnInit: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + OnShutdown: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + Validate: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + }, + OnGeneralEvent: func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) { + }, + } + return ExportManifest(m) + }, nil + }} + defer func() { + testHook = nil + }() + + err := Load(ctx, cfg) + if err == nil { + t.Errorf("load plugin should fail") + } +} + +func TestPluginsClone(t *testing.T) { + ps := &plugins{ + plugins: map[Kind][]Plugin{ + Audit: {{}}, + }, + versions: map[string]uint16{ + "whitelist": 1, + }, + dyingPlugins: []Plugin{{}}, + } + cps := ps.clone() + ps.dyingPlugins = append(ps.dyingPlugins, Plugin{}) + ps.versions["w"] = 2 + as := ps.plugins[Audit] + ps.plugins[Audit] = append(as, Plugin{}) + + if len(cps.plugins) != 1 || len(cps.plugins[Audit]) != 1 || len(cps.versions) != 1 || len(cps.dyingPlugins) != 1 { + t.Errorf("clone plugins failure") + } +} diff --git a/server/conn.go b/server/conn.go index 7a56f1f799f8f..eb23fb118c8c3 100644 --- a/server/conn.go +++ b/server/conn.go @@ -1442,7 +1442,7 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { connInfo := cc.ctx.GetSessionVars().ConnectionInfo - err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: connInfo.Host}, plugin.ChangeUser, connInfo) + err = authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo) if err != nil { return err } diff --git a/server/server.go b/server/server.go index dfc608682023a..43ceefe65e746 100644 --- a/server/server.go +++ b/server/server.go @@ -49,7 +49,6 @@ import ( "github.com/blacktear23/go-proxyprotocol" "github.com/pingcap/errors" - "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/config" @@ -342,7 +341,7 @@ func (s *Server) Run() error { terror.Log(clientConn.Close()) return errors.Trace(err) } - err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: host}, plugin.PreAuth, nil) + err = authPlugin.OnConnectionEvent(context.Background(), plugin.PreAuth, &variable.ConnectionInfo{Host: host}) if err != nil { logutil.Logger(context.Background()).Info("do connection event failed", zap.Error(err)) terror.Log(clientConn.Close()) @@ -429,7 +428,7 @@ func (s *Server) onConn(conn *clientConn) { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { sessionVars := conn.ctx.GetSessionVars() - return authPlugin.OnConnectionEvent(context.Background(), sessionVars.User, plugin.Connected, sessionVars.ConnectionInfo) + return authPlugin.OnConnectionEvent(context.Background(), plugin.Connected, sessionVars.ConnectionInfo) } return nil }) @@ -445,7 +444,7 @@ func (s *Server) onConn(conn *clientConn) { if authPlugin.OnConnectionEvent != nil { sessionVars := conn.ctx.GetSessionVars() sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) - err := authPlugin.OnConnectionEvent(context.Background(), sessionVars.User, plugin.Disconnect, sessionVars.ConnectionInfo) + err := authPlugin.OnConnectionEvent(context.Background(), plugin.Disconnect, sessionVars.ConnectionInfo) if err != nil { logutil.Logger(context.Background()).Warn("do connection event failed", zap.String("plugin", authPlugin.Name), zap.Error(err)) }