Skip to content

Commit

Permalink
Add variadic support to mg.F (magefile#402)
Browse files Browse the repository at this point in the history
Allows to pass sh.Run to mg.F as such:

	mg.Deps(
		mg.F(sh.Run, "go", "test", "./..."),
	)

This improves the magefile by removing some of the one-liner functions
that you might end up with that are only used through mg.Deps.

Resolves magefile#401.
  • Loading branch information
perj authored Mar 23, 2022
1 parent 30b9022 commit 051a55c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
19 changes: 15 additions & 4 deletions mg/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ func checkF(target interface{}, args []interface{}) (hasContext, isNamespace boo
return false, false, fmt.Errorf("target's return value is not an error")
}

// more inputs than slots is always an error
if len(args) > t.NumIn() {
// more inputs than slots is an error if not variadic
if len(args) > t.NumIn() && !t.IsVariadic() {
return false, false, fmt.Errorf("too many arguments for target, got %d for %T", len(args), target)
}

Expand Down Expand Up @@ -142,20 +142,31 @@ func checkF(target interface{}, args []interface{}) (hasContext, isNamespace boo
x++
}

if len(args) != inputs {
if t.IsVariadic() {
if len(args) < inputs-1 {
return false, false, fmt.Errorf("too few arguments for target, got %d for %T", len(args), target)

}
} else if len(args) != inputs {
return false, false, fmt.Errorf("wrong number of arguments for target, got %d for %T", len(args), target)
}

for _, arg := range args {
argT := t.In(x)
if t.IsVariadic() && x == t.NumIn()-1 {
// For the variadic argument, use the slice element type.
argT = argT.Elem()
}
if !argTypes[argT] {
return false, false, fmt.Errorf("argument %d (%s), is not a supported argument type", x, argT)
}
passedT := reflect.TypeOf(arg)
if argT != passedT {
return false, false, fmt.Errorf("argument %d expected to be %s, but is %s", x, argT, passedT)
}
x++
if x < t.NumIn()-1 {
x++
}
}
return hasContext, isNamespace, nil
}
Expand Down
35 changes: 35 additions & 0 deletions mg/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mg
import (
"context"
"fmt"
"reflect"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -213,6 +214,40 @@ func TestFNilError(t *testing.T) {
}
}

func TestFVariadic(t *testing.T) {
fn := F(func(args ...string) {
if !reflect.DeepEqual(args, []string{"a", "b"}) {
t.Errorf("Wrong args, got %v, want [a b]", args)
}
}, "a", "b")
err := fn.Run(context.Background())
if err != nil {
t.Fatal(err)
}

fn = F(func(a string, b ...string) {}, "a", "b1", "b2")
err = fn.Run(context.Background())
if err != nil {
t.Fatal(err)
}

fn = F(func(a ...string) {})
err = fn.Run(context.Background())
if err != nil {
t.Fatal(err)
}

func() {
defer func() {
err, _ := recover().(error)
if err == nil || err.Error() != "too few arguments for target, got 0 for func(string, ...string)" {
t.Fatal(err)
}
}()
F(func(a string, b ...string) {})
}()
}

type Foo Namespace

func (Foo) Bare() {}
Expand Down

0 comments on commit 051a55c

Please sign in to comment.