diff --git a/p2p/net/swarm/addrs.go b/p2p/net/swarm/addrs.go index ed510f2626..17c65ae7b2 100644 --- a/p2p/net/swarm/addrs.go +++ b/p2p/net/swarm/addrs.go @@ -1,12 +1,12 @@ package swarm import ( - mafilter "github.com/libp2p/go-maddr-filter" + ma "github.com/multiformats/go-multiaddr" mamask "github.com/whyrusleeping/multiaddr-filter" ) // http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml -var lowTimeoutFilters = mafilter.NewFilters() +var lowTimeoutFilters = ma.NewFilters() func init() { for _, p := range []string{ diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 198d88a72b..f5c0209c80 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/libp2p/go-libp2p-core/connmgr" "github.com/libp2p/go-libp2p-core/metrics" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -19,9 +20,7 @@ import ( "github.com/jbenet/goprocess" goprocessctx "github.com/jbenet/goprocess/context" - filter "github.com/libp2p/go-maddr-filter" ma "github.com/multiformats/go-multiaddr" - mafilter "github.com/whyrusleeping/multiaddr-filter" ) // DialTimeoutLocal is the maximum duration a Dial to local network address @@ -87,22 +86,24 @@ type Swarm struct { dsync *DialSync backf DialBackoff limiter *dialLimiter - - // filters for addresses that shouldnt be dialed (or accepted) - Filters *filter.Filters + gater connmgr.ConnectionGater proc goprocess.Process ctx context.Context bwc metrics.Reporter } -// NewSwarm constructs a Swarm -func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter) *Swarm { +// NewSwarm constructs a Swarm. +// +// NOTE: go-libp2p will be moving to dependency injection soon. The variadic +// `extra` interface{} parameter facilitates the future migration. Supported +// elements are: +// - connmgr.ConnectionGater +func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, extra ...interface{}) *Swarm { s := &Swarm{ - local: local, - peers: peers, - bwc: bwc, - Filters: filter.NewFilters(), + local: local, + peers: peers, + bwc: bwc, } s.conns.m = make(map[peer.ID][]*Conn) @@ -110,6 +111,13 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc s.transports.m = make(map[int]transport.Transport) s.notifs.m = make(map[network.Notifiee]struct{}) + for _, i := range extra { + switch v := i.(type) { + case connmgr.ConnectionGater: + s.gater = v + } + } + s.dsync = NewDialSync(s.doDial) s.limiter = newDialLimiter(s.dialAddr) s.proc = goprocessctx.WithContext(ctx) @@ -168,33 +176,46 @@ func (s *Swarm) teardown() error { return nil } -// AddAddrFilter adds a multiaddr filter to the set of filters the swarm will use to determine which -// addresses not to dial to. -func (s *Swarm) AddAddrFilter(f string) error { - m, err := mafilter.NewMask(f) - if err != nil { - return err - } - - s.Filters.AddDialFilter(m) - return nil -} - // Process returns the Process of the swarm func (s *Swarm) Process() goprocess.Process { return s.proc } func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) { - // The underlying transport (or the dialer) *should* filter it's own - // connections but we should double check anyways. - raddr := tc.RemoteMultiaddr() - if s.Filters.AddrBlocked(raddr) { - tc.Close() - return nil, ErrAddrFiltered + var ( + p = tc.RemotePeer() + addr = tc.RemoteMultiaddr() + ) + + if s.gater != nil { + if allow := s.gater.InterceptAddrDial(p, addr); !allow { + err := tc.Close() + if err != nil { + log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p.Pretty(), addr, err) + } + return nil, ErrAddrFiltered + } } - p := tc.RemotePeer() + stat := network.Stat{Direction: dir} + c := &Conn{ + conn: tc, + swarm: s, + stat: stat, + } + + // we ONLY check upgraded connections here so we can send them a Disconnect message. + // If we do this in the Upgrader, we will not be able to do this. + if s.gater != nil { + if allow, _ := s.gater.InterceptUpgraded(c); !allow { + // TODO Send disconnect with reason here + err := tc.Close() + if err != nil { + log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p.Pretty(), addr, err) + } + return nil, ErrGaterDisallowedConnection + } + } // Add the public key. if pk := tc.RemotePublicKey(); pk != nil { @@ -214,12 +235,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, } // Wrap and register the connection. - stat := network.Stat{Direction: dir} - c := &Conn{ - conn: tc, - swarm: s, - stat: stat, - } c.streams.m = make(map[*Stream]struct{}) s.conns.m[p] = append(s.conns.m[p], c) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index cae4110c01..f35f2b6372 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -50,6 +50,10 @@ var ( // ErrNoGoodAddresses is returned when we find addresses for a peer but // can't use any of them. ErrNoGoodAddresses = errors.New("no good addresses") + + // ErrGaterDisallowedConnection is returned when the gater prevents us from + // forming a connection with a peer. + ErrGaterDisallowedConnection = errors.New("gater disallows connection to peer") ) // DialAttempts governs how many times a goroutine will try to dial a given peer. @@ -218,6 +222,11 @@ func (db *DialBackoff) cleanup() { // This allows us to use various transport protocols, do NAT traversal/relay, // etc. to achieve connection. func (s *Swarm) DialPeer(ctx context.Context, p peer.ID) (network.Conn, error) { + if s.gater != nil && !s.gater.InterceptPeerDial(p) { + log.Debugf("gater disallowed outbound connection to peer %s", p.Pretty()) + return nil, &DialError{Peer: p, Cause: ErrGaterDisallowedConnection} + } + return s.dialPeer(ctx, p) } @@ -339,7 +348,7 @@ func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) { if len(peerAddrs) == 0 { return nil, &DialError{Peer: p, Cause: ErrNoAddresses} } - goodAddrs := s.filterKnownUndialables(peerAddrs) + goodAddrs := s.filterKnownUndialables(p, peerAddrs) if len(goodAddrs) == 0 { return nil, &DialError{Peer: p, Cause: ErrNoGoodAddresses} } @@ -393,7 +402,7 @@ func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) { // IPv6 link-local addresses, addresses without a dial-capable transport, // and addresses that we know to be our own. // This is an optimization to avoid wasting time on dials that we know are going to fail. -func (s *Swarm) filterKnownUndialables(addrs []ma.Multiaddr) []ma.Multiaddr { +func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr { lisAddrs, _ := s.InterfaceListenAddresses() var ourAddrs []ma.Multiaddr for _, addr := range lisAddrs { @@ -409,7 +418,9 @@ func (s *Swarm) filterKnownUndialables(addrs []ma.Multiaddr) []ma.Multiaddr { s.canDial, // TODO: Consider allowing link-local addresses addrutil.AddrOverNonLocalIP, - addrutil.FilterNeg(s.Filters.AddrBlocked), + func(addr ma.Multiaddr) bool { + return s.gater == nil || s.gater.InterceptAddrDial(p, addr) + }, ) } diff --git a/p2p/net/swarm/swarm_listen.go b/p2p/net/swarm/swarm_listen.go index 09d411dfd8..ab5c42a345 100644 --- a/p2p/net/swarm/swarm_listen.go +++ b/p2p/net/swarm/swarm_listen.go @@ -89,6 +89,7 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { } return } + log.Debugf("swarm listener accepted connection: %s", c) s.refs.Add(1) go func() { diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index b155373f67..4750b6a64d 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -5,20 +5,21 @@ import ( "context" "fmt" "io" - "net" "sync" "testing" "time" - logging "github.com/ipfs/go-log" + "github.com/libp2p/go-libp2p-core/control" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" - ma "github.com/multiformats/go-multiaddr" - . "github.com/libp2p/go-libp2p-swarm" . "github.com/libp2p/go-libp2p-swarm/testing" + + logging "github.com/ipfs/go-log" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" ) var log = logging.Logger("swarm_test") @@ -280,60 +281,105 @@ func TestConnHandler(t *testing.T) { } } -func TestAddrBlocking(t *testing.T) { +func TestConnectionGating(t *testing.T) { ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) - - swarms[0].SetConnHandler(func(conn network.Conn) { - t.Errorf("no connections should happen! -- %s", conn) - }) - - _, block, err := net.ParseCIDR("127.0.0.1/8") - if err != nil { - t.Fatal(err) - } - - swarms[1].Filters.AddDialFilter(block) - - swarms[1].Peerstore().AddAddr(swarms[0].LocalPeer(), swarms[0].ListenAddresses()[0], peerstore.PermanentAddrTTL) - _, err = swarms[1].DialPeer(ctx, swarms[0].LocalPeer()) - if err == nil { - t.Fatal("dial should have failed") - } - - swarms[0].Peerstore().AddAddr(swarms[1].LocalPeer(), swarms[1].ListenAddresses()[0], peerstore.PermanentAddrTTL) - _, err = swarms[0].DialPeer(ctx, swarms[1].LocalPeer()) - if err == nil { - t.Fatal("dial should have failed") + tcs := map[string]struct { + p1Gater func(gater *MockConnectionGater) *MockConnectionGater + p2Gater func(gater *MockConnectionGater) *MockConnectionGater + + p1ConnectednessToP2 network.Connectedness + p2ConnectednessToP1 network.Connectedness + isP1OutboundErr bool + }{ + "no gating": { + p1ConnectednessToP2: network.Connected, + p2ConnectednessToP1: network.Connected, + isP1OutboundErr: false, + }, + "p1 gates outbound peer dial": { + p1Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.PeerDial = func(p peer.ID) bool { return false } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p1 gates outbound addr dialing": { + p1Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Dial = func(p peer.ID, addr ma.Multiaddr) bool { return false } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p2 gates inbound peer dial before securing": { + p2Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Accept = func(c network.ConnMultiaddrs) bool { return false } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p2 gates inbound peer dial before multiplexing": { + p1Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Secured = func(network.Direction, peer.ID, network.ConnMultiaddrs) bool { return false } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p2 gates inbound peer dial after upgrading": { + p1Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Upgraded = func(c network.Conn) (bool, control.DisconnectReason) { return false, 0 } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p2 gates outbound dials": { + p2Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.PeerDial = func(p peer.ID) bool { return false } + return c + }, + p1ConnectednessToP2: network.Connected, + p2ConnectednessToP1: network.Connected, + isP1OutboundErr: false, + }, } -} -func TestFilterBounds(t *testing.T) { - ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + for n, tc := range tcs { + t.Run(n, func(t *testing.T) { + p1Gater := DefaultMockConnectionGater() + p2Gater := DefaultMockConnectionGater() + if tc.p1Gater != nil { + p1Gater = tc.p1Gater(p1Gater) + } + if tc.p2Gater != nil { + p2Gater = tc.p2Gater(p2Gater) + } - conns := make(chan struct{}, 8) - swarms[0].SetConnHandler(func(conn network.Conn) { - conns <- struct{}{} - }) + sw1 := GenSwarm(t, ctx, OptConnGater(p1Gater)) + sw2 := GenSwarm(t, ctx, OptConnGater(p2Gater)) - // Address that we wont be dialing from - _, block, err := net.ParseCIDR("192.0.0.1/8") - if err != nil { - t.Fatal(err) - } + p1 := sw1.LocalPeer() + p2 := sw2.LocalPeer() + sw1.Peerstore().AddAddr(p2, sw2.ListenAddresses()[0], peerstore.PermanentAddrTTL) + // 1 -> 2 + _, err := sw1.DialPeer(ctx, p2) - // set filter on both sides, shouldnt matter - swarms[1].Filters.AddDialFilter(block) - swarms[0].Filters.AddDialFilter(block) + require.Equal(t, tc.isP1OutboundErr, err != nil, n) + require.Equal(t, tc.p1ConnectednessToP2, sw1.Connectedness(p2), n) - connectSwarms(t, ctx, swarms) + require.Eventually(t, func() bool { + return tc.p2ConnectednessToP1 == sw2.Connectedness(p1) + }, 2*time.Second, 100*time.Millisecond, n) + }) - select { - case <-time.After(time.Second): - t.Fatal("should have gotten connection") - case <-conns: - t.Log("got connect") } } diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 10de2ace09..0d252c3a1c 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -4,26 +4,31 @@ import ( "context" "testing" + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/control" "github.com/libp2p/go-libp2p-core/metrics" "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/libp2p/go-libp2p-testing/net" - "github.com/libp2p/go-tcp-transport" - goprocess "github.com/jbenet/goprocess" csms "github.com/libp2p/go-conn-security-multistream" pstoremem "github.com/libp2p/go-libp2p-peerstore/pstoremem" secio "github.com/libp2p/go-libp2p-secio" + swarm "github.com/libp2p/go-libp2p-swarm" + "github.com/libp2p/go-libp2p-testing/net" tptu "github.com/libp2p/go-libp2p-transport-upgrader" yamux "github.com/libp2p/go-libp2p-yamux" msmux "github.com/libp2p/go-stream-muxer-multistream" + "github.com/libp2p/go-tcp-transport" - swarm "github.com/libp2p/go-libp2p-swarm" + goprocess "github.com/jbenet/goprocess" + ma "github.com/multiformats/go-multiaddr" ) type config struct { disableReuseport bool dialOnly bool + connectionGater connmgr.ConnectionGater } // Option is an option that can be passed when constructing a test swarm. @@ -39,6 +44,13 @@ var OptDialOnly Option = func(_ *testing.T, c *config) { c.dialOnly = true } +// OptConnGater configures the given connection gater on the test +func OptConnGater(cg connmgr.ConnectionGater) Option { + return func(_ *testing.T, c *config) { + c.connectionGater = cg + } +} + // GenUpgrader creates a new connection upgrader for use with this swarm. func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { id := n.LocalPeer() @@ -53,9 +65,8 @@ func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) return &tptu.Upgrader{ - Secure: secMuxer, - Muxer: stMuxer, - Filters: n.Filters, + Secure: secMuxer, + Muxer: stMuxer, } } @@ -72,12 +83,16 @@ func GenSwarm(t *testing.T, ctx context.Context, opts ...Option) *swarm.Swarm { ps := pstoremem.NewPeerstore() ps.AddPubKey(p.ID, p.PubKey) ps.AddPrivKey(p.ID, p.PrivKey) - s := swarm.NewSwarm(ctx, p.ID, ps, metrics.NewBandwidthCounter()) + + s := swarm.NewSwarm(ctx, p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater) + // Call AddChildNoWait because we can't call AddChild after the process // may have been closed (e.g., if the context was canceled). s.Process().AddChildNoWait(goprocess.WithTeardown(ps.Close)) - tcpTransport := tcp.NewTCPTransport(GenUpgrader(s)) + upgrader := GenUpgrader(s) + upgrader.ConnGater = cfg.connectionGater + tcpTransport := tcp.NewTCPTransport(upgrader) tcpTransport.DisableReuseport = cfg.disableReuseport if err := s.AddTransport(tcpTransport); err != nil { @@ -101,3 +116,57 @@ func DivulgeAddresses(a, b network.Network) { addrs := a.Peerstore().Addrs(id) b.Peerstore().AddAddrs(id, addrs, peerstore.PermanentAddrTTL) } + +// MockConnectionGater is a mock connection gater to be used by the tests. +type MockConnectionGater struct { + Dial func(p peer.ID, addr ma.Multiaddr) bool + PeerDial func(p peer.ID) bool + Accept func(c network.ConnMultiaddrs) bool + Secured func(network.Direction, peer.ID, network.ConnMultiaddrs) bool + Upgraded func(c network.Conn) (bool, control.DisconnectReason) +} + +func DefaultMockConnectionGater() *MockConnectionGater { + m := &MockConnectionGater{} + m.Dial = func(p peer.ID, addr ma.Multiaddr) bool { + return true + } + + m.PeerDial = func(p peer.ID) bool { + return true + } + + m.Accept = func(c network.ConnMultiaddrs) bool { + return true + } + + m.Secured = func(network.Direction, peer.ID, network.ConnMultiaddrs) bool { + return true + } + + m.Upgraded = func(c network.Conn) (bool, control.DisconnectReason) { + return true, 0 + } + + return m +} + +func (m *MockConnectionGater) InterceptAddrDial(p peer.ID, addr ma.Multiaddr) (allow bool) { + return m.Dial(p, addr) +} + +func (m *MockConnectionGater) InterceptPeerDial(p peer.ID) (allow bool) { + return m.PeerDial(p) +} + +func (m *MockConnectionGater) InterceptAccept(c network.ConnMultiaddrs) (allow bool) { + return m.Accept(c) +} + +func (m *MockConnectionGater) InterceptSecured(d network.Direction, p peer.ID, c network.ConnMultiaddrs) (allow bool) { + return m.Secured(d, p, c) +} + +func (m *MockConnectionGater) InterceptUpgraded(tc network.Conn) (allow bool, reason control.DisconnectReason) { + return m.Upgraded(tc) +}