Skip to content

Commit

Permalink
feat(dot/network): Add warp sync spam limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
dimartiro committed Oct 2, 2024
1 parent 6e7a351 commit 3cf6f7f
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 0 deletions.
8 changes: 8 additions & 0 deletions dot/network/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ type Config struct {

Telemetry Telemetry
Metrics metrics.IntervalConfig

// Spam limiters configuration
warpSyncSpamLimiter RateLimiter
}

// build checks the configuration, sets up the private key for the network service,
Expand Down Expand Up @@ -154,6 +157,11 @@ func (c *Config) build() error {
c.telemetryInterval = time.Second * 5
}

// set warp sync spam limiter to default
if c.warpSyncSpamLimiter == nil {
c.warpSyncSpamLimiter = NewSpamLimiter(MaxAllowedRequestsPerPeer, MaxTimeWindow)
}

return nil
}

Expand Down
8 changes: 8 additions & 0 deletions dot/network/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ package network
import (
"encoding/json"
"io"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/libp2p/go-libp2p/core/peer"
)

// Telemetry is the telemetry client to send telemetry messages.
Expand All @@ -27,3 +30,8 @@ type MDNS interface {
Start() error
io.Closer
}

type RateLimiter interface {
AddRequest(peer peer.ID, hashedRequest common.Hash)
IsLimitExceeded(peer peer.ID, hashedRequest common.Hash) bool
}
4 changes: 4 additions & 0 deletions dot/network/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ type Service struct {
closeCh chan struct{}

telemetry Telemetry

// Spam control
warpSyncSpamLimiter RateLimiter
}

// NewService creates a new network service from the configuration and message channels
Expand Down Expand Up @@ -226,6 +229,7 @@ func NewService(cfg *Config) (*Service, error) {
streamManager: newStreamManager(ctx),
telemetry: cfg.Telemetry,
Metrics: cfg.Metrics,
warpSyncSpamLimiter: cfg.warpSyncSpamLimiter,
}

return network, nil
Expand Down
88 changes: 88 additions & 0 deletions dot/network/spam_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package network

import (
"sync"
"time"

"github.com/ChainSafe/gossamer/lib/common"
lrucache "github.com/ChainSafe/gossamer/lib/utils/lru-cache"
"github.com/libp2p/go-libp2p/core/peer"
)

const MaxTimeWindow = 10 * time.Second
const MaxCachedPeers = 100
const MaxCachedRequests = 100

type SpamLimiter struct {
mu sync.Mutex
limits *lrucache.LRUCache[peer.ID, *lrucache.LRUCache[common.Hash, []time.Time]]
maxReqs uint32
windowSize time.Duration
}

// NewSpamLimiter creates a new SpamLimiter with the given maximum number of requests
func NewSpamLimiter(maxReqs uint32, windowSize time.Duration) *SpamLimiter {
return &SpamLimiter{
limits: lrucache.NewLRUCache[peer.ID, *lrucache.LRUCache[common.Hash, []time.Time]](MaxCachedPeers),
maxReqs: maxReqs,
windowSize: windowSize,
}
}

// AddRequest adds a request to the SpamLimiter
func (rl *SpamLimiter) AddRequest(peer peer.ID, hashedRequest common.Hash) {
rl.mu.Lock()
defer rl.mu.Unlock()

// Get or create the internal cache for the peer
peerCache := rl.limits.Get(peer)
if peerCache == nil {
peerCache = lrucache.NewLRUCache[common.Hash, []time.Time](MaxCachedRequests)
rl.limits.Put(peer, peerCache)
}

// Get the timestamps for the hash
timestamps := peerCache.Get(hashedRequest)
now := time.Now()

// Filter requests that are within the time window
var recentRequests []time.Time
for _, t := range timestamps {
if now.Sub(t) <= rl.windowSize {
recentRequests = append(recentRequests, t)
}
}

// Add the current request and update the cache
recentRequests = append(recentRequests, now)
peerCache.Put(hashedRequest, recentRequests)
}

// IsLimitExceeded returns true if the limit is exceeded for the given peer and hash
func (rl *SpamLimiter) IsLimitExceeded(peer peer.ID, hashedRequest common.Hash) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

// Get the internal cache for the peer
peerCache := rl.limits.Get(peer)
if peerCache == nil {
return false
}

// Get the timestamps for the hash
timestamps := peerCache.Get(hashedRequest)
now := time.Now()

