From c70c2b53da7c3381fbf6c7e857e7eaecd03393e6 Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Thu, 30 Mar 2023 13:26:28 -0400 Subject: [PATCH] feat: zero timeout on composed routers should disable timeout (#72) This will let consumers disable timeouts instead of using a timeout of 0s which isn't otherwise useful since it will always fail anyway. --- compparallel.go | 15 ++++- compparallel_test.go | 22 +++++++ compsequential.go | 17 +++--- compsequential_test.go | 132 +++++++++++++++++------------------------ dummy_test.go | 22 +++++++ 5 files changed, 118 insertions(+), 90 deletions(-) diff --git a/compparallel.go b/compparallel.go index e0b37cc..ccdb7f4 100644 --- a/compparallel.go +++ b/compparallel.go @@ -152,6 +152,13 @@ func (r *composableParallel) Bootstrap(ctx context.Context) error { }) } +func withCancelAndOptionalTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if timeout != 0 { + return context.WithTimeout(ctx, timeout) + } + return context.WithCancel(ctx) +} + func getValueOrErrorParallel[T any]( ctx context.Context, routers []*ParallelRouter, @@ -177,7 +184,7 @@ func getValueOrErrorParallel[T any]( select { case <-ctx.Done(): case <-tim.C: - ctx, cancel := context.WithTimeout(ctx, r.Timeout) + ctx, cancel := withCancelAndOptionalTimeout(ctx, r.Timeout) defer cancel() value, empty, err := f(ctx, r.Router) if err != nil && @@ -269,8 +276,9 @@ func executeParallel( errCh <- ctx.Err() } case <-tim.C: - ctx, cancel := context.WithTimeout(ctx, r.Timeout) + ctx, cancel := withCancelAndOptionalTimeout(ctx, r.Timeout) defer cancel() + log.Debug("executeParallel: calling router function for router ", r.Router, " with timeout ", r.Timeout, " and ignore errors ", r.IgnoreError, @@ -335,8 +343,9 @@ func getChannelOrErrorParallel[T any]( ) return case <-tim.C: - ctx, cancel := context.WithTimeout(ctx, r.Timeout) + ctx, cancel := withCancelAndOptionalTimeout(ctx, r.Timeout) defer cancel() + log.Debug("getChannelOrErrorParallel: calling router function for router ", r.Router, " with timeout ", r.Timeout, " and ignore errors ", r.IgnoreError, diff --git a/compparallel_test.go b/compparallel_test.go index 39953f0..b7068fa 100644 --- a/compparallel_test.go +++ b/compparallel_test.go @@ -254,6 +254,28 @@ func TestComposableParallelFixtures(t *testing.T) { }}, SearchValue: []searchValueFixture{{key: "a", ctx: canceledCtx, err: context.Canceled}}, }, + { + Name: "timeout=0 should disable the timeout, two routers with one disabled timeout should timeout on the other router", + routers: []*ParallelRouter{ + { + Timeout: 0, + IgnoreError: false, + Router: &Compose{ + ValueStore: newDummyValueStore(t, nil, nil), + }, + }, + { + Timeout: time.Second, + IgnoreError: false, + Router: &Compose{ + ValueStore: newDummyValueStore(t, []string{"a"}, []string{"av"}), + }, + }, + }, + GetValue: []getValueFixture{ + {key: "/wait/100ms/a", value: "av", searchValCount: 1}, + }, + }, } for _, f := range fixtures { diff --git a/compsequential.go b/compsequential.go index bb761d5..78939f3 100644 --- a/compsequential.go +++ b/compsequential.go @@ -156,8 +156,9 @@ func getValueOrErrorSequential[T any]( return value, ctxErr } - ctx, cancel := context.WithTimeout(ctx, router.Timeout) + ctx, cancel := withCancelAndOptionalTimeout(ctx, router.Timeout) defer cancel() + value, empty, err := f(ctx, router.Router) if err != nil && !errors.Is(err, routing.ErrNotFound) && @@ -184,14 +185,15 @@ func executeSequential( if ctxErr := ctx.Err(); ctxErr != nil { return ctxErr } - ctx, cancel := context.WithTimeout(ctx, router.Timeout) + + ctx, cancel := withCancelAndOptionalTimeout(ctx, router.Timeout) + defer cancel() + if err := f(ctx, router.Router); err != nil && !errors.Is(err, routing.ErrNotFound) && !router.IgnoreError { - cancel() return err } - cancel() } return nil @@ -211,13 +213,12 @@ func getChannelOrErrorSequential[T any]( close(chanOut) return } - - ctx, cancel := context.WithTimeout(ctx, router.Timeout) + ctx, cancel := withCancelAndOptionalTimeout(ctx, router.Timeout) + defer cancel() rch, err := f(ctx, router.Router) if err != nil && !errors.Is(err, routing.ErrNotFound) && !router.IgnoreError { - cancel() break } @@ -238,8 +239,6 @@ func getChannelOrErrorSequential[T any]( } } - - cancel() } close(chanOut) diff --git a/compsequential_test.go b/compsequential_test.go index ea75c81..81df0b2 100644 --- a/compsequential_test.go +++ b/compsequential_test.go @@ -41,27 +41,31 @@ func TestNoResultsSequential(t *testing.T) { } func TestComposableSequentialFixtures(t *testing.T) { + type getValueFixture struct { + err error + key string + value string + searchValCount int + } + type putValueFixture struct { + err error + key string + value string + } + type provideFixture struct { + err error + } + type findPeerFixture struct { + peerID string + err error + } fixtures := []struct { Name string routers []*SequentialRouter - GetValueFixtures []struct { - err error - key string - value string - searchValCount int - } - PutValueFixtures []struct { - err error - key string - value string - } - ProvideFixtures []struct { - err error - } - FindPeerFixtures []struct { - peerID string - err error - } + GetValueFixtures []getValueFixture + PutValueFixtures []putValueFixture + ProvideFixtures []provideFixture + FindPeerFixtures []findPeerFixture }{ { Name: "simple two routers", @@ -85,12 +89,7 @@ func TestComposableSequentialFixtures(t *testing.T) { }, }, }, - GetValueFixtures: []struct { - err error - key string - value string - searchValCount int - }{ + GetValueFixtures: []getValueFixture{ { key: "d", value: "dv", @@ -102,11 +101,7 @@ func TestComposableSequentialFixtures(t *testing.T) { searchValCount: 2, }, }, - PutValueFixtures: []struct { - err error - key string - value string - }{ + PutValueFixtures: []putValueFixture{ { err: errors.New("a"), key: "/error/a", @@ -117,17 +112,12 @@ func TestComposableSequentialFixtures(t *testing.T) { value: "a", }, }, - ProvideFixtures: []struct { - err error - }{ + ProvideFixtures: []provideFixture{ { err: routing.ErrNotSupported, }, }, - FindPeerFixtures: []struct { - peerID string - err error - }{ + FindPeerFixtures: []findPeerFixture{ { peerID: "pid1", }, @@ -158,12 +148,7 @@ func TestComposableSequentialFixtures(t *testing.T) { }, }, }, - GetValueFixtures: []struct { - err error - key string - value string - searchValCount int - }{ + GetValueFixtures: []getValueFixture{ { key: "d", value: "dv", @@ -174,11 +159,7 @@ func TestComposableSequentialFixtures(t *testing.T) { key: "a", }, }, - PutValueFixtures: []struct { - err error - key string - value string - }{ + PutValueFixtures: []putValueFixture{ { key: "/error/x", value: "xv", @@ -188,10 +169,7 @@ func TestComposableSequentialFixtures(t *testing.T) { value: "yv", }, }, - FindPeerFixtures: []struct { - peerID string - err error - }{ + FindPeerFixtures: []findPeerFixture{ { peerID: "pid1", }, @@ -223,12 +201,7 @@ func TestComposableSequentialFixtures(t *testing.T) { }, }, }, - GetValueFixtures: []struct { - err error - key string - value string - searchValCount int - }{ + GetValueFixtures: []getValueFixture{ { key: "d", value: "dv", @@ -248,11 +221,7 @@ func TestComposableSequentialFixtures(t *testing.T) { key: "/error/y", }, }, - PutValueFixtures: []struct { - err error - key string - value string - }{ + PutValueFixtures: []putValueFixture{ { key: "/error/x", value: "xv", @@ -262,10 +231,7 @@ func TestComposableSequentialFixtures(t *testing.T) { value: "yv", }, }, - FindPeerFixtures: []struct { - peerID string - err error - }{ + FindPeerFixtures: []findPeerFixture{ { peerID: "pid1", }, @@ -297,12 +263,7 @@ func TestComposableSequentialFixtures(t *testing.T) { }, }, }, - GetValueFixtures: []struct { - err error - key string - value string - searchValCount int - }{ + GetValueFixtures: []getValueFixture{ { err: errFailValue, key: "d", @@ -337,12 +298,7 @@ func TestComposableSequentialFixtures(t *testing.T) { }, }, }, - GetValueFixtures: []struct { - err error - key string - value string - searchValCount int - }{ + GetValueFixtures: []getValueFixture{ { key: "d", value: "dv", @@ -355,6 +311,26 @@ func TestComposableSequentialFixtures(t *testing.T) { }, }, }, + { + Name: "timeout=0 should disable the timeout, two routers with one disabled timeout should timeout on the other router", + routers: []*SequentialRouter{ + { + Timeout: 0, + IgnoreError: false, + Router: &Compose{ + ValueStore: newDummyValueStore(t, nil, nil), + }, + }, + { + Timeout: time.Minute, + IgnoreError: false, + Router: &Compose{ + ValueStore: newDummyValueStore(t, []string{"a"}, []string{"av"}), + }, + }, + }, + GetValueFixtures: []getValueFixture{{key: "/wait/100ms/a", value: "av", searchValCount: 1}}, + }, } for _, f := range fixtures { diff --git a/dummy_test.go b/dummy_test.go index dd21179..d7b4e2c 100644 --- a/dummy_test.go +++ b/dummy_test.go @@ -3,8 +3,10 @@ package routinghelpers import ( "context" "errors" + "fmt" "strings" "sync" + "time" "github.com/ipfs/go-cid" "github.com/libp2p/go-libp2p/core/peer" @@ -60,6 +62,26 @@ func (d *dummyValueStore) GetValue(ctx context.Context, key string, opts ...rout <-ctx.Done() return nil, ctx.Err() } + // format: /wait/10s/key + // this will wait for the given duration and then perform the lookup normally on key, + // short circuiting if the context closes + if strings.HasPrefix(key, "/wait/") { + durationAndKey := strings.TrimPrefix(key, "/wait/") + split := strings.Split(durationAndKey, "/") + durationStr, key := split[0], split[1] + duration, err := time.ParseDuration(durationStr) + if err != nil { + return nil, fmt.Errorf("parsing wait duration: %w", err) + } + timer := time.NewTimer(duration) + defer timer.Stop() + select { + case <-timer.C: + return d.GetValue(ctx, key, opts...) + case <-ctx.Done(): + return nil, ctx.Err() + } + } if v, ok := (*sync.Map)(d).Load(key); ok { return v.([]byte), nil }