Skip to content

Commit

Permalink
state transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
ckousik committed Oct 12, 2022
1 parent 894e6a8 commit 6c3bf8a
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 60 deletions.
6 changes: 3 additions & 3 deletions p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
return nil, err
}

streamId := *dc.ID()
streamID := *dc.ID()
var stream *dataChannel
dc.OnOpen(func() {
rwc, err := dc.Detach()
Expand All @@ -159,7 +159,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
return
}
stream = newDataChannel(dc, rwc, c.pc, nil, nil)
c.addStream(streamId, stream)
c.addStream(streamID, stream)
result <- struct {
network.MuxedStream
error
Expand All @@ -168,7 +168,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error

dc.OnClose(func() {
stream.remoteClosed()
c.removeStream(streamId)
c.removeStream(streamID)
})

select {
Expand Down
210 changes: 153 additions & 57 deletions p2p/transport/webrtc/datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@ import (
"io"
"os"

// "io"
"net"

"sync"
"time"

"sync/atomic"

"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-msgio/protoio"
"github.com/pion/datachannel"
Expand All @@ -24,6 +21,24 @@ import (

var _ network.MuxedStream = &dataChannel{}

const (
// bufferedAmountLowThreshold and maxBufferedAmount are bound
// to a stream but congestion control is done on the whole
// SCTP association. This means that a single stream can monopolize
// the complete congestion control window (cwnd) if it does not
// read stream data and it's remote continues to send.
bufferedAmountLowThreshold uint64 = 1024
// Max message size limit in Pion is 2^16
maxBufferedAmount uint64 = 65536
)

const (
stateOpen uint32 = iota
stateReadClosed
stateWriteClosed
stateClosed
)

// Package pion detached data channel into a net.Conn
// and then a network.MuxedStream
type dataChannel struct {
Expand All @@ -38,18 +53,16 @@ type dataChannel struct {
closeReadOnce sync.Once
resetOnce sync.Once

remoteWriteClosed uint32
localWriteClosed uint32
state uint32

remoteReadClosed uint32
localReadClosed uint32
ctx context.Context
cancel context.CancelFunc
m sync.Mutex
readBuf bytes.Buffer
writeAvailable chan struct{}

ctx context.Context
cancel context.CancelFunc
m sync.Mutex
readBuf bytes.Buffer
writer protoio.Writer
reader protoio.Reader
writer protoio.Writer
reader protoio.Reader
}

func newDataChannel(
Expand All @@ -60,30 +73,48 @@ func newDataChannel(
ctx, cancel := context.WithCancel(context.Background())

result := &dataChannel{
channel: channel,
laddr: laddr,
raddr: raddr,
readDeadline: newDeadline(),
writeDeadline: newDeadline(),
ctx: ctx,
cancel: cancel,
writer: protoio.NewDelimitedWriter(rwc),
reader: protoio.NewDelimitedReader(rwc, 1500),
channel: channel,
laddr: laddr,
raddr: raddr,
readDeadline: newDeadline(),
writeDeadline: newDeadline(),
ctx: ctx,
cancel: cancel,
writer: protoio.NewDelimitedWriter(rwc),
reader: protoio.NewDelimitedReader(rwc, 1500),
writeAvailable: make(chan struct{}),
}

// channel.OnMessage(result.handleMessage)
channel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
channel.OnBufferedAmountLow(func() {
result.writeAvailable <- struct{}{}
})

return result
}

func (d *dataChannel) processControlMessage(msg pb.Message) {
d.m.Lock()
defer d.m.Unlock()
if d.state == stateClosed {
return
}
switch msg.GetFlag() {
case pb.Message_FIN:
atomic.StoreUint32(&d.remoteWriteClosed, 1)
if d.state == stateWriteClosed {
d.Close()
return
}
d.state = stateReadClosed
case pb.Message_STOP_SENDING:
atomic.StoreUint32(&d.remoteReadClosed, 1)
if d.state == stateReadClosed {
d.Close()
return
}
d.state = stateWriteClosed
case pb.Message_RESET:
atomic.StoreUint32(&d.remoteWriteClosed, 1)
d.channel.Close()
}
}

Expand All @@ -98,7 +129,7 @@ func (d *dataChannel) Read(b []byte) (int, error) {
d.m.Lock()
read, err := d.readBuf.Read(b)
d.m.Unlock()
if err == io.EOF && d.isRemoteWriteClosed() {
if state := d.getState(); err == io.EOF && (state == stateReadClosed || state == stateClosed) {
return read, io.EOF
}
if read > 0 {
Expand All @@ -107,32 +138,91 @@ func (d *dataChannel) Read(b []byte) (int, error) {

// read until data message
var msg pb.Message
err = d.reader.ReadMsg(&msg)
if err != nil {
return 0, err
}
if !d.isRemoteWriteClosed() && !d.isLocalReadClosed() {
d.readBuf.Write(msg.GetMessage())
}
// process control message
if msg.Flag != nil {
d.processControlMessage(msg)
signal := make(chan struct {
error
})

// read in a separate goroutine to enable read deadlines
go func() {
err = d.reader.ReadMsg(&msg)
if err != nil {
signal <- struct {
error
}{err}
return
}
if state := d.getState(); state != stateClosed && state != stateReadClosed {
d.m.Lock()
d.readBuf.Write(msg.GetMessage())
d.m.Unlock()
}
// process control message
if msg.Flag != nil {
d.processControlMessage(msg)
}
signal <- struct{ error }{nil}

}()
select {
case sig := <-signal:
if sig.error != nil {
return 0, sig.error
}
case <-d.readDeadline.wait():
return 0, os.ErrDeadlineExceeded
}

}
}

func (d *dataChannel) Write(b []byte) (int, error) {
if d.isLocalWriteClosed() || d.isRemoteReadClosed() {
if s := d.getState(); s == stateWriteClosed || s == stateClosed {
return 0, io.ErrClosedPipe
}

var err error
var (
start int = 0
end = 0
written = 0
chunkSize = 1024*1024 - 10
n = 0
)

for start < len(b) {
end = len(b)
if start+chunkSize < end {
end = start + chunkSize
}
chunk := b[start:end]
n, err = d.partialWrite(chunk)
if err != nil {
break
}
written += n
start = end
}
return written, err

}

func (d *dataChannel) partialWrite(b []byte) (int, error) {
if s := d.getState(); s == stateWriteClosed || s == stateClosed {
return 0, io.ErrClosedPipe
}
select {
case <-d.writeDeadline.wait():
return 0, os.ErrDeadlineExceeded
default:
}
msg := &pb.Message{
Message: b,
msg := &pb.Message{Message: b}
// approximate overhead
if d.channel.BufferedAmount()+uint64(len(b))+10 > maxBufferedAmount {
select {
case <-d.writeAvailable:
case <-d.writeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
return len(b), d.writer.WriteMsg(msg)
}
Expand All @@ -143,6 +233,11 @@ func (d *dataChannel) Close() error {
return nil
default:
}

d.m.Lock()
d.state = stateClosed
d.m.Unlock()

d.cancel()
d.CloseWrite()
_ = d.channel.Close()
Expand All @@ -152,7 +247,11 @@ func (d *dataChannel) Close() error {
func (d *dataChannel) CloseRead() error {
var err error
d.closeReadOnce.Do(func() {
atomic.StoreUint32(&d.localReadClosed, 1)
d.m.Lock()
if d.state != stateClosed {
d.state = stateReadClosed
}
d.m.Unlock()
msg := &pb.Message{
Flag: pb.Message_STOP_SENDING.Enum(),
}
Expand All @@ -163,13 +262,21 @@ func (d *dataChannel) CloseRead() error {
}

func (d *dataChannel) remoteClosed() {
d.m.Lock()
defer d.m.Unlock()
d.state = stateClosed
d.cancel()

}

func (d *dataChannel) CloseWrite() error {
var err error
d.closeWriteOnce.Do(func() {
atomic.StoreUint32(&d.localWriteClosed, 1)
d.m.Lock()
if d.state != stateClosed {
d.state = stateWriteClosed
}
d.m.Unlock()
msg := &pb.Message{
Flag: pb.Message_FIN.Enum(),
}
Expand All @@ -189,10 +296,9 @@ func (d *dataChannel) RemoteAddr() net.Addr {
func (d *dataChannel) Reset() error {
var err error
d.resetOnce.Do(func() {
// does reset mean that no more data will be sent?
atomic.StoreUint32(&d.localWriteClosed, 1)
msg := &pb.Message{Flag: pb.Message_RESET.Enum()}
err = d.writer.WriteMsg(msg)
d.Close()
})
return err
}
Expand All @@ -213,18 +319,8 @@ func (d *dataChannel) SetWriteDeadline(t time.Time) error {
return nil
}

func (d *dataChannel) isRemoteWriteClosed() bool {
return atomic.LoadUint32(&d.remoteWriteClosed) == 1
}

func (d *dataChannel) isLocalWriteClosed() bool {
return atomic.LoadUint32(&d.localWriteClosed) == 1
}

func (d *dataChannel) isRemoteReadClosed() bool {
return atomic.LoadUint32(&d.remoteReadClosed) == 1
}

func (d *dataChannel) isLocalReadClosed() bool {
return atomic.LoadUint32(&d.localReadClosed) == 1
func (d *dataChannel) getState() uint32 {
d.m.Lock()
defer d.m.Unlock()
return d.state
}

0 comments on commit 6c3bf8a

Please sign in to comment.