diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index f1759665a6..d481b75498 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -145,7 +145,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error return nil, err } - streamId := *dc.ID() + streamID := *dc.ID() var stream *dataChannel dc.OnOpen(func() { rwc, err := dc.Detach() @@ -159,7 +159,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error return } stream = newDataChannel(dc, rwc, c.pc, nil, nil) - c.addStream(streamId, stream) + c.addStream(streamID, stream) result <- struct { network.MuxedStream error @@ -168,7 +168,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error dc.OnClose(func() { stream.remoteClosed() - c.removeStream(streamId) + c.removeStream(streamID) }) select { diff --git a/p2p/transport/webrtc/datachannel.go b/p2p/transport/webrtc/datachannel.go index 050e74aea7..5a9aa45ef7 100644 --- a/p2p/transport/webrtc/datachannel.go +++ b/p2p/transport/webrtc/datachannel.go @@ -6,14 +6,11 @@ import ( "io" "os" - // "io" "net" "sync" "time" - "sync/atomic" - "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-msgio/protoio" "github.com/pion/datachannel" @@ -24,6 +21,24 @@ import ( var _ network.MuxedStream = &dataChannel{} +const ( + // bufferedAmountLowThreshold and maxBufferedAmount are bound + // to a stream but congestion control is done on the whole + // SCTP association. This means that a single stream can monopolize + // the complete congestion control window (cwnd) if it does not + // read stream data and it's remote continues to send. + bufferedAmountLowThreshold uint64 = 1024 + // Max message size limit in Pion is 2^16 + maxBufferedAmount uint64 = 65536 +) + +const ( + stateOpen uint32 = iota + stateReadClosed + stateWriteClosed + stateClosed +) + // Package pion detached data channel into a net.Conn // and then a network.MuxedStream type dataChannel struct { @@ -38,18 +53,16 @@ type dataChannel struct { closeReadOnce sync.Once resetOnce sync.Once - remoteWriteClosed uint32 - localWriteClosed uint32 + state uint32 - remoteReadClosed uint32 - localReadClosed uint32 + ctx context.Context + cancel context.CancelFunc + m sync.Mutex + readBuf bytes.Buffer + writeAvailable chan struct{} - ctx context.Context - cancel context.CancelFunc - m sync.Mutex - readBuf bytes.Buffer - writer protoio.Writer - reader protoio.Reader + writer protoio.Writer + reader protoio.Reader } func newDataChannel( @@ -60,30 +73,48 @@ func newDataChannel( ctx, cancel := context.WithCancel(context.Background()) result := &dataChannel{ - channel: channel, - laddr: laddr, - raddr: raddr, - readDeadline: newDeadline(), - writeDeadline: newDeadline(), - ctx: ctx, - cancel: cancel, - writer: protoio.NewDelimitedWriter(rwc), - reader: protoio.NewDelimitedReader(rwc, 1500), + channel: channel, + laddr: laddr, + raddr: raddr, + readDeadline: newDeadline(), + writeDeadline: newDeadline(), + ctx: ctx, + cancel: cancel, + writer: protoio.NewDelimitedWriter(rwc), + reader: protoio.NewDelimitedReader(rwc, 1500), + writeAvailable: make(chan struct{}), } // channel.OnMessage(result.handleMessage) + channel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) + channel.OnBufferedAmountLow(func() { + result.writeAvailable <- struct{}{} + }) return result } func (d *dataChannel) processControlMessage(msg pb.Message) { + d.m.Lock() + defer d.m.Unlock() + if d.state == stateClosed { + return + } switch msg.GetFlag() { case pb.Message_FIN: - atomic.StoreUint32(&d.remoteWriteClosed, 1) + if d.state == stateWriteClosed { + d.Close() + return + } + d.state = stateReadClosed case pb.Message_STOP_SENDING: - atomic.StoreUint32(&d.remoteReadClosed, 1) + if d.state == stateReadClosed { + d.Close() + return + } + d.state = stateWriteClosed case pb.Message_RESET: - atomic.StoreUint32(&d.remoteWriteClosed, 1) + d.channel.Close() } } @@ -98,7 +129,7 @@ func (d *dataChannel) Read(b []byte) (int, error) { d.m.Lock() read, err := d.readBuf.Read(b) d.m.Unlock() - if err == io.EOF && d.isRemoteWriteClosed() { + if state := d.getState(); err == io.EOF && (state == stateReadClosed || state == stateClosed) { return read, io.EOF } if read > 0 { @@ -107,23 +138,76 @@ func (d *dataChannel) Read(b []byte) (int, error) { // read until data message var msg pb.Message - err = d.reader.ReadMsg(&msg) - if err != nil { - return 0, err - } - if !d.isRemoteWriteClosed() && !d.isLocalReadClosed() { - d.readBuf.Write(msg.GetMessage()) - } - // process control message - if msg.Flag != nil { - d.processControlMessage(msg) + signal := make(chan struct { + error + }) + + // read in a separate goroutine to enable read deadlines + go func() { + err = d.reader.ReadMsg(&msg) + if err != nil { + signal <- struct { + error + }{err} + return + } + if state := d.getState(); state != stateClosed && state != stateReadClosed { + d.m.Lock() + d.readBuf.Write(msg.GetMessage()) + d.m.Unlock() + } + // process control message + if msg.Flag != nil { + d.processControlMessage(msg) + } + signal <- struct{ error }{nil} + + }() + select { + case sig := <-signal: + if sig.error != nil { + return 0, sig.error + } + case <-d.readDeadline.wait(): + return 0, os.ErrDeadlineExceeded } } } func (d *dataChannel) Write(b []byte) (int, error) { - if d.isLocalWriteClosed() || d.isRemoteReadClosed() { + if s := d.getState(); s == stateWriteClosed || s == stateClosed { + return 0, io.ErrClosedPipe + } + + var err error + var ( + start int = 0 + end = 0 + written = 0 + chunkSize = 1024*1024 - 10 + n = 0 + ) + + for start < len(b) { + end = len(b) + if start+chunkSize < end { + end = start + chunkSize + } + chunk := b[start:end] + n, err = d.partialWrite(chunk) + if err != nil { + break + } + written += n + start = end + } + return written, err + +} + +func (d *dataChannel) partialWrite(b []byte) (int, error) { + if s := d.getState(); s == stateWriteClosed || s == stateClosed { return 0, io.ErrClosedPipe } select { @@ -131,8 +215,14 @@ func (d *dataChannel) Write(b []byte) (int, error) { return 0, os.ErrDeadlineExceeded default: } - msg := &pb.Message{ - Message: b, + msg := &pb.Message{Message: b} + // approximate overhead + if d.channel.BufferedAmount()+uint64(len(b))+10 > maxBufferedAmount { + select { + case <-d.writeAvailable: + case <-d.writeDeadline.wait(): + return 0, os.ErrDeadlineExceeded + } } return len(b), d.writer.WriteMsg(msg) } @@ -143,6 +233,11 @@ func (d *dataChannel) Close() error { return nil default: } + + d.m.Lock() + d.state = stateClosed + d.m.Unlock() + d.cancel() d.CloseWrite() _ = d.channel.Close() @@ -152,7 +247,11 @@ func (d *dataChannel) Close() error { func (d *dataChannel) CloseRead() error { var err error d.closeReadOnce.Do(func() { - atomic.StoreUint32(&d.localReadClosed, 1) + d.m.Lock() + if d.state != stateClosed { + d.state = stateReadClosed + } + d.m.Unlock() msg := &pb.Message{ Flag: pb.Message_STOP_SENDING.Enum(), } @@ -163,13 +262,21 @@ func (d *dataChannel) CloseRead() error { } func (d *dataChannel) remoteClosed() { + d.m.Lock() + defer d.m.Unlock() + d.state = stateClosed d.cancel() + } func (d *dataChannel) CloseWrite() error { var err error d.closeWriteOnce.Do(func() { - atomic.StoreUint32(&d.localWriteClosed, 1) + d.m.Lock() + if d.state != stateClosed { + d.state = stateWriteClosed + } + d.m.Unlock() msg := &pb.Message{ Flag: pb.Message_FIN.Enum(), } @@ -189,10 +296,9 @@ func (d *dataChannel) RemoteAddr() net.Addr { func (d *dataChannel) Reset() error { var err error d.resetOnce.Do(func() { - // does reset mean that no more data will be sent? - atomic.StoreUint32(&d.localWriteClosed, 1) msg := &pb.Message{Flag: pb.Message_RESET.Enum()} err = d.writer.WriteMsg(msg) + d.Close() }) return err } @@ -213,18 +319,8 @@ func (d *dataChannel) SetWriteDeadline(t time.Time) error { return nil } -func (d *dataChannel) isRemoteWriteClosed() bool { - return atomic.LoadUint32(&d.remoteWriteClosed) == 1 -} - -func (d *dataChannel) isLocalWriteClosed() bool { - return atomic.LoadUint32(&d.localWriteClosed) == 1 -} - -func (d *dataChannel) isRemoteReadClosed() bool { - return atomic.LoadUint32(&d.remoteReadClosed) == 1 -} - -func (d *dataChannel) isLocalReadClosed() bool { - return atomic.LoadUint32(&d.localReadClosed) == 1 +func (d *dataChannel) getState() uint32 { + d.m.Lock() + defer d.m.Unlock() + return d.state }