// Filter requests that are within the time window
var recentRequests []time.Time
for _, t := range timestamps {
if now.Sub(t) <= rl.windowSize {
recentRequests = append(recentRequests, t)
}
}

// Update the cache with the recent requests
peerCache.Put(hashedRequest, recentRequests)

return uint32(len(recentRequests)) > rl.maxReqs
}
86 changes: 86 additions & 0 deletions dot/network/spam_limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package network

import (
"testing"
"time"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/assert"
)

func TestSpamLimiter_AddRequestAndIsLimitExceeded(t *testing.T) {
t.Parallel()

// Create a SpamLimiter with a limit of 5 requests and a time window of 10 seconds
limiter := NewSpamLimiter(5, 10*time.Second)

peerID := peer.ID("peer1")
hash := common.Hash{0x01}

// Add 5 requests for the same peer and hash
for i := 0; i < 5; i++ {
limiter.AddRequest(peerID, hash)
}

// Limit should not be exceeded after 5 requests
assert.False(t, limiter.IsLimitExceeded(peerID, hash))

// Add one more request and check that the limit is exceeded
limiter.AddRequest(peerID, hash)
assert.True(t, limiter.IsLimitExceeded(peerID, hash))
}

func TestSpamLimiter_WindowExpiry(t *testing.T) {
t.Parallel()

// Create a SpamLimiter with a limit of 3 requests and a time window of 2 seconds
limiter := NewSpamLimiter(3, 1*time.Second)

peerID := peer.ID("peer2")
hash := common.Hash{0x02}

// Add 3 requests
for i := 0; i < 3; i++ {
limiter.AddRequest(peerID, hash)
}

// Limit should not be exceeded
assert.False(t, limiter.IsLimitExceeded(peerID, hash))

// Wait for the time window to expire
time.Sleep(2 * time.Second)

// Add another request, should be considered as the first in a new window
limiter.AddRequest(peerID, hash)
assert.False(t, limiter.IsLimitExceeded(peerID, hash))
}

func TestSpamLimiter_DifferentPeersAndHashes(t *testing.T) {
// Create a SpamLimiter with a limit of 2 requests and a time window of 5 seconds
limiter := NewSpamLimiter(2, 5*time.Second)

peerID1 := peer.ID("peer1")
peerID2 := peer.ID("peer2")
hash1 := common.Hash{0x01}
hash2 := common.Hash{0x02}

// Add requests for peerID1 and hash1
limiter.AddRequest(peerID1, hash1)
limiter.AddRequest(peerID1, hash1)

// Add requests for peerID2 and hash2
limiter.AddRequest(peerID2, hash2)
limiter.AddRequest(peerID2, hash2)

// No limit should be exceeded yet
assert.False(t, limiter.IsLimitExceeded(peerID1, hash1))
assert.False(t, limiter.IsLimitExceeded(peerID2, hash2))

// Add another request for each and check that the limit is exceeded
limiter.AddRequest(peerID1, hash1)
assert.True(t, limiter.IsLimitExceeded(peerID1, hash1))

limiter.AddRequest(peerID2, hash2)
assert.True(t, limiter.IsLimitExceeded(peerID2, hash2))
}
15 changes: 15 additions & 0 deletions dot/network/warp_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/libp2p/go-libp2p/core/peer"
)

const MaxAllowedRequestsPerPeer = 10

// WarpSyncProvider is an interface for generating warp sync proofs
type WarpSyncProvider interface {
// Generate proof starting at given block hash. The proof is accumulated until maximum proof
Expand Down Expand Up @@ -55,7 +57,20 @@ func (s *Service) handleWarpSyncMessage(stream libp2pnetwork.Stream, msg message
}
}()

peerId := stream.Conn().RemotePeer()
hashedReq := common.MustBlake2bHash([]byte(msg.String()))

if req, ok := msg.(*messages.WarpProofRequest); ok {
// Check if this peer has exceeded the limit of requests
if !s.warpSyncSpamLimiter.IsLimitExceeded(peerId, hashedReq) {
logger.Debugf("same warp sync request exceeded for peer: %s", peerId)
return nil
}

// Add the request to the spam limiter
s.warpSyncSpamLimiter.AddRequest(peerId, hashedReq)

// Handle request
resp, err := s.handleWarpSyncRequest(*req)
if err != nil {
logger.Debugf("cannot create response for request: %s", err)
Expand Down

0 comments on commit 3cf6f7f

Please sign in to comment.