Skip to content

Commit

Permalink
fix handshake channel
Browse files Browse the repository at this point in the history
  • Loading branch information
ckousik committed Oct 17, 2022
1 parent 6c3bf8a commit 36f6ae4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 41 deletions.
67 changes: 43 additions & 24 deletions p2p/transport/webrtc/datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"time"

"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-msgio/protoio"
"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"

Expand Down Expand Up @@ -44,6 +43,7 @@ const (
type dataChannel struct {
// TODO: Are these circular references okay?
channel *webrtc.DataChannel
rwc datachannel.ReadWriteCloser
laddr net.Addr
raddr net.Addr
readDeadline *deadline
Expand All @@ -59,10 +59,8 @@ type dataChannel struct {
cancel context.CancelFunc
m sync.Mutex
readBuf bytes.Buffer
buf []byte
writeAvailable chan struct{}

writer protoio.Writer
reader protoio.Reader
}

func newDataChannel(
Expand All @@ -74,15 +72,15 @@ func newDataChannel(

result := &dataChannel{
channel: channel,
rwc: rwc,
laddr: laddr,
raddr: raddr,
readDeadline: newDeadline(),
writeDeadline: newDeadline(),
ctx: ctx,
cancel: cancel,
writer: protoio.NewDelimitedWriter(rwc),
reader: protoio.NewDelimitedReader(rwc, 1500),
writeAvailable: make(chan struct{}),
buf: make([]byte, 1500),
}

// channel.OnMessage(result.handleMessage)
Expand All @@ -100,6 +98,9 @@ func (d *dataChannel) processControlMessage(msg pb.Message) {
if d.state == stateClosed {
return
}
if msg.Flag == nil {
return
}
switch msg.GetFlag() {
case pb.Message_FIN:
if d.state == stateWriteClosed {
Expand Down Expand Up @@ -144,22 +145,30 @@ func (d *dataChannel) Read(b []byte) (int, error) {

// read in a separate goroutine to enable read deadlines
go func() {
err = d.reader.ReadMsg(&msg)
read, err = d.rwc.Read(d.buf)
if err != nil {
log.Warnf("error reading from datachannel: %v", err)
signal <- struct {
error
}{err}
return
}
err = msg.Unmarshal(d.buf[:read])
if err != nil {
log.Warnf("could not unmarshal message: read: %d, err: %v", read, err)
signal <- struct {
error
}{err}
return
}
if state := d.getState(); state != stateClosed && state != stateReadClosed {

if state := d.getState(); state != stateClosed && state != stateReadClosed && msg.Message != nil {
d.m.Lock()
d.readBuf.Write(msg.GetMessage())
d.readBuf.Write(msg.Message)
d.m.Unlock()
}
// process control message
if msg.Flag != nil {
d.processControlMessage(msg)
}
d.processControlMessage(msg)

signal <- struct{ error }{nil}

}()
Expand All @@ -183,9 +192,8 @@ func (d *dataChannel) Write(b []byte) (int, error) {
var err error
var (
start int = 0
end = 0
written = 0
chunkSize = 1024*1024 - 10
end = len(b)
chunkSize = 65525
n = 0
)

Expand All @@ -199,11 +207,9 @@ func (d *dataChannel) Write(b []byte) (int, error) {
if err != nil {
break
}
written += n
start = end
start += n
}
return written, err

return start, err
}

func (d *dataChannel) partialWrite(b []byte) (int, error) {
Expand All @@ -224,7 +230,20 @@ func (d *dataChannel) partialWrite(b []byte) (int, error) {
return 0, os.ErrDeadlineExceeded
}
}
return len(b), d.writer.WriteMsg(msg)
return d.writeMessage(msg)
}

func (d *dataChannel) writeMessage(msg *pb.Message) (int, error) {
data, err := msg.Marshal()
if err != nil {
return 0, err
}
_, err = d.rwc.Write(data)
if err != nil {
return 0, err
}
return len(msg.GetMessage()), err

}

func (d *dataChannel) Close() error {
Expand Down Expand Up @@ -255,7 +274,7 @@ func (d *dataChannel) CloseRead() error {
msg := &pb.Message{
Flag: pb.Message_STOP_SENDING.Enum(),
}
err = d.writer.WriteMsg(msg)
_, err = d.writeMessage(msg)
})
return err

Expand All @@ -280,7 +299,7 @@ func (d *dataChannel) CloseWrite() error {
msg := &pb.Message{
Flag: pb.Message_FIN.Enum(),
}
err = d.writer.WriteMsg(msg)
_, err = d.writeMessage(msg)
})
return err
}
Expand All @@ -297,7 +316,7 @@ func (d *dataChannel) Reset() error {
var err error
d.resetOnce.Do(func() {
msg := &pb.Message{Flag: pb.Message_RESET.Enum()}
err = d.writer.WriteMsg(msg)
_, err = d.writeMessage(msg)
d.Close()
})
return err
Expand Down
9 changes: 4 additions & 5 deletions p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,13 @@ func (l *listener) accept(ctx context.Context, addr candidateAddr) (tpt.CapableC
// signaling channel wraps an error in a struct to make
// the error nullable.
signalChan := make(chan struct{ error })
var wrappedChannel *dataChannel
var handshakeOnce sync.Once
// this enforces that the correct data channel label is used
// for the handshake
handshakeChannel, err := pc.CreateDataChannel("data", &webrtc.DataChannelInit{
handshakeChannel, err := pc.CreateDataChannel("handshake", &webrtc.DataChannelInit{
Negotiated: func(v bool) *bool { return &v }(true),
ID: func(v uint16) *uint16 { return &v }(1),
ID: func(v uint16) *uint16 { return &v }(0),
})
if err != nil {
defer cleanup()
Expand All @@ -208,9 +210,6 @@ func (l *listener) accept(ctx context.Context, addr candidateAddr) (tpt.CapableC
// Therefore, we wrap the datachannel before performing the
// offer-answer exchange, so any messages sent from the remote get
// buffered.
var wrappedChannel *dataChannel

var handshakeOnce sync.Once
handshakeChannel.OnOpen(func() {
rwc, err := handshakeChannel.Detach()
if err != nil {
Expand Down
23 changes: 11 additions & 12 deletions p2p/transport/webrtc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,9 @@ func (t *WebRTCTransport) Dial(
return nil, errInternal("could not instantiate peerconnection", err)
}

// We need to set negotiated = true for this channel on both
// the client and server to avoid DCEP errors.
handshakeChannel, err := pc.CreateDataChannel("data", &webrtc.DataChannelInit{
Negotiated: func(v bool) *bool { return &v }(true),
ID: func(v uint16) *uint16 { return &v }(1),
})

if err != nil {
defer cleanup()
return nil, errDatachannel("could not create", err)
}

signalChan := make(chan struct{ error })
var connectedOnce sync.Once

pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
switch state {
case webrtc.PeerConnectionStateConnected:
Expand All @@ -238,6 +227,16 @@ func (t *WebRTCTransport) Dial(
}
})

// We need to set negotiated = true for this channel on both
// the client and server to avoid DCEP errors.
handshakeChannel, err := pc.CreateDataChannel("handshake", &webrtc.DataChannelInit{
Negotiated: func(v bool) *bool { return &v }(true),
ID: func(v uint16) *uint16 { return &v }(0),
})
if err != nil {
defer cleanup()
return nil, errDatachannel("could not create", err)
}
handshakeChannel.OnOpen(func() {
rwc, err := handshakeChannel.Detach()
if err != nil {
Expand Down

0 comments on commit 36f6ae4

Please sign in to comment.