diff --git a/README.md b/README.md index 72480c3..21b53fb 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ import ( func TestClient(t *testing.T) { ctx := context.Background() - ts := grpcstub.NewServer(t, []string{}, "path/to/route_guide.proto") + ts := grpcstub.NewServer(t, "path/to/route_guide.proto") t.Cleanup(func() { ts.Close() }) diff --git a/grpcstub.go b/grpcstub.go index 7903f23..6c65b3b 100644 --- a/grpcstub.go +++ b/grpcstub.go @@ -105,47 +105,47 @@ type matchFunc func(r *Request) bool type handlerFunc func(r *Request) *Response // NewServer returns a new server with registered *grpc.Server -func NewServer(t *testing.T, importPaths []string, protos ...string) *Server { +func NewServer(t *testing.T, proto string, opts ...Option) *Server { t.Helper() - fds, err := descriptorFromFiles(importPaths, protos...) + c := &config{} + opts = append(opts, Proto(proto)) + for _, opt := range opts { + if err := opt(c); err != nil { + t.Fatal(err) + } + } + fds, err := descriptorFromFiles(c.importPaths, c.protos...) if err != nil { t.Error(err) return nil } s := &Server{ - fds: fds, - server: grpc.NewServer(), - t: t, + fds: fds, + t: t, + } + if c.useTLS { + certificate, err := tls.X509KeyPair(c.cert, c.key) + if err != nil { + t.Fatal(err) + } + tlsc := &tls.Config{ + Certificates: []tls.Certificate{certificate}, + } + creds := credentials.NewTLS(tlsc) + s.tlsc = tlsc + s.cacert = c.cacert + s.server = grpc.NewServer(grpc.Creds(creds)) + } else { + s.server = grpc.NewServer() } s.startServer() return s } // NewTLSServer returns a new server with registered secure *grpc.Server -func NewTLSServer(t *testing.T, cacert, cert, key []byte, importPaths []string, protos ...string) *Server { - t.Helper() - fds, err := descriptorFromFiles(importPaths, protos...) - if err != nil { - t.Error(err) - return nil - } - certificate, err := tls.X509KeyPair(cert, key) - if err != nil { - t.Fatal(err) - } - tlsc := &tls.Config{ - Certificates: []tls.Certificate{certificate}, - } - creds := credentials.NewTLS(tlsc) - s := &Server{ - fds: fds, - tlsc: tlsc, - cacert: cacert, - server: grpc.NewServer(grpc.Creds(creds)), - t: t, - } - s.startServer() - return s +func NewTLSServer(t *testing.T, proto string, cacert, cert, key []byte, opts ...Option) *Server { + opts = append(opts, UseTLS(cacert, cert, key)) + return NewServer(t, proto, opts...) } // Close shuts down *grpc.Server diff --git a/grpcstub_test.go b/grpcstub_test.go index 1b4030b..28afb00 100644 --- a/grpcstub_test.go +++ b/grpcstub_test.go @@ -21,7 +21,7 @@ import ( func TestUnary(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -58,7 +58,7 @@ func TestUnary(t *testing.T) { func TestServerStreaming(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -115,7 +115,7 @@ func TestServerStreaming(t *testing.T) { func TestClientStreaming(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -157,7 +157,7 @@ func TestClientStreaming(t *testing.T) { func TestBiStreaming(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -238,7 +238,7 @@ func TestBiStreaming(t *testing.T) { } func TestAddr(t *testing.T) { - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -250,7 +250,7 @@ func TestAddr(t *testing.T) { func TestServerMatch(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -274,7 +274,7 @@ func TestServerMatch(t *testing.T) { func TestMatcherMatch(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -298,7 +298,7 @@ func TestMatcherMatch(t *testing.T) { func TestServerService(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -320,7 +320,7 @@ func TestServerService(t *testing.T) { func TestMatcherService(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -342,7 +342,7 @@ func TestMatcherService(t *testing.T) { func TestMatcherMethod(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -364,7 +364,7 @@ func TestMatcherMethod(t *testing.T) { func TestHeader(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -391,7 +391,7 @@ func TestHeader(t *testing.T) { func TestTrailer(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -418,7 +418,7 @@ func TestTrailer(t *testing.T) { func TestResponseHeader(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -438,7 +438,7 @@ func TestResponseHeader(t *testing.T) { func TestStatusUnary(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -472,7 +472,7 @@ func TestStatusUnary(t *testing.T) { func TestStatusServerStreaming(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -518,7 +518,7 @@ func TestStatusServerStreaming(t *testing.T) { func TestStatusClientStreaming(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -566,7 +566,7 @@ func TestStatusClientStreaming(t *testing.T) { func TestStatusBiStreaming(t *testing.T) { ctx := context.Background() - ts := NewServer(t, []string{}, "testdata/route_guide.proto") + ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { ts.Close() }) @@ -635,7 +635,7 @@ func TestLoadProto(t *testing.T) { ctx := context.Background() for _, tt := range tests { t.Run(tt.proto, func(t *testing.T) { - ts := NewServer(t, []string{}, tt.proto) + ts := NewServer(t, tt.proto) t.Cleanup(func() { ts.Close() }) @@ -676,7 +676,7 @@ func TestTLSServer(t *testing.T) { if err != nil { t.Fatal(err) } - ts := NewTLSServer(t, cacert, cert, key, []string{}, "testdata/route_guide.proto") + ts := NewTLSServer(t, "testdata/route_guide.proto", cacert, cert, key) t.Cleanup(func() { ts.Close() }) diff --git a/option.go b/option.go new file mode 100644 index 0000000..c976bc6 --- /dev/null +++ b/option.go @@ -0,0 +1,41 @@ +package grpcstub + +type config struct { + protos []string + importPaths []string + useTLS bool + cacert, cert, key []byte +} + +type Option func(*config) error + +func Proto(proto string) Option { + return func(c *config) error { + c.protos = append(c.protos, proto) + return nil + } +} + +func Protos(protos []string) Option { + return func(c *config) error { + c.protos = append(c.protos, protos...) + return nil + } +} + +func ImportPaths(paths []string) Option { + return func(c *config) error { + c.importPaths = append(c.importPaths, paths...) + return nil + } +} + +func UseTLS(cacert, cert, key []byte) Option { + return func(c *config) error { + c.useTLS = true + c.cacert = cacert + c.cert = cert + c.key = key + return nil + } +}