diff --git a/internal/helpers.go b/internal/helpers.go index 1f75687199..0214aa7164 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -9,6 +9,7 @@ import ( "math/rand" "os" "path/filepath" + "reflect" "strings" "sync" "testing" @@ -16,6 +17,8 @@ import ( "github.com/databricks/cli/cmd/root" _ "github.com/databricks/cli/cmd/version" + "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/stretchr/testify/require" _ "github.com/databricks/cli/cmd/workspace" @@ -82,6 +85,33 @@ func consumeLines(ctx context.Context, wg *sync.WaitGroup, r io.Reader) <-chan s return ch } +func (t *cobraTestRunner) registerFlagCleanup(c *cobra.Command) { + // Find target command that will be run. Example: if the command run is `databricks fs cp`, + // target command corresponds to `cp` + targetCmd, _, err := c.Find(t.args) + require.NoError(t, err) + + // Force initialization of default flags. + // These are initialized by cobra at execution time and would otherwise + // not be cleaned up by the cleanup function below. + targetCmd.InitDefaultHelpFlag() + targetCmd.InitDefaultVersionFlag() + + // Restore flag values to their original value on test completion. + targetCmd.Flags().VisitAll(func(f *pflag.Flag) { + v := reflect.ValueOf(f.Value) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + // Store copy of the current flag value. + reset := reflect.New(v.Type()).Elem() + reset.Set(v) + t.Cleanup(func() { + v.Set(reset) + }) + }) +} + func (t *cobraTestRunner) RunBackground() { var stdoutR, stderrR io.Reader var stdoutW, stderrW io.WriteCloser @@ -92,6 +122,12 @@ func (t *cobraTestRunner) RunBackground() { root.SetErr(stderrW) root.SetArgs(t.args) + // Register cleanup function to restore flags to their original values + // once test has been executed. This is needed because flag values reside + // in a global singleton data-structure, and thus subsequent tests might + // otherwise interfere with each other + t.registerFlagCleanup(root) + errch := make(chan error) ctx, cancel := context.WithCancel(context.Background())