diff --git a/assert/assertions.go b/assert/assertions.go index fa1245b18..2924cf3a1 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -8,7 +8,6 @@ import ( "fmt" "math" "os" - "path/filepath" "reflect" "regexp" "runtime" @@ -141,12 +140,11 @@ func CallerInfo() []string { } parts := strings.Split(file, "/") - file = parts[len(parts)-1] if len(parts) > 1 { + filename := parts[len(parts)-1] dir := parts[len(parts)-2] - if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" { - path, _ := filepath.Abs(file) - callers = append(callers, fmt.Sprintf("%s:%d", path, line)) + if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" { + callers = append(callers, fmt.Sprintf("%s:%d", file, line)) } } @@ -530,7 +528,7 @@ func isNil(object interface{}) bool { []reflect.Kind{ reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, - reflect.Ptr, reflect.Slice}, + reflect.Ptr, reflect.Slice, reflect.UnsafePointer}, kind) if isNilableKind && value.IsNil() { @@ -818,49 +816,44 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok return true // we consider nil to be equal to the nil set } - defer func() { - if e := recover(); e != nil { - ok = false - } - }() - listKind := reflect.TypeOf(list).Kind() - subsetKind := reflect.TypeOf(subset).Kind() - if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) } + subsetKind := reflect.TypeOf(subset).Kind() if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) } - subsetValue := reflect.ValueOf(subset) if subsetKind == reflect.Map && listKind == reflect.Map { - listValue := reflect.ValueOf(list) - subsetKeys := subsetValue.MapKeys() + subsetMap := reflect.ValueOf(subset) + actualMap := reflect.ValueOf(list) - for i := 0; i < len(subsetKeys); i++ { - subsetKey := subsetKeys[i] - subsetElement := subsetValue.MapIndex(subsetKey).Interface() - listElement := listValue.MapIndex(subsetKey).Interface() + for _, k := range subsetMap.MapKeys() { + ev := subsetMap.MapIndex(k) + av := actualMap.MapIndex(k) - if !ObjectsAreEqual(subsetElement, listElement) { - return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...) + if !av.IsValid() { + return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, subset), msgAndArgs...) + } + if !ObjectsAreEqual(ev.Interface(), av.Interface()) { + return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, subset), msgAndArgs...) } } return true } - for i := 0; i < subsetValue.Len(); i++ { - element := subsetValue.Index(i).Interface() + subsetList := reflect.ValueOf(subset) + for i := 0; i < subsetList.Len(); i++ { + element := subsetList.Index(i).Interface() ok, found := containsElement(list, element) if !ok { - return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) + return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", list), msgAndArgs...) } if !found { - return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, element), msgAndArgs...) + return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, element), msgAndArgs...) } } @@ -879,34 +872,28 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...) } - defer func() { - if e := recover(); e != nil { - ok = false - } - }() - listKind := reflect.TypeOf(list).Kind() - subsetKind := reflect.TypeOf(subset).Kind() - if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) } + subsetKind := reflect.TypeOf(subset).Kind() if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) } - subsetValue := reflect.ValueOf(subset) if subsetKind == reflect.Map && listKind == reflect.Map { - listValue := reflect.ValueOf(list) - subsetKeys := subsetValue.MapKeys() + subsetMap := reflect.ValueOf(subset) + actualMap := reflect.ValueOf(list) - for i := 0; i < len(subsetKeys); i++ { - subsetKey := subsetKeys[i] - subsetElement := subsetValue.MapIndex(subsetKey).Interface() - listElement := listValue.MapIndex(subsetKey).Interface() + for _, k := range subsetMap.MapKeys() { + ev := subsetMap.MapIndex(k) + av := actualMap.MapIndex(k) - if !ObjectsAreEqual(subsetElement, listElement) { + if !av.IsValid() { + return true + } + if !ObjectsAreEqual(ev.Interface(), av.Interface()) { return true } } @@ -914,8 +901,9 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...) } - for i := 0; i < subsetValue.Len(); i++ { - element := subsetValue.Index(i).Interface() + subsetList := reflect.ValueOf(subset) + for i := 0; i < subsetList.Len(); i++ { + element := subsetList.Index(i).Interface() ok, found := containsElement(list, element) if !ok { return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) diff --git a/assert/assertions_test.go b/assert/assertions_test.go index bdab184c1..cae11f81d 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -15,6 +15,7 @@ import ( "strings" "testing" "time" + "unsafe" ) var ( @@ -671,20 +672,18 @@ func TestContainsNotContainsOnNilValue(t *testing.T) { } func TestSubsetNotSubset(t *testing.T) { - - // MTestCase adds a custom message to the case cases := []struct { - expected interface{} - actual interface{} - result bool - message string + list interface{} + subset interface{} + result bool + message string }{ // cases that are expected to contain - {[]int{1, 2, 3}, nil, true, "given subset is nil"}, - {[]int{1, 2, 3}, []int{}, true, "any set contains the nil set"}, - {[]int{1, 2, 3}, []int{1, 2}, true, "[1, 2, 3] contains [1, 2]"}, - {[]int{1, 2, 3}, []int{1, 2, 3}, true, "[1, 2, 3] contains [1, 2, 3"}, - {[]string{"hello", "world"}, []string{"hello"}, true, "[\"hello\", \"world\"] contains [\"hello\"]"}, + {[]int{1, 2, 3}, nil, true, `nil is the empty set which is a subset of every set`}, + {[]int{1, 2, 3}, []int{}, true, `[] is a subset of ['\x01' '\x02' '\x03']`}, + {[]int{1, 2, 3}, []int{1, 2}, true, `['\x01' '\x02'] is a subset of ['\x01' '\x02' '\x03']`}, + {[]int{1, 2, 3}, []int{1, 2, 3}, true, `['\x01' '\x02' '\x03'] is a subset of ['\x01' '\x02' '\x03']`}, + {[]string{"hello", "world"}, []string{"hello"}, true, `["hello"] is a subset of ["hello" "world"]`}, {map[string]string{ "a": "x", "c": "z", @@ -692,12 +691,12 @@ func TestSubsetNotSubset(t *testing.T) { }, map[string]string{ "a": "x", "b": "y", - }, true, `{ "a": "x", "b": "y", "c": "z"} contains { "a": "x", "b": "y"}`}, + }, true, `map["a":"x" "b":"y"] is a subset of map["a":"x" "b":"y" "c":"z"]`}, // cases that are expected not to contain - {[]string{"hello", "world"}, []string{"hello", "testify"}, false, "[\"hello\", \"world\"] does not contain [\"hello\", \"testify\"]"}, - {[]int{1, 2, 3}, []int{4, 5}, false, "[1, 2, 3] does not contain [4, 5"}, - {[]int{1, 2, 3}, []int{1, 5}, false, "[1, 2, 3] does not contain [1, 5]"}, + {[]string{"hello", "world"}, []string{"hello", "testify"}, false, `[]string{"hello", "world"} does not contain "testify"`}, + {[]int{1, 2, 3}, []int{4, 5}, false, `[]int{1, 2, 3} does not contain 4`}, + {[]int{1, 2, 3}, []int{1, 5}, false, `[]int{1, 2, 3} does not contain 5`}, {map[string]string{ "a": "x", "c": "z", @@ -705,35 +704,51 @@ func TestSubsetNotSubset(t *testing.T) { }, map[string]string{ "a": "x", "b": "z", - }, false, `{ "a": "x", "b": "y", "c": "z"} does not contain { "a": "x", "b": "z"}`}, + }, false, `map[string]string{"a":"x", "b":"y", "c":"z"} does not contain map[string]string{"a":"x", "b":"z"}`}, + {map[string]string{ + "a": "x", + "b": "y", + }, map[string]string{ + "a": "x", + "b": "y", + "c": "z", + }, false, `map[string]string{"a":"x", "b":"y"} does not contain map[string]string{"a":"x", "b":"y", "c":"z"}`}, } for _, c := range cases { t.Run("SubSet: "+c.message, func(t *testing.T) { - mockT := new(testing.T) - res := Subset(mockT, c.expected, c.actual) + mockT := new(mockTestingT) + res := Subset(mockT, c.list, c.subset) if res != c.result { - if res { - t.Errorf("Subset should return true: %s", c.message) - } else { - t.Errorf("Subset should return false: %s", c.message) + t.Errorf("Subset should return %t: %s", c.result, c.message) + } + if !c.result { + expectedFail := c.message + actualFail := mockT.errorString() + if !strings.Contains(actualFail, expectedFail) { + t.Log(actualFail) + t.Errorf("Subset failure should contain %q but was %q", expectedFail, actualFail) } } }) } for _, c := range cases { t.Run("NotSubSet: "+c.message, func(t *testing.T) { - mockT := new(testing.T) - res := NotSubset(mockT, c.expected, c.actual) + mockT := new(mockTestingT) + res := NotSubset(mockT, c.list, c.subset) // NotSubset should match the inverse of Subset. If it doesn't, something is wrong - if res == Subset(mockT, c.expected, c.actual) { - if res { - t.Errorf("NotSubset should return true: %s", c.message) - } else { - t.Errorf("NotSubset should return false: %s", c.message) + if res == Subset(mockT, c.list, c.subset) { + t.Errorf("NotSubset should return %t: %s", !c.result, c.message) + } + if c.result { + expectedFail := c.message + actualFail := mockT.errorString() + if !strings.Contains(actualFail, expectedFail) { + t.Log(actualFail) + t.Errorf("NotSubset failure should contain %q but was %q", expectedFail, actualFail) } } }) @@ -2558,3 +2573,10 @@ func TestErrorAs(t *testing.T) { }) } } + +func TestIsNil(t *testing.T) { + var n unsafe.Pointer = nil + if !isNil(n) { + t.Fatal("fail") + } +} diff --git a/mock/mock.go b/mock/mock.go index f0af8246c..e6ff8dfeb 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -218,16 +218,22 @@ func (c *Call) Unset() *Call { foundMatchingCall := false - for i, call := range c.Parent.ExpectedCalls { + // in-place filter slice for calls to be removed - iterate from 0'th to last skipping unnecessary ones + var index int // write index + for _, call := range c.Parent.ExpectedCalls { if call.Method == c.Method { _, diffCount := call.Arguments.Diff(c.Arguments) if diffCount == 0 { foundMatchingCall = true - // Remove from ExpectedCalls - c.Parent.ExpectedCalls = append(c.Parent.ExpectedCalls[:i], c.Parent.ExpectedCalls[i+1:]...) + // Remove from ExpectedCalls - just skip it + continue } } + c.Parent.ExpectedCalls[index] = call + index++ } + // trim slice up to last copied index + c.Parent.ExpectedCalls = c.Parent.ExpectedCalls[:index] if !foundMatchingCall { unlockOnce.Do(c.unlock) diff --git a/mock/mock_test.go b/mock/mock_test.go index f77dfe461..260bb9c4f 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -569,6 +569,80 @@ func Test_Mock_UnsetIfAlreadyUnsetFails(t *testing.T) { assert.Equal(t, 0, len(mockedService.ExpectedCalls)) } +func Test_Mock_UnsetByOnMethodSpec(t *testing.T) { + // make a test impl object + var mockedService = new(TestExampleImplementation) + + mock1 := mockedService. + On("TheExampleMethod", 1, 2, 3). + Return(0, nil) + + assert.Equal(t, 1, len(mockedService.ExpectedCalls)) + mock1.On("TheExampleMethod", 1, 2, 3). + Return(0, nil).Unset() + + assert.Equal(t, 0, len(mockedService.ExpectedCalls)) + + assert.Panics(t, func() { + mock1.Unset() + }) + + assert.Equal(t, 0, len(mockedService.ExpectedCalls)) +} + +func Test_Mock_UnsetByOnMethodSpecAmongOthers(t *testing.T) { + // make a test impl object + var mockedService = new(TestExampleImplementation) + + _, filename, line, _ := runtime.Caller(0) + mock1 := mockedService. + On("TheExampleMethod", 1, 2, 3). + Return(0, nil). + On("TheExampleMethodVariadic", 1, 2, 3, 4, 5).Once(). + Return(nil) + mock1. + On("TheExampleMethodFuncType", Anything). + Return(nil) + + assert.Equal(t, 3, len(mockedService.ExpectedCalls)) + mock1.On("TheExampleMethod", 1, 2, 3). + Return(0, nil).Unset() + + assert.Equal(t, 2, len(mockedService.ExpectedCalls)) + + expectedCalls := []*Call{ + { + Parent: &mockedService.Mock, + Method: "TheExampleMethodVariadic", + Repeatability: 1, + Arguments: []interface{}{1, 2, 3, 4, 5}, + ReturnArguments: []interface{}{nil}, + callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+4)}, + }, + { + Parent: &mockedService.Mock, + Method: "TheExampleMethodFuncType", + Arguments: []interface{}{Anything}, + ReturnArguments: []interface{}{nil}, + callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+7)}, + }, + } + + assert.Equal(t, 2, len(mockedService.ExpectedCalls)) + assert.Equal(t, expectedCalls, mockedService.ExpectedCalls) +} + +func Test_Mock_Unset_WithFuncPanics(t *testing.T) { + // make a test impl object + var mockedService = new(TestExampleImplementation) + mock1 := mockedService.On("TheExampleMethod", 1) + mock1.Arguments = append(mock1.Arguments, func(string) error { return nil }) + + assert.Panics(t, func() { + mock1.Unset() + }) +} + func Test_Mock_Return(t *testing.T) { // make a test impl object diff --git a/suite/interfaces.go b/suite/interfaces.go index 8b98a8af2..fed037d7f 100644 --- a/suite/interfaces.go +++ b/suite/interfaces.go @@ -7,6 +7,7 @@ import "testing" type TestingSuite interface { T() *testing.T SetT(*testing.T) + SetS(suite TestingSuite) } // SetupAllSuite has a SetupSuite method, which will run before the @@ -51,3 +52,15 @@ type AfterTest interface { type WithStats interface { HandleStats(suiteName string, stats *SuiteInformation) } + +// SetupSubTest has a SetupSubTest method, which will run before each +// subtest in the suite. +type SetupSubTest interface { + SetupSubTest() +} + +// TearDownSubTest has a TearDownSubTest method, which will run after +// each subtest in the suite have been run. +type TearDownSubTest interface { + TearDownSubTest() +} diff --git a/suite/suite.go b/suite/suite.go index 895591878..8b4202d89 100644 --- a/suite/suite.go +++ b/suite/suite.go @@ -22,9 +22,13 @@ var matchMethod = flag.String("testify.m", "", "regular expression to select tes // retrieving the current *testing.T context. type Suite struct { *assert.Assertions + mu sync.RWMutex require *require.Assertions t *testing.T + + // Parent suite to have access to the implemented methods of parent struct + s TestingSuite } // T retrieves the current *testing.T context. @@ -43,6 +47,12 @@ func (suite *Suite) SetT(t *testing.T) { suite.require = require.New(t) } +// SetS needs to set the current test suite as parent +// to get access to the parent methods +func (suite *Suite) SetS(s TestingSuite) { + suite.s = s +} + // Require returns a require context for suite. func (suite *Suite) Require() *require.Assertions { suite.mu.Lock() @@ -85,7 +95,18 @@ func failOnPanic(t *testing.T, r interface{}) { // Provides compatibility with go test pkg -run TestSuite/TestName/SubTestName. func (suite *Suite) Run(name string, subtest func()) bool { oldT := suite.T() - defer suite.SetT(oldT) + + if setupSubTest, ok := suite.s.(SetupSubTest); ok { + setupSubTest.SetupSubTest() + } + + defer func() { + suite.SetT(oldT) + if tearDownSubTest, ok := suite.s.(TearDownSubTest); ok { + tearDownSubTest.TearDownSubTest() + } + }() + return oldT.Run(name, func(t *testing.T) { suite.SetT(t) subtest() @@ -98,6 +119,7 @@ func Run(t *testing.T, suite TestingSuite) { defer recoverAndFailOnPanic(t) suite.SetT(t) + suite.SetS(suite) var suiteSetupDone bool diff --git a/suite/suite_test.go b/suite/suite_test.go index 446029a43..d684f52b9 100644 --- a/suite/suite_test.go +++ b/suite/suite_test.go @@ -151,14 +151,16 @@ type SuiteTester struct { Suite // Keep counts of how many times each method is run. - SetupSuiteRunCount int - TearDownSuiteRunCount int - SetupTestRunCount int - TearDownTestRunCount int - TestOneRunCount int - TestTwoRunCount int - TestSubtestRunCount int - NonTestMethodRunCount int + SetupSuiteRunCount int + TearDownSuiteRunCount int + SetupTestRunCount int + TearDownTestRunCount int + TestOneRunCount int + TestTwoRunCount int + TestSubtestRunCount int + NonTestMethodRunCount int + SetupSubTestRunCount int + TearDownSubTestRunCount int SuiteNameBefore []string TestNameBefore []string @@ -255,6 +257,14 @@ func (suite *SuiteTester) TestSubtest() { } } +func (suite *SuiteTester) TearDownSubTest() { + suite.TearDownSubTestRunCount++ +} + +func (suite *SuiteTester) SetupSubTest() { + suite.SetupSubTestRunCount++ +} + type SuiteSkipTester struct { // Include our basic suite logic. Suite @@ -336,6 +346,9 @@ func TestRunSuite(t *testing.T) { assert.Equal(t, suiteTester.TestTwoRunCount, 1) assert.Equal(t, suiteTester.TestSubtestRunCount, 1) + assert.Equal(t, suiteTester.TearDownSubTestRunCount, 2) + assert.Equal(t, suiteTester.SetupSubTestRunCount, 2) + // Methods that don't match the test method identifier shouldn't // have been run at all. assert.Equal(t, suiteTester.NonTestMethodRunCount, 0)