From 710f9723bb60022d03dbd6f265ae1e7f29b7213f Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Fri, 12 May 2023 17:55:59 +0200 Subject: [PATCH] refactor: use streaming bool instead of count int --- routing/http/client/client_test.go | 6 ++--- routing/http/server/server.go | 42 ++++++++---------------------- routing/http/server/server_test.go | 6 ++--- 3 files changed, 17 insertions(+), 37 deletions(-) diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index f55bdf38e..7551350d3 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -27,8 +27,8 @@ import ( type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error) { - args := m.Called(ctx, key, count) +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) { + args := m.Called(ctx, key, stream) return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { @@ -302,7 +302,7 @@ func TestClient_FindProviders(t *testing.T) { findProvsIter := iter.FromSlice(c.routerProvs) - router.On("FindProviders", mock.Anything, cid, 20). + router.On("FindProviders", mock.Anything, cid, c.expStreamingResponse). Return(findProvsIter, c.routerErr) provsIter, err := client.FindProviders(ctx, cid) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 627c0eb62..64eb2dee5 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -41,10 +41,9 @@ type FindProvidersAsyncResponse struct { } type ContentRouter interface { - // FindProviders searches for peers who are able to provide a given key. Count - // indicates the maximum amount of providers we are looking for. If count is 0, - // the implementer can return an unbounded number of results. - FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error) + // FindProviders searches for peers who are able to provide a given key. Stream + // indicates whether or not this request will be responded as a stream. + FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error) } @@ -72,27 +71,9 @@ func WithStreamingResultsDisabled() Option { } } -// WithRecordsCount changes the amount of records asked for non-streaming requests. -// Default is 20. -func WithRecordsCount(count int) Option { - return func(s *server) { - s.recordsCount = count - } -} - -// WithStreamingRecordsCount changes the amount of records asked for streaming requests. -// Default is 0 (unbounded). -func WithStreamingRecordsCount(count int) Option { - return func(s *server) { - s.streamingRecordsCount = count - } -} - func Handler(svc ContentRouter, opts ...Option) http.Handler { server := &server{ - svc: svc, - recordsCount: 20, - streamingRecordsCount: 0, + svc: svc, } for _, opt := range opts { @@ -107,10 +88,8 @@ func Handler(svc ContentRouter, opts ...Option) http.Handler { } type server struct { - svc ContentRouter - disableNDJSON bool - recordsCount int - streamingRecordsCount int + svc ContentRouter + disableNDJSON bool } func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { @@ -193,10 +172,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { var supportsNDJSON bool var supportsJSON bool - var count int + var streaming bool acceptHeaders := httpReq.Header.Values("Accept") if len(acceptHeaders) == 0 { handlerFunc = s.findProvidersJSON + streaming = false } else { for _, acceptHeader := range acceptHeaders { for _, accept := range strings.Split(acceptHeader, ",") { @@ -209,25 +189,25 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { switch mediaType { case mediaTypeJSON, mediaTypeWildcard: supportsJSON = true - count = s.recordsCount case mediaTypeNDJSON: supportsNDJSON = true - count = s.streamingRecordsCount } } } if supportsNDJSON && !s.disableNDJSON { handlerFunc = s.findProvidersNDJSON + streaming = true } else if supportsJSON { handlerFunc = s.findProvidersJSON + streaming = false } else { writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types")) return } } - provIter, err := s.svc.FindProviders(httpReq.Context(), cid, count) + provIter, err := s.svc.FindProviders(httpReq.Context(), cid, streaming) if err != nil { writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index f5af9c590..cf104c654 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -33,7 +33,7 @@ func TestHeaders(t *testing.T) { cb, err := cid.Decode(c) require.NoError(t, err) - router.On("FindProviders", mock.Anything, cb, 0). + router.On("FindProviders", mock.Anything, cb, false). Return(results, nil) resp, err := http.Get(serverAddr + ProvidePath + c) @@ -118,8 +118,8 @@ func TestResponse(t *testing.T) { type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error) { - args := m.Called(ctx, key, count) +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) { + args := m.Called(ctx, key, stream) return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) {