From ded75c34d471cc570fb204a24b405cd400392156 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Apr 2023 17:57:58 +0900 Subject: [PATCH 1/8] nat: remove unused NAT method from the Mapping interface --- p2p/net/nat/mapping.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/p2p/net/nat/mapping.go b/p2p/net/nat/mapping.go index f9b508e4e2..0641e8fc4e 100644 --- a/p2p/net/nat/mapping.go +++ b/p2p/net/nat/mapping.go @@ -9,9 +9,6 @@ import ( // Mapping represents a port mapping in a NAT. type Mapping interface { - // NAT returns the NAT object this Mapping belongs to. - NAT() *NAT - // Protocol returns the protocol of this port mapping. This is either // "tcp" or "udp" as no other protocols are likely to be NAT-supported. Protocol() string @@ -46,12 +43,6 @@ type mapping struct { cacheLk sync.Mutex } -func (m *mapping) NAT() *NAT { - m.Lock() - defer m.Unlock() - return m.nat -} - func (m *mapping) Protocol() string { m.Lock() defer m.Unlock() From f5cbaf172190f786836c4fb9d277cfce0790a983 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Apr 2023 18:06:17 +0900 Subject: [PATCH 2/8] nat: rename NewMapping to AddMapping, remove unused Mapping return value --- p2p/host/basic/natmgr.go | 3 +-- p2p/net/nat/nat.go | 14 ++++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/p2p/host/basic/natmgr.go b/p2p/host/basic/natmgr.go index 782c116d44..4bbec10a41 100644 --- a/p2p/host/basic/natmgr.go +++ b/p2p/host/basic/natmgr.go @@ -203,8 +203,7 @@ func (nmgr *natManager) doSync() { wg.Add(1) go func(proto string, port int) { defer wg.Done() - _, err := nmgr.nat.NewMapping(proto, port) - if err != nil { + if err := nmgr.nat.AddMapping(proto, port); err != nil { log.Errorf("failed to port-map %s port %d: %s", proto, port, err) } }(proto, port) diff --git a/p2p/net/nat/nat.go b/p2p/net/nat/nat.go index e2656f8bcc..201c887f8f 100644 --- a/p2p/net/nat/nat.go +++ b/p2p/net/nat/nat.go @@ -92,23 +92,21 @@ func (nat *NAT) Mappings() []Mapping { return maps2 } -// NewMapping attempts to construct a mapping on protocol and internal port +// AddMapping attempts to construct a mapping on protocol and internal port // It will also periodically renew the mapping until the returned Mapping // -- or its parent NAT -- is Closed. // // May not succeed, and mappings may change over time; // NAT devices may not respect our port requests, and even lie. -// Clients should not store the mapped results, but rather always -// poll our object for the latest mappings. -func (nat *NAT) NewMapping(protocol string, port int) (Mapping, error) { +func (nat *NAT) AddMapping(protocol string, port int) error { if nat == nil { - return nil, fmt.Errorf("no nat available") + return fmt.Errorf("no nat available") } switch protocol { case "tcp", "udp": default: - return nil, fmt.Errorf("invalid protocol: %s", protocol) + return fmt.Errorf("invalid protocol: %s", protocol) } m := &mapping{ @@ -120,7 +118,7 @@ func (nat *NAT) NewMapping(protocol string, port int) (Mapping, error) { nat.mappingmu.Lock() if nat.closed { nat.mappingmu.Unlock() - return nil, errors.New("closed") + return errors.New("closed") } nat.mappings[m] = struct{}{} nat.refCount.Add(1) @@ -130,7 +128,7 @@ func (nat *NAT) NewMapping(protocol string, port int) (Mapping, error) { // do it once synchronously, so first mapping is done right away, and before exiting, // allowing users -- in the optimistic case -- to use results right after. nat.establishMapping(m) - return m, nil + return nil } func (nat *NAT) removeMapping(m *mapping) { From c62be4081c25fce38c0f64d2900c8366e1347ca2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Apr 2023 22:27:04 +0900 Subject: [PATCH 3/8] nat: use a single Go routine to renew NAT mappings --- p2p/host/basic/natmgr.go | 56 ++++++++-------- p2p/net/nat/mapping.go | 8 --- p2p/net/nat/nat.go | 138 +++++++++++++++++++++++---------------- 3 files changed, 107 insertions(+), 95 deletions(-) diff --git a/p2p/host/basic/natmgr.go b/p2p/host/basic/natmgr.go index 4bbec10a41..522d3e827a 100644 --- a/p2p/host/basic/natmgr.go +++ b/p2p/host/basic/natmgr.go @@ -30,6 +30,11 @@ func NewNATManager(net network.Network) NATManager { return newNatManager(net) } +type entry struct { + protocol string + port int +} + // natManager takes care of adding + removing port mappings to the nat. // Initialized with the host if it has a NATPortMap option enabled. // natManager receives signals from the network, and check on nat mappings: @@ -42,7 +47,9 @@ type natManager struct { nat *inat.NAT ready chan struct{} // closed once the nat is ready to process port mappings - syncFlag chan struct{} + syncFlag chan struct{} // cap: 1 + + tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function refCount sync.WaitGroup ctxCancel context.CancelFunc @@ -55,6 +62,7 @@ func newNatManager(net network.Network) *natManager { ready: make(chan struct{}), syncFlag: make(chan struct{}, 1), ctxCancel: cancel, + tracked: make(map[entry]bool), } nmgr.refCount.Add(1) go nmgr.background(ctx) @@ -127,10 +135,10 @@ func (nmgr *natManager) sync() { // doSync syncs the current NAT mappings, removing any outdated mappings and adding any // new mappings. func (nmgr *natManager) doSync() { - ports := map[string]map[int]bool{ - "tcp": {}, - "udp": {}, + for e := range nmgr.tracked { + nmgr.tracked[e] = false } + var newAddresses []entry for _, maddr := range nmgr.net.ListenAddresses() { // Strip the IP maIP, rest := ma.SplitFirst(maddr) @@ -166,48 +174,36 @@ func (nmgr *natManager) doSync() { default: continue } - port, err := strconv.ParseUint(proto.Value(), 10, 16) if err != nil { // bug in multiaddr panic(err) } - ports[protocol][int(port)] = false + e := entry{protocol: protocol, port: int(port)} + if _, ok := nmgr.tracked[e]; ok { + nmgr.tracked[e] = true + } else { + newAddresses = append(newAddresses, e) + } } var wg sync.WaitGroup defer wg.Wait() // Close old mappings - for _, m := range nmgr.nat.Mappings() { - mappedPort := m.InternalPort() - if _, ok := ports[m.Protocol()][mappedPort]; !ok { - // No longer need this mapping. - wg.Add(1) - go func(m inat.Mapping) { - defer wg.Done() - m.Close() - }(m) - } else { - // already mapped - ports[m.Protocol()][mappedPort] = true + for e, v := range nmgr.tracked { + if !v { + nmgr.nat.RemoveMapping(e.protocol, e.port) + delete(nmgr.tracked, e) } } // Create new mappings. - for proto, pports := range ports { - for port, mapped := range pports { - if mapped { - continue - } - wg.Add(1) - go func(proto string, port int) { - defer wg.Done() - if err := nmgr.nat.AddMapping(proto, port); err != nil { - log.Errorf("failed to port-map %s port %d: %s", proto, port, err) - } - }(proto, port) + for _, e := range newAddresses { + if err := nmgr.nat.AddMapping(e.protocol, e.port); err != nil { + log.Errorf("failed to port-map %s port %d: %s", e.protocol, e.port, err) } + nmgr.tracked[e] = false } } diff --git a/p2p/net/nat/mapping.go b/p2p/net/nat/mapping.go index 0641e8fc4e..4ba507e04d 100644 --- a/p2p/net/nat/mapping.go +++ b/p2p/net/nat/mapping.go @@ -24,9 +24,6 @@ type Mapping interface { // ExternalAddr returns the external facing address. If the mapping is not // established, addr will be nil, and and ErrNoMapping will be returned. ExternalAddr() (addr net.Addr, err error) - - // Close closes the port mapping - Close() error } // keeps republishing @@ -103,8 +100,3 @@ func (m *mapping) ExternalAddr() (net.Addr, error) { panic(fmt.Sprintf("invalid protocol %q", m.Protocol())) } } - -func (m *mapping) Close() error { - m.nat.removeMapping(m) - return nil -} diff --git a/p2p/net/nat/nat.go b/p2p/net/nat/nat.go index 201c887f8f..ca07082bb4 100644 --- a/p2p/net/nat/nat.go +++ b/p2p/net/nat/nat.go @@ -24,6 +24,11 @@ const MappingDuration = time.Second * 60 // CacheTime is the time a mapping will cache an external address for const CacheTime = time.Second * 15 +type entry struct { + protocol string + port int +} + // DiscoverNAT looks for a NAT device in the network and // returns an object that can manage port mappings. func DiscoverNAT(ctx context.Context) (*NAT, error) { @@ -40,7 +45,19 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) { log.Debug("DiscoverGateway address:", addr) } - return newNAT(natInstance), nil + ctx, cancel := context.WithCancel(context.Background()) + nat := &NAT{ + nat: natInstance, + mappings: make(map[entry]int), + ctx: ctx, + ctxCancel: cancel, + } + nat.refCount.Add(1) + go func() { + defer nat.refCount.Done() + nat.background() + }() + return nat, nil } // NAT is an object that manages address port mappings in @@ -57,17 +74,7 @@ type NAT struct { mappingmu sync.RWMutex // guards mappings closed bool - mappings map[*mapping]struct{} -} - -func newNAT(realNAT nat.NAT) *NAT { - ctx, cancel := context.WithCancel(context.Background()) - return &NAT{ - nat: realNAT, - mappings: make(map[*mapping]struct{}), - ctx: ctx, - ctxCancel: cancel, - } + mappings map[entry]int } // Close shuts down all port mappings. NAT can no longer be used. @@ -84,94 +91,114 @@ func (nat *NAT) Close() error { // Mappings returns a slice of all NAT mappings func (nat *NAT) Mappings() []Mapping { nat.mappingmu.Lock() + defer nat.mappingmu.Unlock() maps2 := make([]Mapping, 0, len(nat.mappings)) - for m := range nat.mappings { - maps2 = append(maps2, m) + for e, extPort := range nat.mappings { + maps2 = append(maps2, &mapping{ + nat: nat, + proto: e.protocol, + intport: e.port, + extport: extPort, + }) } - nat.mappingmu.Unlock() return maps2 } // AddMapping attempts to construct a mapping on protocol and internal port -// It will also periodically renew the mapping until the returned Mapping -// -- or its parent NAT -- is Closed. +// It will also periodically renew the mapping. // // May not succeed, and mappings may change over time; // NAT devices may not respect our port requests, and even lie. func (nat *NAT) AddMapping(protocol string, port int) error { - if nat == nil { - return fmt.Errorf("no nat available") - } - switch protocol { case "tcp", "udp": default: return fmt.Errorf("invalid protocol: %s", protocol) } - m := &mapping{ - intport: port, - nat: nat, - proto: protocol, - } - nat.mappingmu.Lock() if nat.closed { nat.mappingmu.Unlock() return errors.New("closed") } - nat.mappings[m] = struct{}{} - nat.refCount.Add(1) - nat.mappingmu.Unlock() - go nat.refreshMappings(m) // do it once synchronously, so first mapping is done right away, and before exiting, // allowing users -- in the optimistic case -- to use results right after. - nat.establishMapping(m) + extPort := nat.establishMapping(protocol, port) + nat.mappings[entry{protocol: protocol, port: port}] = extPort + nat.mappingmu.Unlock() + return nil } -func (nat *NAT) removeMapping(m *mapping) { +func (nat *NAT) RemoveMapping(protocol string, port int) error { nat.mappingmu.Lock() - delete(nat.mappings, m) - nat.mappingmu.Unlock() - nat.natmu.Lock() - nat.nat.DeletePortMapping(m.Protocol(), m.InternalPort()) - nat.natmu.Unlock() + defer nat.mappingmu.Unlock() + switch protocol { + case "tcp", "udp": + delete(nat.mappings, entry{protocol: protocol, port: port}) + default: + return fmt.Errorf("invalid protocol: %s", protocol) + } + return nil } -func (nat *NAT) refreshMappings(m *mapping) { - defer nat.refCount.Done() - t := time.NewTicker(MappingDuration / 3) +func (nat *NAT) background() { + const tick = MappingDuration / 3 + t := time.NewTimer(tick) // don't use a ticker here. We don't know how long establishing the mappings takes. defer t.Stop() + var in []entry + var out []int // port numbers for { select { case <-t.C: - nat.establishMapping(m) + in = in[:0] + out = out[:0] + nat.mappingmu.Lock() + for e := range nat.mappings { + in = append(in, e) + } + nat.mappingmu.Unlock() + // Establishing the mapping involves network requests. + // Don't hold the mutex, just save the ports. + for _, e := range in { + out = append(out, nat.establishMapping(e.protocol, e.port)) + } + nat.mappingmu.Lock() + for i, p := range in { + if _, ok := nat.mappings[p]; !ok { + continue // entry might have been deleted + } + nat.mappings[p] = out[i] + } + nat.mappingmu.Unlock() + t.Reset(tick) case <-nat.ctx.Done(): - m.Close() + nat.mappingmu.Lock() + for e := range nat.mappings { + delete(nat.mappings, e) + } + nat.mappingmu.Unlock() return } } } -func (nat *NAT) establishMapping(m *mapping) { - oldport := m.ExternalPort() - - log.Debugf("Attempting port map: %s/%d", m.Protocol(), m.InternalPort()) +func (nat *NAT) establishMapping(protocol string, internalPort int) (externalPort int) { + log.Debugf("Attempting port map: %s/%d", protocol, internalPort) const comment = "libp2p" nat.natmu.Lock() - newport, err := nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, MappingDuration) + var err error + externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, MappingDuration) if err != nil { // Some hardware does not support mappings with timeout, so try that - newport, err = nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, 0) + externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, 0) } nat.natmu.Unlock() - if err != nil || newport == 0 { - m.setExternalPort(0) // clear mapping + if err != nil || externalPort == 0 { // TODO: log.Event if err != nil { log.Warnf("failed to establish port mapping: %s", err) @@ -180,12 +207,9 @@ func (nat *NAT) establishMapping(m *mapping) { } // we do not close if the mapping failed, // because it may work again next time. - return + return 0 } - m.setExternalPort(newport) - log.Debugf("NAT Mapping: %d --> %d (%s)", m.ExternalPort(), m.InternalPort(), m.Protocol()) - if oldport != 0 && newport != oldport { - log.Debugf("failed to renew same port mapping: ch %d -> %d", oldport, newport) - } + log.Debugf("NAT Mapping: %d --> %d (%s)", externalPort, internalPort, protocol) + return externalPort } From c4fb97455cb7b41c8cb9913496ee34d340d8543b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Apr 2023 09:35:12 +0400 Subject: [PATCH 4/8] natmgr: remove unused Ready method --- p2p/host/basic/natmgr.go | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/p2p/host/basic/natmgr.go b/p2p/host/basic/natmgr.go index 522d3e827a..e271b75e2a 100644 --- a/p2p/host/basic/natmgr.go +++ b/p2p/host/basic/natmgr.go @@ -15,13 +15,12 @@ import ( ) // NATManager is a simple interface to manage NAT devices. +// It listens Listen and ListenClose notifications from the network.Network, +// and tries to obtain port mappings for those. type NATManager interface { // NAT gets the NAT device managed by the NAT manager. NAT() *inat.NAT - // Ready receives a notification when the NAT device is ready for use. - Ready() <-chan struct{} - io.Closer } @@ -46,7 +45,6 @@ type natManager struct { natMx sync.RWMutex nat *inat.NAT - ready chan struct{} // closed once the nat is ready to process port mappings syncFlag chan struct{} // cap: 1 tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function @@ -59,7 +57,6 @@ func newNatManager(net network.Network) *natManager { ctx, cancel := context.WithCancel(context.Background()) nmgr := &natManager{ net: net, - ready: make(chan struct{}), syncFlag: make(chan struct{}, 1), ctxCancel: cancel, tracked: make(map[entry]bool), @@ -77,21 +74,16 @@ func (nmgr *natManager) Close() error { return nil } -// Ready returns a channel which will be closed when the NAT has been found -// and is ready to be used, or the search process is done. -func (nmgr *natManager) Ready() <-chan struct{} { - return nmgr.ready -} - func (nmgr *natManager) background(ctx context.Context) { defer nmgr.refCount.Done() defer func() { nmgr.natMx.Lock() + defer nmgr.natMx.Unlock() + if nmgr.nat != nil { nmgr.nat.Close() } - nmgr.natMx.Unlock() }() discoverCtx, cancel := context.WithTimeout(ctx, 10*time.Second) @@ -99,14 +91,12 @@ func (nmgr *natManager) background(ctx context.Context) { natInstance, err := inat.DiscoverNAT(discoverCtx) if err != nil { log.Info("DiscoverNAT error:", err) - close(nmgr.ready) return } nmgr.natMx.Lock() nmgr.nat = natInstance nmgr.natMx.Unlock() - close(nmgr.ready) // sign natManager up for network notifications // we need to sign up here to avoid missing some notifs @@ -152,10 +142,9 @@ func (nmgr *natManager) doSync() { continue } - // Only bother if we're listening on a - // unicast/unspecified IP. + // Only bother if we're listening on an unicast / unspecified IP. ip := net.IP(maIP.RawValue()) - if !(ip.IsGlobalUnicast() || ip.IsUnspecified()) { + if !ip.IsGlobalUnicast() && !ip.IsUnspecified() { continue } @@ -222,11 +211,11 @@ func (nn *nmgrNetNotifiee) natManager() *natManager { return (*natManager)(nn) } -func (nn *nmgrNetNotifiee) Listen(n network.Network, addr ma.Multiaddr) { +func (nn *nmgrNetNotifiee) Listen(network.Network, ma.Multiaddr) { nn.natManager().sync() } -func (nn *nmgrNetNotifiee) ListenClose(n network.Network, addr ma.Multiaddr) { +func (nn *nmgrNetNotifiee) ListenClose(network.Network, ma.Multiaddr) { nn.natManager().sync() } From bdeab19c66fc084c61687d1f80628d5a98c788d6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Apr 2023 10:03:44 +0400 Subject: [PATCH 5/8] nat: properly remove port mapping --- p2p/net/nat/nat.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/p2p/net/nat/nat.go b/p2p/net/nat/nat.go index ca07082bb4..3757d6d4d8 100644 --- a/p2p/net/nat/nat.go +++ b/p2p/net/nat/nat.go @@ -29,8 +29,7 @@ type entry struct { port int } -// DiscoverNAT looks for a NAT device in the network and -// returns an object that can manage port mappings. +// DiscoverNAT looks for a NAT device in the network and returns an object that can manage port mappings. func DiscoverNAT(ctx context.Context) (*NAT, error) { natInstance, err := nat.DiscoverGateway(ctx) if err != nil { @@ -92,6 +91,7 @@ func (nat *NAT) Close() error { func (nat *NAT) Mappings() []Mapping { nat.mappingmu.Lock() defer nat.mappingmu.Unlock() + maps2 := make([]Mapping, 0, len(nat.mappings)) for e, extPort := range nat.mappings { maps2 = append(maps2, &mapping{ @@ -104,8 +104,8 @@ func (nat *NAT) Mappings() []Mapping { return maps2 } -// AddMapping attempts to construct a mapping on protocol and internal port -// It will also periodically renew the mapping. +// AddMapping attempts to construct a mapping on protocol and internal port. +// It blocks until a mapping was established. Once added, it periodically renews the mapping. // // May not succeed, and mappings may change over time; // NAT devices may not respect our port requests, and even lie. @@ -117,8 +117,9 @@ func (nat *NAT) AddMapping(protocol string, port int) error { } nat.mappingmu.Lock() + defer nat.mappingmu.Unlock() + if nat.closed { - nat.mappingmu.Unlock() return errors.New("closed") } @@ -126,21 +127,22 @@ func (nat *NAT) AddMapping(protocol string, port int) error { // allowing users -- in the optimistic case -- to use results right after. extPort := nat.establishMapping(protocol, port) nat.mappings[entry{protocol: protocol, port: port}] = extPort - nat.mappingmu.Unlock() - return nil } +// RemoveMapping removes a port mapping. +// It blocks until the NAT has removed the mapping. func (nat *NAT) RemoveMapping(protocol string, port int) error { nat.mappingmu.Lock() defer nat.mappingmu.Unlock() + switch protocol { case "tcp", "udp": delete(nat.mappings, entry{protocol: protocol, port: port}) default: return fmt.Errorf("invalid protocol: %s", protocol) } - return nil + return nat.nat.DeletePortMapping(protocol, port) } func (nat *NAT) background() { @@ -178,6 +180,7 @@ func (nat *NAT) background() { nat.mappingmu.Lock() for e := range nat.mappings { delete(nat.mappings, e) + nat.nat.DeletePortMapping(e.protocol, e.port) } nat.mappingmu.Unlock() return From 1ad214650c6dfb33fd26830932df2047189b2976 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Apr 2023 11:30:33 +0400 Subject: [PATCH 6/8] nat: replace alloc-heavy nat.Mapping with explicit GetMapping method --- p2p/host/basic/basic_host.go | 39 +++-------- p2p/net/nat/mapping.go | 102 ---------------------------- p2p/net/nat/mock_nat_test.go | 124 +++++++++++++++++++++++++++++++++++ p2p/net/nat/nat.go | 111 ++++++++++++++++++++----------- p2p/net/nat/nat_test.go | 69 +++++++++++++++++++ 5 files changed, 277 insertions(+), 168 deletions(-) delete mode 100644 p2p/net/nat/mapping.go create mode 100644 p2p/net/nat/mock_nat_test.go create mode 100644 p2p/net/nat/nat_test.go diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index ce82ad331d..4c629d2b0c 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -25,7 +25,6 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" "github.com/libp2p/go-libp2p/p2p/host/relaysvc" - inat "github.com/libp2p/go-libp2p/p2p/net/nat" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" @@ -858,35 +857,12 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { finalAddrs = dedupAddrs(finalAddrs) - var natMappings []inat.Mapping - // natmgr is nil if we do not use nat option; // h.natmgr.NAT() is nil if not ready, or no nat is available. if h.natmgr != nil && h.natmgr.NAT() != nil { - natMappings = h.natmgr.NAT().Mappings() - } - - if len(natMappings) > 0 { // We have successfully mapped ports on our NAT. Use those // instead of observed addresses (mostly). - // First, generate a mapping table. - // protocol -> internal port -> external addr - ports := make(map[string]map[int]net.Addr) - for _, m := range natMappings { - addr, err := m.ExternalAddr() - if err != nil { - // mapping not ready yet. - continue - } - protoPorts, ok := ports[m.Protocol()] - if !ok { - protoPorts = make(map[int]net.Addr) - ports[m.Protocol()] = protoPorts - } - protoPorts[m.InternalPort()] = addr - } - // Next, apply this mapping to our addresses. for _, listen := range listenAddrs { found := false @@ -929,23 +905,28 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { } if !ip.IsGlobalUnicast() && !ip.IsUnspecified() { - // We only map global unicast & unspecified addresses ports. - // Not broadcast, multicast, etc. + // We only map global unicast & unspecified addresses ports, not broadcast, multicast, etc. continue } - mappedAddr, ok := ports[protocol][iport] + extAddr, ok := h.natmgr.NAT().GetMapping(protocol, iport) if !ok { - // Not mapped. + // not mapped continue } + var mappedAddr net.Addr + switch naddr.(type) { + case *net.TCPAddr: + mappedAddr = net.TCPAddrFromAddrPort(extAddr) + case *net.UDPAddr: + mappedAddr = net.UDPAddrFromAddrPort(extAddr) + } mappedMaddr, err := manet.FromNetAddr(mappedAddr) if err != nil { log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", mappedAddr, err) continue } - extMaddr := mappedMaddr if rest != nil { extMaddr = ma.Join(extMaddr, rest) diff --git a/p2p/net/nat/mapping.go b/p2p/net/nat/mapping.go deleted file mode 100644 index 4ba507e04d..0000000000 --- a/p2p/net/nat/mapping.go +++ /dev/null @@ -1,102 +0,0 @@ -package nat - -import ( - "fmt" - "net" - "sync" - "time" -) - -// Mapping represents a port mapping in a NAT. -type Mapping interface { - // Protocol returns the protocol of this port mapping. This is either - // "tcp" or "udp" as no other protocols are likely to be NAT-supported. - Protocol() string - - // InternalPort returns the internal device port. Mapping will continue to - // try to map InternalPort() to an external facing port. - InternalPort() int - - // ExternalPort returns the external facing port. If the mapping is not - // established, port will be 0 - ExternalPort() int - - // ExternalAddr returns the external facing address. If the mapping is not - // established, addr will be nil, and and ErrNoMapping will be returned. - ExternalAddr() (addr net.Addr, err error) -} - -// keeps republishing -type mapping struct { - sync.Mutex // guards all fields - - nat *NAT - proto string - intport int - extport int - - cached net.IP - cacheTime time.Time - cacheLk sync.Mutex -} - -func (m *mapping) Protocol() string { - m.Lock() - defer m.Unlock() - return m.proto -} - -func (m *mapping) InternalPort() int { - m.Lock() - defer m.Unlock() - return m.intport -} - -func (m *mapping) ExternalPort() int { - m.Lock() - defer m.Unlock() - return m.extport -} - -func (m *mapping) setExternalPort(p int) { - m.Lock() - defer m.Unlock() - m.extport = p -} - -func (m *mapping) ExternalAddr() (net.Addr, error) { - m.cacheLk.Lock() - defer m.cacheLk.Unlock() - oport := m.ExternalPort() - if oport == 0 { - // dont even try right now. - return nil, ErrNoMapping - } - - if time.Since(m.cacheTime) >= CacheTime { - m.nat.natmu.Lock() - cval, err := m.nat.nat.GetExternalAddress() - m.nat.natmu.Unlock() - - if err != nil { - return nil, err - } - - m.cached = cval - m.cacheTime = time.Now() - } - switch m.Protocol() { - case "tcp": - return &net.TCPAddr{ - IP: m.cached, - Port: oport, - }, nil - case "udp": - return &net.UDPAddr{ - IP: m.cached, - Port: oport, - }, nil - default: - panic(fmt.Sprintf("invalid protocol %q", m.Protocol())) - } -} diff --git a/p2p/net/nat/mock_nat_test.go b/p2p/net/nat/mock_nat_test.go new file mode 100644 index 0000000000..bb91bac247 --- /dev/null +++ b/p2p/net/nat/mock_nat_test.go @@ -0,0 +1,124 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-nat (interfaces: NAT) + +// Package nat is a generated GoMock package. +package nat + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" +) + +// MockNAT is a mock of NAT interface. +type MockNAT struct { + ctrl *gomock.Controller + recorder *MockNATMockRecorder +} + +// MockNATMockRecorder is the mock recorder for MockNAT. +type MockNATMockRecorder struct { + mock *MockNAT +} + +// NewMockNAT creates a new mock instance. +func NewMockNAT(ctrl *gomock.Controller) *MockNAT { + mock := &MockNAT{ctrl: ctrl} + mock.recorder = &MockNATMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNAT) EXPECT() *MockNATMockRecorder { + return m.recorder +} + +// AddPortMapping mocks base method. +func (m *MockNAT) AddPortMapping(arg0 string, arg1 int, arg2 string, arg3 time.Duration) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddPortMapping", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddPortMapping indicates an expected call of AddPortMapping. +func (mr *MockNATMockRecorder) AddPortMapping(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPortMapping", reflect.TypeOf((*MockNAT)(nil).AddPortMapping), arg0, arg1, arg2, arg3) +} + +// DeletePortMapping mocks base method. +func (m *MockNAT) DeletePortMapping(arg0 string, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePortMapping", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePortMapping indicates an expected call of DeletePortMapping. +func (mr *MockNATMockRecorder) DeletePortMapping(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePortMapping", reflect.TypeOf((*MockNAT)(nil).DeletePortMapping), arg0, arg1) +} + +// GetDeviceAddress mocks base method. +func (m *MockNAT) GetDeviceAddress() (net.IP, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDeviceAddress") + ret0, _ := ret[0].(net.IP) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDeviceAddress indicates an expected call of GetDeviceAddress. +func (mr *MockNATMockRecorder) GetDeviceAddress() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeviceAddress", reflect.TypeOf((*MockNAT)(nil).GetDeviceAddress)) +} + +// GetExternalAddress mocks base method. +func (m *MockNAT) GetExternalAddress() (net.IP, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExternalAddress") + ret0, _ := ret[0].(net.IP) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExternalAddress indicates an expected call of GetExternalAddress. +func (mr *MockNATMockRecorder) GetExternalAddress() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExternalAddress", reflect.TypeOf((*MockNAT)(nil).GetExternalAddress)) +} + +// GetInternalAddress mocks base method. +func (m *MockNAT) GetInternalAddress() (net.IP, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInternalAddress") + ret0, _ := ret[0].(net.IP) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInternalAddress indicates an expected call of GetInternalAddress. +func (mr *MockNATMockRecorder) GetInternalAddress() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInternalAddress", reflect.TypeOf((*MockNAT)(nil).GetInternalAddress)) +} + +// Type mocks base method. +func (m *MockNAT) Type() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Type") + ret0, _ := ret[0].(string) + return ret0 +} + +// Type indicates an expected call of Type. +func (mr *MockNATMockRecorder) Type() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*MockNAT)(nil).Type)) +} diff --git a/p2p/net/nat/nat.go b/p2p/net/nat/nat.go index 3757d6d4d8..68834ac877 100644 --- a/p2p/net/nat/nat.go +++ b/p2p/net/nat/nat.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/netip" "sync" "time" @@ -19,22 +20,30 @@ var log = logging.Logger("nat") // MappingDuration is a default port mapping duration. // Port mappings are renewed every (MappingDuration / 3) -const MappingDuration = time.Second * 60 +const MappingDuration = time.Minute // CacheTime is the time a mapping will cache an external address for -const CacheTime = time.Second * 15 +const CacheTime = 15 * time.Second type entry struct { protocol string port int } +// so we can mock it in tests +var discoverGateway = nat.DiscoverGateway + // DiscoverNAT looks for a NAT device in the network and returns an object that can manage port mappings. func DiscoverNAT(ctx context.Context) (*NAT, error) { - natInstance, err := nat.DiscoverGateway(ctx) + natInstance, err := discoverGateway(ctx) if err != nil { return nil, err } + var extAddr netip.Addr + extIP, err := natInstance.GetExternalAddress() + if err == nil { + extAddr, _ = netip.AddrFromSlice(extIP) + } // Log the device addr. addr, err := natInstance.GetDeviceAddress() @@ -47,6 +56,7 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) { ctx, cancel := context.WithCancel(context.Background()) nat := &NAT{ nat: natInstance, + extAddr: extAddr, mappings: make(map[entry]int), ctx: ctx, ctxCancel: cancel, @@ -66,6 +76,8 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) { type NAT struct { natmu sync.Mutex nat nat.NAT + // External IP of the NAT. Will be renewed periodically (every CacheTime). + extAddr netip.Addr refCount sync.WaitGroup ctx context.Context @@ -87,21 +99,18 @@ func (nat *NAT) Close() error { return nil } -// Mappings returns a slice of all NAT mappings -func (nat *NAT) Mappings() []Mapping { +func (nat *NAT) GetMapping(protocol string, port int) (addr netip.AddrPort, found bool) { nat.mappingmu.Lock() defer nat.mappingmu.Unlock() - maps2 := make([]Mapping, 0, len(nat.mappings)) - for e, extPort := range nat.mappings { - maps2 = append(maps2, &mapping{ - nat: nat, - proto: e.protocol, - intport: e.port, - extport: extPort, - }) + if !nat.extAddr.IsValid() { + return netip.AddrPort{}, false + } + extPort, found := nat.mappings[entry{protocol: protocol, port: port}] + if !found { + return netip.AddrPort{}, false } - return maps2 + return netip.AddrPortFrom(nat.extAddr, uint16(extPort)), true } // AddMapping attempts to construct a mapping on protocol and internal port. @@ -138,44 +147,65 @@ func (nat *NAT) RemoveMapping(protocol string, port int) error { switch protocol { case "tcp", "udp": - delete(nat.mappings, entry{protocol: protocol, port: port}) + e := entry{protocol: protocol, port: port} + if _, ok := nat.mappings[e]; ok { + delete(nat.mappings, e) + return nat.nat.DeletePortMapping(protocol, port) + } + return errors.New("unknown mapping") default: return fmt.Errorf("invalid protocol: %s", protocol) } - return nat.nat.DeletePortMapping(protocol, port) } func (nat *NAT) background() { - const tick = MappingDuration / 3 - t := time.NewTimer(tick) // don't use a ticker here. We don't know how long establishing the mappings takes. + const mappingUpdate = MappingDuration / 3 + + now := time.Now() + nextMappingUpdate := now.Add(mappingUpdate) + nextAddrUpdate := now.Add(CacheTime) + + t := time.NewTimer(minTime(nextMappingUpdate, nextAddrUpdate).Sub(now)) // don't use a ticker here. We don't know how long establishing the mappings takes. defer t.Stop() var in []entry var out []int // port numbers for { select { - case <-t.C: - in = in[:0] - out = out[:0] - nat.mappingmu.Lock() - for e := range nat.mappings { - in = append(in, e) - } - nat.mappingmu.Unlock() - // Establishing the mapping involves network requests. - // Don't hold the mutex, just save the ports. - for _, e := range in { - out = append(out, nat.establishMapping(e.protocol, e.port)) + case now := <-t.C: + if now.After(nextMappingUpdate) { + in = in[:0] + out = out[:0] + nat.mappingmu.Lock() + for e := range nat.mappings { + in = append(in, e) + } + nat.mappingmu.Unlock() + // Establishing the mapping involves network requests. + // Don't hold the mutex, just save the ports. + for _, e := range in { + out = append(out, nat.establishMapping(e.protocol, e.port)) + } + nat.mappingmu.Lock() + for i, p := range in { + if _, ok := nat.mappings[p]; !ok { + continue // entry might have been deleted + } + nat.mappings[p] = out[i] + } + nat.mappingmu.Unlock() + nextMappingUpdate = time.Now().Add(mappingUpdate) } - nat.mappingmu.Lock() - for i, p := range in { - if _, ok := nat.mappings[p]; !ok { - continue // entry might have been deleted + if now.After(nextAddrUpdate) { + var extAddr netip.Addr + extIP, err := nat.nat.GetExternalAddress() + if err == nil { + extAddr, _ = netip.AddrFromSlice(extIP) } - nat.mappings[p] = out[i] + nat.extAddr = extAddr + nextAddrUpdate = time.Now().Add(CacheTime) } - nat.mappingmu.Unlock() - t.Reset(tick) + t.Reset(time.Until(minTime(nextAddrUpdate, nextMappingUpdate))) case <-nat.ctx.Done(): nat.mappingmu.Lock() for e := range nat.mappings { @@ -216,3 +246,10 @@ func (nat *NAT) establishMapping(protocol string, internalPort int) (externalPor log.Debugf("NAT Mapping: %d --> %d (%s)", externalPort, internalPort, protocol) return externalPort } + +func minTime(a, b time.Time) time.Time { + if a.Before(b) { + return a + } + return b +} diff --git a/p2p/net/nat/nat_test.go b/p2p/net/nat/nat_test.go new file mode 100644 index 0000000000..8fffb512c7 --- /dev/null +++ b/p2p/net/nat/nat_test.go @@ -0,0 +1,69 @@ +package nat + +import ( + "context" + "errors" + "net" + "net/netip" + "testing" + + "github.com/libp2p/go-nat" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +//go:generate sh -c "go run github.com/golang/mock/mockgen -package nat -destination mock_nat_test.go github.com/libp2p/go-nat NAT" + +func setupMockNAT(t *testing.T) (mockNAT *MockNAT, reset func()) { + t.Helper() + ctrl := gomock.NewController(t) + mockNAT = NewMockNAT(ctrl) + mockNAT.EXPECT().GetDeviceAddress().Return(nil, errors.New("nope")) // is only used for logging + origDiscoverGateway := discoverGateway + discoverGateway = func(ctx context.Context) (nat.NAT, error) { return mockNAT, nil } + return mockNAT, func() { + discoverGateway = origDiscoverGateway + ctrl.Finish() + } +} + +func TestAddMapping(t *testing.T) { + mockNAT, reset := setupMockNAT(t) + defer reset() + + mockNAT.EXPECT().GetExternalAddress().Return(net.IPv4(1, 2, 3, 4), nil) + nat, err := DiscoverNAT(context.Background()) + require.NoError(t, err) + + mockNAT.EXPECT().AddPortMapping("tcp", 10000, gomock.Any(), MappingDuration).Return(1234, nil) + require.NoError(t, nat.AddMapping("tcp", 10000)) + + _, found := nat.GetMapping("tcp", 9999) + require.False(t, found, "didn't expect a port mapping for unmapped port") + _, found = nat.GetMapping("udp", 10000) + require.False(t, found, "didn't expect a port mapping for unmapped protocol") + mapped, found := nat.GetMapping("tcp", 10000) + require.True(t, found, "expected port mapping") + require.Equal(t, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 1234), mapped) +} + +func TestRemoveMapping(t *testing.T) { + mockNAT, reset := setupMockNAT(t) + defer reset() + + mockNAT.EXPECT().GetExternalAddress().Return(net.IPv4(1, 2, 3, 4), nil) + nat, err := DiscoverNAT(context.Background()) + require.NoError(t, err) + mockNAT.EXPECT().AddPortMapping("tcp", 10000, gomock.Any(), MappingDuration).Return(1234, nil) + require.NoError(t, nat.AddMapping("tcp", 10000)) + _, found := nat.GetMapping("tcp", 10000) + require.True(t, found, "expected port mapping") + + require.Error(t, nat.RemoveMapping("tcp", 9999), "expected error for unknown mapping") + mockNAT.EXPECT().DeletePortMapping("tcp", 10000) + require.NoError(t, nat.RemoveMapping("tcp", 10000)) + + _, found = nat.GetMapping("tcp", 10000) + require.False(t, found, "didn't expect port mapping for deleted mapping") +} From 4ee60e39d0a11723a4c459c0f86566f9d5474238 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Apr 2023 12:23:38 +0400 Subject: [PATCH 7/8] move more of the NAT mapping logging to the NAT manager --- p2p/host/basic/basic_host.go | 68 +----------------- p2p/host/basic/mock_nat_test.go | 92 ++++++++++++++++++++++++ p2p/host/basic/mockgen_private.sh | 49 +++++++++++++ p2p/host/basic/natmgr.go | 115 +++++++++++++++++++++++------- p2p/host/basic/natmgr_test.go | 110 ++++++++++++++++++++++++++++ 5 files changed, 345 insertions(+), 89 deletions(-) create mode 100644 p2p/host/basic/mock_nat_test.go create mode 100755 p2p/host/basic/mockgen_private.sh create mode 100644 p2p/host/basic/natmgr_test.go diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 4c629d2b0c..dbefefa481 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -858,80 +858,18 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { finalAddrs = dedupAddrs(finalAddrs) // natmgr is nil if we do not use nat option; - // h.natmgr.NAT() is nil if not ready, or no nat is available. - if h.natmgr != nil && h.natmgr.NAT() != nil { + if h.natmgr != nil { // We have successfully mapped ports on our NAT. Use those // instead of observed addresses (mostly). // Next, apply this mapping to our addresses. for _, listen := range listenAddrs { - found := false - transport, rest := ma.SplitFunc(listen, func(c ma.Component) bool { - if found { - return true - } - switch c.Protocol().Code { - case ma.P_TCP, ma.P_UDP: - found = true - } - return false - }) - if !manet.IsThinWaist(transport) { - continue - } - - naddr, err := manet.ToNetAddr(transport) - if err != nil { - log.Error("error parsing net multiaddr %q: %s", transport, err) - continue - } - - var ( - ip net.IP - iport int - protocol string - ) - switch naddr := naddr.(type) { - case *net.TCPAddr: - ip = naddr.IP - iport = naddr.Port - protocol = "tcp" - case *net.UDPAddr: - ip = naddr.IP - iport = naddr.Port - protocol = "udp" - default: - continue - } - - if !ip.IsGlobalUnicast() && !ip.IsUnspecified() { - // We only map global unicast & unspecified addresses ports, not broadcast, multicast, etc. - continue - } - - extAddr, ok := h.natmgr.NAT().GetMapping(protocol, iport) - if !ok { + extMaddr := h.natmgr.GetMapping(listen) + if extMaddr == nil { // not mapped continue } - var mappedAddr net.Addr - switch naddr.(type) { - case *net.TCPAddr: - mappedAddr = net.TCPAddrFromAddrPort(extAddr) - case *net.UDPAddr: - mappedAddr = net.UDPAddrFromAddrPort(extAddr) - } - mappedMaddr, err := manet.FromNetAddr(mappedAddr) - if err != nil { - log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", mappedAddr, err) - continue - } - extMaddr := mappedMaddr - if rest != nil { - extMaddr = ma.Join(extMaddr, rest) - } - // if the router reported a sane address if !manet.IsIPUnspecified(extMaddr) { // Add in the mapped addr. diff --git a/p2p/host/basic/mock_nat_test.go b/p2p/host/basic/mock_nat_test.go new file mode 100644 index 0000000000..b6d7e9c526 --- /dev/null +++ b/p2p/host/basic/mock_nat_test.go @@ -0,0 +1,92 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: natmgr.go + +// Package basichost is a generated GoMock package. +package basichost + +import ( + netip "net/netip" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockNat is a mock of Nat interface. +type MockNat struct { + ctrl *gomock.Controller + recorder *MockNatMockRecorder +} + +// MockNatMockRecorder is the mock recorder for MockNat. +type MockNatMockRecorder struct { + mock *MockNat +} + +// NewMockNat creates a new mock instance. +func NewMockNat(ctrl *gomock.Controller) *MockNat { + mock := &MockNat{ctrl: ctrl} + mock.recorder = &MockNatMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNat) EXPECT() *MockNatMockRecorder { + return m.recorder +} + +// AddMapping mocks base method. +func (m *MockNat) AddMapping(protocol string, port int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddMapping", protocol, port) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddMapping indicates an expected call of AddMapping. +func (mr *MockNatMockRecorder) AddMapping(protocol, port interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddMapping", reflect.TypeOf((*MockNat)(nil).AddMapping), protocol, port) +} + +// Close mocks base method. +func (m *MockNat) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockNatMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNat)(nil).Close)) +} + +// GetMapping mocks base method. +func (m *MockNat) GetMapping(protocol string, port int) (netip.AddrPort, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMapping", protocol, port) + ret0, _ := ret[0].(netip.AddrPort) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetMapping indicates an expected call of GetMapping. +func (mr *MockNatMockRecorder) GetMapping(protocol, port interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMapping", reflect.TypeOf((*MockNat)(nil).GetMapping), protocol, port) +} + +// RemoveMapping mocks base method. +func (m *MockNat) RemoveMapping(protocol string, port int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveMapping", protocol, port) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveMapping indicates an expected call of RemoveMapping. +func (mr *MockNatMockRecorder) RemoveMapping(protocol, port interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveMapping", reflect.TypeOf((*MockNat)(nil).RemoveMapping), protocol, port) +} diff --git a/p2p/host/basic/mockgen_private.sh b/p2p/host/basic/mockgen_private.sh new file mode 100755 index 0000000000..79f63eee3e --- /dev/null +++ b/p2p/host/basic/mockgen_private.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +DEST=$2 +PACKAGE=$3 +TMPFILE="mockgen_tmp.go" +# uppercase the name of the interface +ORIG_INTERFACE_NAME=$4 +INTERFACE_NAME="$(tr '[:lower:]' '[:upper:]' <<< ${ORIG_INTERFACE_NAME:0:1})${ORIG_INTERFACE_NAME:1}" + +# Gather all files that contain interface definitions. +# These interfaces might be used as embedded interfaces, +# so we need to pass them to mockgen as aux_files. +AUX=() +for f in *.go; do + if [[ -z ${f##*_test.go} ]]; then + # skip test files + continue; + fi + if $(egrep -qe "type (.*) interface" $f); then + AUX+=("github.com/quic-go/quic-go=$f") + fi +done + +# Find the file that defines the interface we're mocking. +for f in *.go; do + if [[ -z ${f##*_test.go} ]]; then + # skip test files + continue; + fi + INTERFACE=$(sed -n "/^type $ORIG_INTERFACE_NAME interface/,/^}/p" $f) + if [[ -n "$INTERFACE" ]]; then + SRC=$f + break + fi +done + +if [[ -z "$INTERFACE" ]]; then + echo "Interface $ORIG_INTERFACE_NAME not found." + exit 1 +fi + +AUX_FILES=$(IFS=, ; echo "${AUX[*]}") + +## create a public alias for the interface, so that mockgen can process it +echo -e "package $1\n" > $TMPFILE +echo "$INTERFACE" | sed "s/$ORIG_INTERFACE_NAME/$INTERFACE_NAME/" >> $TMPFILE +go run github.com/golang/mock/mockgen -package $1 -self_package $3 -destination $DEST -source=$TMPFILE -aux_files $AUX_FILES +sed "s/$TMPFILE/$SRC/" "$DEST" > "$DEST.new" && mv "$DEST.new" "$DEST" +rm "$TMPFILE" diff --git a/p2p/host/basic/natmgr.go b/p2p/host/basic/natmgr.go index e271b75e2a..6ebe37b9b7 100644 --- a/p2p/host/basic/natmgr.go +++ b/p2p/host/basic/natmgr.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "net/netip" "strconv" "sync" "time" @@ -12,21 +13,20 @@ import ( inat "github.com/libp2p/go-libp2p/p2p/net/nat" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" ) // NATManager is a simple interface to manage NAT devices. // It listens Listen and ListenClose notifications from the network.Network, // and tries to obtain port mappings for those. type NATManager interface { - // NAT gets the NAT device managed by the NAT manager. - NAT() *inat.NAT - + GetMapping(ma.Multiaddr) ma.Multiaddr io.Closer } // NewNATManager creates a NAT manager. func NewNATManager(net network.Network) NATManager { - return newNatManager(net) + return newNATManager(net) } type entry struct { @@ -34,6 +34,16 @@ type entry struct { port int } +type nat interface { + AddMapping(protocol string, port int) error + RemoveMapping(protocol string, port int) error + GetMapping(protocol string, port int) (netip.AddrPort, bool) + io.Closer +} + +// so we can mock it in tests +var discoverNAT = func(ctx context.Context) (nat, error) { return inat.DiscoverNAT(ctx) } + // natManager takes care of adding + removing port mappings to the nat. // Initialized with the host if it has a NATPortMap option enabled. // natManager receives signals from the network, and check on nat mappings: @@ -43,7 +53,7 @@ type entry struct { type natManager struct { net network.Network natMx sync.RWMutex - nat *inat.NAT + nat nat syncFlag chan struct{} // cap: 1 @@ -53,7 +63,7 @@ type natManager struct { ctxCancel context.CancelFunc } -func newNatManager(net network.Network) *natManager { +func newNATManager(net network.Network) *natManager { ctx, cancel := context.WithCancel(context.Background()) nmgr := &natManager{ net: net, @@ -88,7 +98,7 @@ func (nmgr *natManager) background(ctx context.Context) { discoverCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - natInstance, err := inat.DiscoverNAT(discoverCtx) + natInstance, err := discoverNAT(discoverCtx) if err != nil { log.Info("DiscoverNAT error:", err) return @@ -196,28 +206,85 @@ func (nmgr *natManager) doSync() { } } -// NAT returns the natManager's nat object. this may be nil, if -// (a) the search process is still ongoing, or (b) the search process -// found no nat. Clients must check whether the return value is nil. -func (nmgr *natManager) NAT() *inat.NAT { +func (nmgr *natManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { nmgr.natMx.Lock() defer nmgr.natMx.Unlock() - return nmgr.nat -} -type nmgrNetNotifiee natManager + if nmgr.nat == nil { // NAT not yet initialized + return nil + } -func (nn *nmgrNetNotifiee) natManager() *natManager { - return (*natManager)(nn) -} + var found bool + var proto int // ma.P_TCP or ma.P_UDP + transport, rest := ma.SplitFunc(addr, func(c ma.Component) bool { + if found { + return true + } + proto = c.Protocol().Code + found = proto == ma.P_TCP || proto == ma.P_UDP + return false + }) + if !manet.IsThinWaist(transport) { + return nil + } -func (nn *nmgrNetNotifiee) Listen(network.Network, ma.Multiaddr) { - nn.natManager().sync() -} + naddr, err := manet.ToNetAddr(transport) + if err != nil { + log.Error("error parsing net multiaddr %q: %s", transport, err) + return nil + } -func (nn *nmgrNetNotifiee) ListenClose(network.Network, ma.Multiaddr) { - nn.natManager().sync() + var ( + ip net.IP + port int + protocol string + ) + switch naddr := naddr.(type) { + case *net.TCPAddr: + ip = naddr.IP + port = naddr.Port + protocol = "tcp" + case *net.UDPAddr: + ip = naddr.IP + port = naddr.Port + protocol = "udp" + default: + return nil + } + + if !ip.IsGlobalUnicast() && !ip.IsUnspecified() { + // We only map global unicast & unspecified addresses ports, not broadcast, multicast, etc. + return nil + } + + extAddr, ok := nmgr.nat.GetMapping(protocol, port) + if !ok { + return nil + } + + var mappedAddr net.Addr + switch naddr.(type) { + case *net.TCPAddr: + mappedAddr = net.TCPAddrFromAddrPort(extAddr) + case *net.UDPAddr: + mappedAddr = net.UDPAddrFromAddrPort(extAddr) + } + mappedMaddr, err := manet.FromNetAddr(mappedAddr) + if err != nil { + log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", mappedAddr, err) + return nil + } + extMaddr := mappedMaddr + if rest != nil { + extMaddr = ma.Join(extMaddr, rest) + } + return extMaddr } -func (nn *nmgrNetNotifiee) Connected(network.Network, network.Conn) {} -func (nn *nmgrNetNotifiee) Disconnected(network.Network, network.Conn) {} +type nmgrNetNotifiee natManager + +func (nn *nmgrNetNotifiee) natManager() *natManager { return (*natManager)(nn) } +func (nn *nmgrNetNotifiee) Listen(network.Network, ma.Multiaddr) { nn.natManager().sync() } +func (nn *nmgrNetNotifiee) ListenClose(n network.Network, addr ma.Multiaddr) { nn.natManager().sync() } +func (nn *nmgrNetNotifiee) Connected(network.Network, network.Conn) {} +func (nn *nmgrNetNotifiee) Disconnected(network.Network, network.Conn) {} diff --git a/p2p/host/basic/natmgr_test.go b/p2p/host/basic/natmgr_test.go new file mode 100644 index 0000000000..8ee772e62a --- /dev/null +++ b/p2p/host/basic/natmgr_test.go @@ -0,0 +1,110 @@ +package basichost + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" + + ma "github.com/multiformats/go-multiaddr" + + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + + "github.com/golang/mock/gomock" +) + +//go:generate sh -c "./mockgen_private.sh basichost mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic nat" + +func setupMockNAT(t *testing.T) (mockNAT *MockNat, reset func()) { + t.Helper() + ctrl := gomock.NewController(t) + mockNAT = NewMockNat(ctrl) + origDiscoverNAT := discoverNAT + discoverNAT = func(ctx context.Context) (nat, error) { return mockNAT, nil } + return mockNAT, func() { + discoverNAT = origDiscoverNAT + ctrl.Finish() + } +} + +func TestMapping(t *testing.T) { + mockNAT, reset := setupMockNAT(t) + defer reset() + + sw := swarmt.GenSwarm(t) + defer sw.Close() + m := newNATManager(sw) + require.Eventually(t, func() bool { + m.natMx.Lock() + defer m.natMx.Unlock() + return m.nat != nil + }, time.Second, time.Millisecond) + externalAddr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 4321) + // pretend that we have a TCP mapping + mockNAT.EXPECT().GetMapping("tcp", 1234).Return(externalAddr, true) + require.Equal(t, ma.StringCast("/ip4/1.2.3.4/tcp/4321"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) + + // pretend that we have a QUIC mapping + mockNAT.EXPECT().GetMapping("udp", 1234).Return(externalAddr, true) + require.Equal(t, ma.StringCast("/ip4/1.2.3.4/udp/4321/quic-v1"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1"))) + + // pretend that there's no mapping + mockNAT.EXPECT().GetMapping("tcp", 1234).Return(netip.AddrPort{}, false) + require.Nil(t, m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) + + // make sure this works for WebSocket addresses as well + mockNAT.EXPECT().GetMapping("tcp", 1234).Return(externalAddr, true) + require.Equal(t, ma.StringCast("/ip4/1.2.3.4/tcp/4321/ws"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234/ws"))) + + // make sure this works for WebTransport addresses as well + mockNAT.EXPECT().GetMapping("udp", 1234).Return(externalAddr, true) + require.Equal(t, ma.StringCast("/ip4/1.2.3.4/udp/4321/quic-v1/webtransport"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1/webtransport"))) +} + +func TestAddAndRemoveListeners(t *testing.T) { + mockNAT, reset := setupMockNAT(t) + defer reset() + + sw := swarmt.GenSwarm(t) + defer sw.Close() + m := newNATManager(sw) + require.Eventually(t, func() bool { + m.natMx.Lock() + defer m.natMx.Unlock() + return m.nat != nil + }, time.Second, time.Millisecond) + + added := make(chan struct{}, 1) + // add a TCP listener + mockNAT.EXPECT().AddMapping("tcp", 1234).Do(func(string, int) { added <- struct{}{} }) + require.NoError(t, sw.Listen(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) + select { + case <-added: + case <-time.After(time.Second): + t.Fatal("didn't receive call to AddMapping") + } + + // add a QUIC listener + mockNAT.EXPECT().AddMapping("udp", 1234).Do(func(string, int) { added <- struct{}{} }) + require.NoError(t, sw.Listen(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1"))) + select { + case <-added: + case <-time.After(time.Second): + t.Fatal("didn't receive call to AddMapping") + } + + // remove the QUIC listener + mockNAT.EXPECT().RemoveMapping("udp", 1234).Do(func(string, int) { added <- struct{}{} }) + sw.ListenClose(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1")) + select { + case <-added: + case <-time.After(time.Second): + t.Fatal("didn't receive call to RemoveMapping") + } + + // test shutdown + mockNAT.EXPECT().RemoveMapping("tcp", 1234).MaxTimes(1) + mockNAT.EXPECT().Close().MaxTimes(1) +} From 60e9adedf6f884fd2fda2f804391fd6d14e50253 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 12 Apr 2023 16:59:18 +0200 Subject: [PATCH 8/8] remove ugly private mockgen workaround --- p2p/host/basic/mock_nat_test.go | 54 +++++++++++++++---------------- p2p/host/basic/mockgen_private.sh | 49 ---------------------------- p2p/host/basic/mocks.go | 6 ++++ p2p/host/basic/natmgr_test.go | 6 ++-- 4 files changed, 35 insertions(+), 80 deletions(-) delete mode 100755 p2p/host/basic/mockgen_private.sh create mode 100644 p2p/host/basic/mocks.go diff --git a/p2p/host/basic/mock_nat_test.go b/p2p/host/basic/mock_nat_test.go index b6d7e9c526..7714b25853 100644 --- a/p2p/host/basic/mock_nat_test.go +++ b/p2p/host/basic/mock_nat_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: natmgr.go +// Source: github.com/libp2p/go-libp2p/p2p/host/basic (interfaces: NAT) // Package basichost is a generated GoMock package. package basichost @@ -11,45 +11,45 @@ import ( gomock "github.com/golang/mock/gomock" ) -// MockNat is a mock of Nat interface. -type MockNat struct { +// MockNAT is a mock of NAT interface. +type MockNAT struct { ctrl *gomock.Controller - recorder *MockNatMockRecorder + recorder *MockNATMockRecorder } -// MockNatMockRecorder is the mock recorder for MockNat. -type MockNatMockRecorder struct { - mock *MockNat +// MockNATMockRecorder is the mock recorder for MockNAT. +type MockNATMockRecorder struct { + mock *MockNAT } -// NewMockNat creates a new mock instance. -func NewMockNat(ctrl *gomock.Controller) *MockNat { - mock := &MockNat{ctrl: ctrl} - mock.recorder = &MockNatMockRecorder{mock} +// NewMockNAT creates a new mock instance. +func NewMockNAT(ctrl *gomock.Controller) *MockNAT { + mock := &MockNAT{ctrl: ctrl} + mock.recorder = &MockNATMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockNat) EXPECT() *MockNatMockRecorder { +func (m *MockNAT) EXPECT() *MockNATMockRecorder { return m.recorder } // AddMapping mocks base method. -func (m *MockNat) AddMapping(protocol string, port int) error { +func (m *MockNAT) AddMapping(arg0 string, arg1 int) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddMapping", protocol, port) + ret := m.ctrl.Call(m, "AddMapping", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // AddMapping indicates an expected call of AddMapping. -func (mr *MockNatMockRecorder) AddMapping(protocol, port interface{}) *gomock.Call { +func (mr *MockNATMockRecorder) AddMapping(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddMapping", reflect.TypeOf((*MockNat)(nil).AddMapping), protocol, port) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddMapping", reflect.TypeOf((*MockNAT)(nil).AddMapping), arg0, arg1) } // Close mocks base method. -func (m *MockNat) Close() error { +func (m *MockNAT) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) @@ -57,36 +57,36 @@ func (m *MockNat) Close() error { } // Close indicates an expected call of Close. -func (mr *MockNatMockRecorder) Close() *gomock.Call { +func (mr *MockNATMockRecorder) Close() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNat)(nil).Close)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNAT)(nil).Close)) } // GetMapping mocks base method. -func (m *MockNat) GetMapping(protocol string, port int) (netip.AddrPort, bool) { +func (m *MockNAT) GetMapping(arg0 string, arg1 int) (netip.AddrPort, bool) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMapping", protocol, port) + ret := m.ctrl.Call(m, "GetMapping", arg0, arg1) ret0, _ := ret[0].(netip.AddrPort) ret1, _ := ret[1].(bool) return ret0, ret1 } // GetMapping indicates an expected call of GetMapping. -func (mr *MockNatMockRecorder) GetMapping(protocol, port interface{}) *gomock.Call { +func (mr *MockNATMockRecorder) GetMapping(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMapping", reflect.TypeOf((*MockNat)(nil).GetMapping), protocol, port) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMapping", reflect.TypeOf((*MockNAT)(nil).GetMapping), arg0, arg1) } // RemoveMapping mocks base method. -func (m *MockNat) RemoveMapping(protocol string, port int) error { +func (m *MockNAT) RemoveMapping(arg0 string, arg1 int) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveMapping", protocol, port) + ret := m.ctrl.Call(m, "RemoveMapping", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // RemoveMapping indicates an expected call of RemoveMapping. -func (mr *MockNatMockRecorder) RemoveMapping(protocol, port interface{}) *gomock.Call { +func (mr *MockNATMockRecorder) RemoveMapping(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveMapping", reflect.TypeOf((*MockNat)(nil).RemoveMapping), protocol, port) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveMapping", reflect.TypeOf((*MockNAT)(nil).RemoveMapping), arg0, arg1) } diff --git a/p2p/host/basic/mockgen_private.sh b/p2p/host/basic/mockgen_private.sh deleted file mode 100755 index 79f63eee3e..0000000000 --- a/p2p/host/basic/mockgen_private.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash - -DEST=$2 -PACKAGE=$3 -TMPFILE="mockgen_tmp.go" -# uppercase the name of the interface -ORIG_INTERFACE_NAME=$4 -INTERFACE_NAME="$(tr '[:lower:]' '[:upper:]' <<< ${ORIG_INTERFACE_NAME:0:1})${ORIG_INTERFACE_NAME:1}" - -# Gather all files that contain interface definitions. -# These interfaces might be used as embedded interfaces, -# so we need to pass them to mockgen as aux_files. -AUX=() -for f in *.go; do - if [[ -z ${f##*_test.go} ]]; then - # skip test files - continue; - fi - if $(egrep -qe "type (.*) interface" $f); then - AUX+=("github.com/quic-go/quic-go=$f") - fi -done - -# Find the file that defines the interface we're mocking. -for f in *.go; do - if [[ -z ${f##*_test.go} ]]; then - # skip test files - continue; - fi - INTERFACE=$(sed -n "/^type $ORIG_INTERFACE_NAME interface/,/^}/p" $f) - if [[ -n "$INTERFACE" ]]; then - SRC=$f - break - fi -done - -if [[ -z "$INTERFACE" ]]; then - echo "Interface $ORIG_INTERFACE_NAME not found." - exit 1 -fi - -AUX_FILES=$(IFS=, ; echo "${AUX[*]}") - -## create a public alias for the interface, so that mockgen can process it -echo -e "package $1\n" > $TMPFILE -echo "$INTERFACE" | sed "s/$ORIG_INTERFACE_NAME/$INTERFACE_NAME/" >> $TMPFILE -go run github.com/golang/mock/mockgen -package $1 -self_package $3 -destination $DEST -source=$TMPFILE -aux_files $AUX_FILES -sed "s/$TMPFILE/$SRC/" "$DEST" > "$DEST.new" && mv "$DEST.new" "$DEST" -rm "$TMPFILE" diff --git a/p2p/host/basic/mocks.go b/p2p/host/basic/mocks.go new file mode 100644 index 0000000000..3ad4d4e90b --- /dev/null +++ b/p2p/host/basic/mocks.go @@ -0,0 +1,6 @@ +//go:build gomock || generate + +package basichost + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package basichost -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic NAT" +type NAT nat diff --git a/p2p/host/basic/natmgr_test.go b/p2p/host/basic/natmgr_test.go index 8ee772e62a..e507b45c82 100644 --- a/p2p/host/basic/natmgr_test.go +++ b/p2p/host/basic/natmgr_test.go @@ -15,12 +15,10 @@ import ( "github.com/golang/mock/gomock" ) -//go:generate sh -c "./mockgen_private.sh basichost mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic nat" - -func setupMockNAT(t *testing.T) (mockNAT *MockNat, reset func()) { +func setupMockNAT(t *testing.T) (mockNAT *MockNAT, reset func()) { t.Helper() ctrl := gomock.NewController(t) - mockNAT = NewMockNat(ctrl) + mockNAT = NewMockNAT(ctrl) origDiscoverNAT := discoverNAT discoverNAT = func(ctx context.Context) (nat, error) { return mockNAT, nil } return mockNAT, func() {