Skip to content

Commit

Permalink
provider: prioritize roots and introduce NewPrioritizedProvider (#595)
Browse files Browse the repository at this point in the history
  • Loading branch information
hacdias authored Apr 9, 2024
1 parent 5fc25a4 commit ec20793
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ The following emojis are used to highlight certain changes:

* `routing/http/server` now adds `Cache-Control` HTTP header to GET requests: 15 seconds for empty responses, or 5 minutes for responses with providers.
* `routing/http/server` the `/ipns` endpoint is more friendly to users opening URL in web browsers: returns `Content-Disposition` header and defaults to `application/vnd.ipfs.ipns-record` response when `Accept` is missing.
* `provider`:
* Exports a `NewPrioritizedProvider`, which can be used to prioritize certain providers while ignoring duplicates.
* 🛠️ `NewPinnedProvider` now prioritizes root blocks, even if `onlyRoots` is set to `false`.

### Changed

Expand Down
109 changes: 88 additions & 21 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,45 +71,112 @@ func NewPinnedProvider(onlyRoots bool, pinning pin.Pinner, fetchConfig fetcher.F
}

func pinSet(ctx context.Context, pinning pin.Pinner, fetchConfig fetcher.Factory, onlyRoots bool) (*cidutil.StreamingSet, error) {
// FIXME: Listing all pins code is duplicated thrice, twice in Kubo and here, maybe more.
// If this were a method of the [pin.Pinner] life would be easier.
set := cidutil.NewStreamingSet()
recursivePins := cidutil.NewSet()

go func() {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
defer close(set.New)

for sc := range pinning.DirectKeys(ctx, false) {
// 1. Recursive keys
for sc := range pinning.RecursiveKeys(ctx, false) {
if sc.Err != nil {
logR.Errorf("reprovide direct pins: %s", sc.Err)
logR.Errorf("reprovide recursive pins: %s", sc.Err)
return
}
set.Visitor(ctx)(sc.Pin.Key)
if !onlyRoots {
// Save some bytes.
_ = recursivePins.Visit(sc.Pin.Key)
}
_ = set.Visitor(ctx)(sc.Pin.Key)
}

session := fetchConfig.NewSession(ctx)
for sc := range pinning.RecursiveKeys(ctx, false) {
// 2. Direct pins
for sc := range pinning.DirectKeys(ctx, false) {
if sc.Err != nil {
logR.Errorf("reprovide recursive pins: %s", sc.Err)
logR.Errorf("reprovide direct pins: %s", sc.Err)
return
}
set.Visitor(ctx)(sc.Pin.Key)
if !onlyRoots {
err := fetcherhelpers.BlockAll(ctx, session, cidlink.Link{Cid: sc.Pin.Key}, func(res fetcher.FetchResult) error {
clink, ok := res.LastBlockLink.(cidlink.Link)
if ok {
set.Visitor(ctx)(clink.Cid)
}
return nil
})
if err != nil {
logR.Errorf("reprovide indirect pins: %s", err)
return
_ = set.Visitor(ctx)(sc.Pin.Key)
}

if onlyRoots {
return
}

// 3. Go through recursive pins to fetch remaining blocks if we want more
// than just roots.
session := fetchConfig.NewSession(ctx)
err := recursivePins.ForEach(func(c cid.Cid) error {
return fetcherhelpers.BlockAll(ctx, session, cidlink.Link{Cid: c}, func(res fetcher.FetchResult) error {
clink, ok := res.LastBlockLink.(cidlink.Link)
if ok {
_ = set.Visitor(ctx)(clink.Cid)
}
}
return nil
})
})
if err != nil {
logR.Errorf("reprovide indirect pins: %s", err)
return
}
}()

return set, nil
}

func NewPrioritizedProvider(priorityCids KeyChanFunc, otherCids KeyChanFunc) KeyChanFunc {
return func(ctx context.Context) (<-chan cid.Cid, error) {
outCh := make(chan cid.Cid)

go func() {
defer close(outCh)
visited := cidutil.NewSet()

handleStream := func(stream KeyChanFunc, markVisited bool) error {
ch, err := stream(ctx)
if err != nil {
return err
}

for {
select {
case <-ctx.Done():
return nil
case c, ok := <-ch:
if !ok {
return nil
}

if visited.Has(c) {
continue
}

select {
case <-ctx.Done():
return nil
case outCh <- c:
if markVisited {
_ = visited.Visit(c)
}
}
}
}
}

err := handleStream(priorityCids, true)
if err != nil {
log.Warnf("error in prioritized strategy while handling priority CIDs: %w", err)
return
}

err = handleStream(otherCids, false)
if err != nil {
log.Warnf("error in prioritized strategy while handling other CIDs: %w", err)
}
}()

return outCh, nil
}
}
92 changes: 92 additions & 0 deletions provider/reprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package provider
import (
"bytes"
"context"
"crypto/rand"
"runtime"
"strconv"
"sync"
Expand All @@ -15,6 +16,7 @@ import (
dssync "github.com/ipfs/go-datastore/sync"
mh "github.com/multiformats/go-multihash"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type allFeatures interface {
Expand Down Expand Up @@ -221,3 +223,93 @@ func TestOfflineRecordsThenOnlineRepublish(t *testing.T) {
t.Fatalf("keys are not equal expected %v, got %v", someHash, prov.keys[0])
}
}

func newMockKeyChanFunc(cids []cid.Cid) KeyChanFunc {
return func(ctx context.Context) (<-chan cid.Cid, error) {
outCh := make(chan cid.Cid)

go func() {
defer close(outCh)
for _, c := range cids {
select {
case <-ctx.Done():
return
case outCh <- c:
}
}
}()

return outCh, nil
}
}

func makeCIDs(n int) []cid.Cid {
cids := make([]cid.Cid, n)
for i := 0; i < n; i++ {
buf := make([]byte, 63)
_, err := rand.Read(buf)
if err != nil {
panic(err)
}
data, err := mh.Encode(buf, mh.SHA2_256)
if err != nil {
panic(err)
}
cids[i] = cid.NewCidV1(0, data)
}

return cids
}

func TestNewPrioritizedProvider(t *testing.T) {
cids := makeCIDs(6)

testCases := []struct {
name string
priority []cid.Cid
all []cid.Cid
expected []cid.Cid
}{
{
name: "basic test",
priority: cids[:3],
all: cids[3:],
expected: cids,
},
{
name: "basic test inverted",
priority: cids[3:],
all: cids[:3],
expected: append(cids[3:], cids[:3]...),
},
{
name: "no repeated",
priority: cids[3:],
all: cids[3:],
expected: cids[3:],
},
{
name: "no repeated if duplicates in the prioritized channel",
priority: []cid.Cid{cids[0], cids[1], cids[0]},
all: []cid.Cid{cids[2], cids[4], cids[5]},
expected: []cid.Cid{cids[0], cids[1], cids[2], cids[4], cids[5]},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

stream := NewPrioritizedProvider(newMockKeyChanFunc(tc.priority), newMockKeyChanFunc(tc.all))
ch, err := stream(ctx)
require.NoError(t, err)

received := []cid.Cid{}
for c := range ch {
received = append(received, c)
}
require.Equal(t, tc.expected, received)
})
}
}

0 comments on commit ec20793

Please sign in to comment.