Skip to content
This repository has been archived by the owner on Sep 9, 2022. It is now read-only.

add a Close method, remove the context from the constructor #141

Merged
merged 1 commit into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ var (

// Relay is the relay transport and service.
type Relay struct {
host host.Host
upgrader *tptu.Upgrader
ctx context.Context
self peer.ID
host host.Host
upgrader *tptu.Upgrader
ctx context.Context
ctxCancel context.CancelFunc
self peer.ID

active bool
hop bool
Expand Down Expand Up @@ -93,15 +94,15 @@ func (e RelayError) Error() string {
}

// NewRelay constructs a new relay.
func NewRelay(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) {
func NewRelay(h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) {
r := &Relay{
upgrader: upgrader,
host: h,
ctx: ctx,
self: h.ID(),
incoming: make(chan *Conn),
hopCount: make(map[peer.ID]int),
}
r.ctx, r.ctxCancel = context.WithCancel(context.Background())

for _, opt := range opts {
switch opt {
Expand Down
78 changes: 28 additions & 50 deletions relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ func getNetHosts(t *testing.T, n int) []host.Host {
netw := swarmt.GenSwarm(t)
h := bhost.NewBlankHost(netw)
out = append(out, h)
t.Cleanup(func() { h.Close() })
}

return out
}

func newTestRelay(t *testing.T, ctx context.Context, host host.Host, opts ...RelayOpt) *Relay {
r, err := NewRelay(ctx, host, swarmt.GenUpgrader(host.Network().(*swarm.Swarm)), opts...)
func newTestRelay(t *testing.T, host host.Host, opts ...RelayOpt) *Relay {
r, err := NewRelay(host, swarmt.GenUpgrader(host.Network().(*swarm.Swarm)), opts...)
if err != nil {
t.Fatal(err)
}
Expand All @@ -71,11 +70,11 @@ func TestBasicRelay(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)
newTestRelay(t, hosts[1], OptHop)

r3 := newTestRelay(t, ctx, hosts[2])
r3 := newTestRelay(t, hosts[2])

var (
conn1, conn2 net.Conn
Expand Down Expand Up @@ -145,11 +144,11 @@ func TestRelayReset(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)
newTestRelay(t, hosts[1], OptHop)

r3 := newTestRelay(t, ctx, hosts[2])
r3 := newTestRelay(t, hosts[2])

ready := make(chan struct{})

Expand Down Expand Up @@ -203,10 +202,10 @@ func TestBasicRelayDial(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

_ = newTestRelay(t, ctx, hosts[1], OptHop)
r3 := newTestRelay(t, ctx, hosts[2])
_ = newTestRelay(t, hosts[1], OptHop)
r3 := newTestRelay(t, hosts[2])

var (
conn1, conn2 net.Conn
Expand Down Expand Up @@ -266,49 +265,28 @@ func TestBasicRelayDial(t *testing.T) {
}

func TestUnspecificRelayDialFails(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

hosts := getNetHosts(t, 3)

r1 := newTestRelay(t, ctx, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)

r3 := newTestRelay(t, ctx, hosts[2])
r1 := newTestRelay(t, hosts[0])
newTestRelay(t, hosts[1], OptHop)
r3 := newTestRelay(t, hosts[2])

connect(t, hosts[0], hosts[1])
connect(t, hosts[1], hosts[2])

time.Sleep(100 * time.Millisecond)

var (
done = make(chan struct{})
)

defer func() {
cancel()
<-done
}()

go func() {
defer close(done)
list := r3.Listener()

var err error
_, err = list.Accept()
if err == nil {
if _, err := r3.Listener().Accept(); err == nil {
t.Error("should not have received relay connection")
}
}()

addr := ma.StringCast("/p2p-circuit")

rctx, rcancel := context.WithTimeout(ctx, time.Second)
defer rcancel()

var err error
_, err = r1.Dial(rctx, addr, hosts[2].ID())
if err == nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if _, err := r1.Dial(ctx, addr, hosts[2].ID()); err == nil {
t.Fatal("expected dial with unspecified relay address to fail, even if we're connected to a relay")
}
}
Expand All @@ -324,11 +302,11 @@ func TestRelayThroughNonHop(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1])
newTestRelay(t, hosts[1])

newTestRelay(t, ctx, hosts[2])
newTestRelay(t, hosts[2])

rinfo := hosts[1].Peerstore().PeerInfo(hosts[1].ID())
dinfo := hosts[2].Peerstore().PeerInfo(hosts[2].ID())
Expand Down Expand Up @@ -361,9 +339,9 @@ func TestRelayNoDestConnection(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)
newTestRelay(t, hosts[1], OptHop)

rinfo := hosts[1].Peerstore().PeerInfo(hosts[1].ID())
dinfo := hosts[2].Peerstore().PeerInfo(hosts[2].ID())
Expand Down Expand Up @@ -396,9 +374,9 @@ func TestActiveRelay(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
newTestRelay(t, ctx, hosts[1], OptHop, OptActive)
r3 := newTestRelay(t, ctx, hosts[2])
r1 := newTestRelay(t, hosts[0])
newTestRelay(t, hosts[1], OptHop, OptActive)
r3 := newTestRelay(t, hosts[2])

connChan := make(chan manet.Conn)

Expand Down Expand Up @@ -458,9 +436,9 @@ func TestRelayCanHop(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)
newTestRelay(t, hosts[1], OptHop)

canhop, err := r1.CanHop(ctx, hosts[1].ID())
if err != nil {
Expand Down
12 changes: 9 additions & 3 deletions transport.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package relay

import (
"context"
"fmt"
"io"

"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/transport"
Expand All @@ -13,6 +13,7 @@ import (
var circuitAddr = ma.Cast(ma.ProtocolWithCode(ma.P_CIRCUIT).VCode)

var _ transport.Transport = (*RelayTransport)(nil)
var _ io.Closer = (*RelayTransport)(nil)

type RelayTransport Relay

Expand Down Expand Up @@ -45,14 +46,19 @@ func (t *RelayTransport) Protocols() []int {
return []int{ma.P_CIRCUIT}
}

func (r *RelayTransport) Close() error {
r.ctxCancel()
return nil
}

// AddRelayTransport constructs a relay and adds it as a transport to the host network.
func AddRelayTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) error {
func AddRelayTransport(h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) error {
n, ok := h.Network().(transport.TransportNetwork)
if !ok {
return fmt.Errorf("%v is not a transport network", h.Network())
}

r, err := NewRelay(ctx, h, upgrader, opts...)
r, err := NewRelay(h, upgrader, opts...)
if err != nil {
return err
}
Expand Down
24 changes: 10 additions & 14 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ const TestProto = "test/relay-transport"

var msg = []byte("relay works!")

func testSetupRelay(t *testing.T, ctx context.Context) []host.Host {
func testSetupRelay(t *testing.T) []host.Host {
hosts := getNetHosts(t, 3)

err := AddRelayTransport(ctx, hosts[0], swarmt.GenUpgrader(hosts[0].Network().(*swarm.Swarm)))
err := AddRelayTransport(hosts[0], swarmt.GenUpgrader(hosts[0].Network().(*swarm.Swarm)))
if err != nil {
t.Fatal(err)
}

err = AddRelayTransport(ctx, hosts[1], swarmt.GenUpgrader(hosts[1].Network().(*swarm.Swarm)), OptHop)
err = AddRelayTransport(hosts[1], swarmt.GenUpgrader(hosts[1].Network().(*swarm.Swarm)), OptHop)
if err != nil {
t.Fatal(err)
}

err = AddRelayTransport(ctx, hosts[2], swarmt.GenUpgrader(hosts[2].Network().(*swarm.Swarm)))
err = AddRelayTransport(hosts[2], swarmt.GenUpgrader(hosts[2].Network().(*swarm.Swarm)))
if err != nil {
t.Fatal(err)
}
Expand All @@ -60,10 +60,7 @@ func testSetupRelay(t *testing.T, ctx context.Context) []host.Host {
}

func TestFullAddressTransportDial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

hosts := testSetupRelay(t, ctx)
hosts := testSetupRelay(t)

var relayAddr ma.Multiaddr
for _, addr := range hosts[1].Addrs() {
Expand All @@ -78,12 +75,11 @@ func TestFullAddressTransportDial(t *testing.T) {
t.Fatal(err)
}

rctx, rcancel := context.WithTimeout(ctx, time.Second)
defer rcancel()

hosts[0].Peerstore().AddAddrs(hosts[2].ID(), []ma.Multiaddr{addr}, peerstore.TempAddrTTL)

s, err := hosts[0].NewStream(rctx, hosts[2].ID(), TestProto)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
s, err := hosts[0].NewStream(ctx, hosts[2].ID(), TestProto)
if err != nil {
t.Fatal(err)
}
Expand All @@ -102,7 +98,7 @@ func TestSpecificRelayTransportDial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

hosts := testSetupRelay(t, ctx)
hosts := testSetupRelay(t)

addr, err := ma.NewMultiaddr(fmt.Sprintf("/ipfs/%s/p2p-circuit/ipfs/%s", hosts[1].ID().Pretty(), hosts[2].ID().Pretty()))
if err != nil {
Expand Down Expand Up @@ -133,7 +129,7 @@ func TestUnspecificRelayTransportDialFails(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

hosts := testSetupRelay(t, ctx)
hosts := testSetupRelay(t)

addr, err := ma.NewMultiaddr(fmt.Sprintf("/p2p-circuit/ipfs/%s", hosts[2].ID().Pretty()))
if err != nil {
Expand Down