diff --git a/p2p/transport/webrtc/datachannel.go b/p2p/transport/webrtc/datachannel.go index 5a9aa45ef7..da95942dbd 100644 --- a/p2p/transport/webrtc/datachannel.go +++ b/p2p/transport/webrtc/datachannel.go @@ -12,7 +12,6 @@ import ( "time" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-msgio/protoio" "github.com/pion/datachannel" "github.com/pion/webrtc/v3" @@ -44,6 +43,7 @@ const ( type dataChannel struct { // TODO: Are these circular references okay? channel *webrtc.DataChannel + rwc datachannel.ReadWriteCloser laddr net.Addr raddr net.Addr readDeadline *deadline @@ -59,10 +59,8 @@ type dataChannel struct { cancel context.CancelFunc m sync.Mutex readBuf bytes.Buffer + buf []byte writeAvailable chan struct{} - - writer protoio.Writer - reader protoio.Reader } func newDataChannel( @@ -74,15 +72,15 @@ func newDataChannel( result := &dataChannel{ channel: channel, + rwc: rwc, laddr: laddr, raddr: raddr, readDeadline: newDeadline(), writeDeadline: newDeadline(), ctx: ctx, cancel: cancel, - writer: protoio.NewDelimitedWriter(rwc), - reader: protoio.NewDelimitedReader(rwc, 1500), writeAvailable: make(chan struct{}), + buf: make([]byte, 1500), } // channel.OnMessage(result.handleMessage) @@ -100,6 +98,9 @@ func (d *dataChannel) processControlMessage(msg pb.Message) { if d.state == stateClosed { return } + if msg.Flag == nil { + return + } switch msg.GetFlag() { case pb.Message_FIN: if d.state == stateWriteClosed { @@ -144,22 +145,30 @@ func (d *dataChannel) Read(b []byte) (int, error) { // read in a separate goroutine to enable read deadlines go func() { - err = d.reader.ReadMsg(&msg) + read, err = d.rwc.Read(d.buf) + if err != nil { + log.Warnf("error reading from datachannel: %v", err) + signal <- struct { + error + }{err} + return + } + err = msg.Unmarshal(d.buf[:read]) if err != nil { + log.Warnf("could not unmarshal message: read: %d, err: %v", read, err) signal <- struct { error }{err} return } - if state := d.getState(); state != stateClosed && state != stateReadClosed { + + if state := d.getState(); state != stateClosed && state != stateReadClosed && msg.Message != nil { d.m.Lock() - d.readBuf.Write(msg.GetMessage()) + d.readBuf.Write(msg.Message) d.m.Unlock() } - // process control message - if msg.Flag != nil { - d.processControlMessage(msg) - } + d.processControlMessage(msg) + signal <- struct{ error }{nil} }() @@ -183,9 +192,8 @@ func (d *dataChannel) Write(b []byte) (int, error) { var err error var ( start int = 0 - end = 0 - written = 0 - chunkSize = 1024*1024 - 10 + end = len(b) + chunkSize = 65525 n = 0 ) @@ -199,11 +207,9 @@ func (d *dataChannel) Write(b []byte) (int, error) { if err != nil { break } - written += n - start = end + start += n } - return written, err - + return start, err } func (d *dataChannel) partialWrite(b []byte) (int, error) { @@ -224,7 +230,20 @@ func (d *dataChannel) partialWrite(b []byte) (int, error) { return 0, os.ErrDeadlineExceeded } } - return len(b), d.writer.WriteMsg(msg) + return d.writeMessage(msg) +} + +func (d *dataChannel) writeMessage(msg *pb.Message) (int, error) { + data, err := msg.Marshal() + if err != nil { + return 0, err + } + _, err = d.rwc.Write(data) + if err != nil { + return 0, err + } + return len(msg.GetMessage()), err + } func (d *dataChannel) Close() error { @@ -255,7 +274,7 @@ func (d *dataChannel) CloseRead() error { msg := &pb.Message{ Flag: pb.Message_STOP_SENDING.Enum(), } - err = d.writer.WriteMsg(msg) + _, err = d.writeMessage(msg) }) return err @@ -280,7 +299,7 @@ func (d *dataChannel) CloseWrite() error { msg := &pb.Message{ Flag: pb.Message_FIN.Enum(), } - err = d.writer.WriteMsg(msg) + _, err = d.writeMessage(msg) }) return err } @@ -297,7 +316,7 @@ func (d *dataChannel) Reset() error { var err error d.resetOnce.Do(func() { msg := &pb.Message{Flag: pb.Message_RESET.Enum()} - err = d.writer.WriteMsg(msg) + _, err = d.writeMessage(msg) d.Close() }) return err diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 5c5ee9c452..6e6876db6a 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -186,11 +186,13 @@ func (l *listener) accept(ctx context.Context, addr candidateAddr) (tpt.CapableC // signaling channel wraps an error in a struct to make // the error nullable. signalChan := make(chan struct{ error }) + var wrappedChannel *dataChannel + var handshakeOnce sync.Once // this enforces that the correct data channel label is used // for the handshake - handshakeChannel, err := pc.CreateDataChannel("data", &webrtc.DataChannelInit{ + handshakeChannel, err := pc.CreateDataChannel("handshake", &webrtc.DataChannelInit{ Negotiated: func(v bool) *bool { return &v }(true), - ID: func(v uint16) *uint16 { return &v }(1), + ID: func(v uint16) *uint16 { return &v }(0), }) if err != nil { defer cleanup() @@ -208,9 +210,6 @@ func (l *listener) accept(ctx context.Context, addr candidateAddr) (tpt.CapableC // Therefore, we wrap the datachannel before performing the // offer-answer exchange, so any messages sent from the remote get // buffered. - var wrappedChannel *dataChannel - - var handshakeOnce sync.Once handshakeChannel.OnOpen(func() { rwc, err := handshakeChannel.Detach() if err != nil { diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index 422b872ad5..5a8105dc16 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -210,20 +210,9 @@ func (t *WebRTCTransport) Dial( return nil, errInternal("could not instantiate peerconnection", err) } - // We need to set negotiated = true for this channel on both - // the client and server to avoid DCEP errors. - handshakeChannel, err := pc.CreateDataChannel("data", &webrtc.DataChannelInit{ - Negotiated: func(v bool) *bool { return &v }(true), - ID: func(v uint16) *uint16 { return &v }(1), - }) - - if err != nil { - defer cleanup() - return nil, errDatachannel("could not create", err) - } - signalChan := make(chan struct{ error }) var connectedOnce sync.Once + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { switch state { case webrtc.PeerConnectionStateConnected: @@ -238,6 +227,16 @@ func (t *WebRTCTransport) Dial( } }) + // We need to set negotiated = true for this channel on both + // the client and server to avoid DCEP errors. + handshakeChannel, err := pc.CreateDataChannel("handshake", &webrtc.DataChannelInit{ + Negotiated: func(v bool) *bool { return &v }(true), + ID: func(v uint16) *uint16 { return &v }(0), + }) + if err != nil { + defer cleanup() + return nil, errDatachannel("could not create", err) + } handshakeChannel.OnOpen(func() { rwc, err := handshakeChannel.Detach() if err != nil {