diff --git a/multiplex.go b/multiplex.go index fc25158..6fe3c61 100644 --- a/multiplex.go +++ b/multiplex.go @@ -39,7 +39,7 @@ var errTimeout = timeout{} var ( ResetStreamTimeout = 2 * time.Minute - + MaxIncomingStreams = 1000 WriteCoalesceDelay = 100 * time.Microsecond ) @@ -83,22 +83,26 @@ type Multiplex struct { nstreams chan *Stream - channels map[streamID]*Stream - chLock sync.Mutex + maxIncoming int + + channels map[streamID]*Stream + numIncoming int + chLock sync.Mutex } // NewMultiplex creates a new multiplexer session. func NewMultiplex(con net.Conn, initiator bool) *Multiplex { mp := &Multiplex{ - con: con, - initiator: initiator, - buf: bufio.NewReader(con), - channels: make(map[streamID]*Stream), - closed: make(chan struct{}), - shutdown: make(chan struct{}), - writeCh: make(chan []byte, 16), - writeTimer: time.NewTimer(0), - nstreams: make(chan *Stream, 16), + con: con, + initiator: initiator, + buf: bufio.NewReader(con), + channels: make(map[streamID]*Stream), + closed: make(chan struct{}), + shutdown: make(chan struct{}), + writeCh: make(chan []byte, 16), + writeTimer: time.NewTimer(0), + maxIncoming: MaxIncomingStreams, + nstreams: make(chan *Stream, 16), } go mp.handleIncoming() @@ -410,6 +414,15 @@ func (mp *Multiplex) handleIncoming() { msch = mp.newStream(ch, name) mp.chLock.Lock() + if remoteIsInitiator { + if mp.numIncoming >= mp.maxIncoming { + msch.mp.sendResetMsg(msch.id.header(resetTag), true) + mp.chLock.Unlock() + continue + } else { + mp.numIncoming++ + } + } mp.channels[ch] = msch mp.chLock.Unlock() select { @@ -436,6 +449,9 @@ func (mp *Multiplex) handleIncoming() { // unregister and throw away future data. mp.chLock.Lock() delete(mp.channels, ch) + if remoteIsInitiator { + mp.numIncoming-- + } mp.chLock.Unlock() // close data channel, there will be no more data. diff --git a/multiplex_test.go b/multiplex_test.go index 1f27f24..aa79541 100644 --- a/multiplex_test.go +++ b/multiplex_test.go @@ -11,6 +11,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/require" ) func init() { @@ -880,3 +882,48 @@ func arrComp(a, b []byte) error { } return nil } + +func TestMaxIncomingStreams(t *testing.T) { + a, b := net.Pipe() + client := NewMultiplex(a, true) + defer client.Close() + + server := NewMultiplex(b, false) + defer server.Close() + + go func() { + for { + str, err := server.Accept() + if err != nil { + return + } + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) + } + }() + + var streams []*Stream + for i := 0; i < MaxIncomingStreams; i++ { + str, err := client.NewStream(context.Background()) + require.NoError(t, err) + _, err = str.Read(make([]byte, 6)) + require.NoError(t, err) + streams = append(streams, str) + } + // The server now has maxIncomingStreams incoming streams. + // It will now reset the next stream that is opened. + str, err := client.NewStream(context.Background()) + require.NoError(t, err) + str.SetDeadline(time.Now().Add(time.Second)) + _, err = str.Read([]byte{0}) + require.EqualError(t, err, "stream reset") + + // Now close one of the streams. + // This should then allow the client to open a new stream. + streams[0].Close() + str, err = client.NewStream(context.Background()) + require.NoError(t, err) + str.SetDeadline(time.Now().Add(time.Second)) + _, err = str.Read([]byte{0}) + require.NoError(t, err) +}