Skip to content

Commit

Permalink
fix keepalive for websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
ginuerzh committed Mar 21, 2023
1 parent 661953e commit fb7b827
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
11 changes: 7 additions & 4 deletions dialer/mws/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,11 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)

cc := ws_util.Conn(c)

if d.md.keepAlive > 0 {
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
if d.md.keepaliveInterval > 0 {
d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval)
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
c.SetPongHandler(func(string) error {
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
return nil
})
go d.keepAlive(cc)
Expand Down Expand Up @@ -203,13 +204,15 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
}

func (d *mwsDialer) keepAlive(conn ws_util.WebsocketConn) {
ticker := time.NewTicker(d.md.keepAlive)
ticker := time.NewTicker(d.md.keepaliveInterval)
defer ticker.Stop()

for range ticker.C {
d.options.Logger.Debug("send ping")
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
conn.SetWriteDeadline(time.Time{})
}
}
18 changes: 12 additions & 6 deletions dialer/mws/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
)

const (
defaultPath = "/ws"
defaultPath = "/ws"
defaultKeepalivePeriod = 15 * time.Second
)

type metadata struct {
Expand All @@ -29,8 +30,8 @@ type metadata struct {
muxMaxReceiveBuffer int
muxMaxStreamBuffer int

header http.Header
keepAlive time.Duration
header http.Header
keepaliveInterval time.Duration
}

func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
Expand All @@ -44,8 +45,7 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"

header = "header"
keepAlive = "keepAlive"
header = "header"

muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAliveInterval = "muxKeepAliveInterval"
Expand Down Expand Up @@ -82,7 +82,13 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
}
d.md.header = h
}
d.md.keepAlive = mdutil.GetDuration(md, keepAlive)

if mdutil.GetBool(md, "keepalive") {
d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval")
if d.md.keepaliveInterval <= 0 {
d.md.keepaliveInterval = defaultKeepalivePeriod
}
}

return
}
5 changes: 4 additions & 1 deletion dialer/ws/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial
cc := ws_util.Conn(c)

if d.md.keepaliveInterval > 0 {
d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval)
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
c.SetPongHandler(func(string) error {
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
Expand All @@ -123,10 +124,12 @@ func (d *wsDialer) keepalive(conn ws_util.WebsocketConn) {
defer ticker.Stop()

for range ticker.C {
d.options.Logger.Debug("send ping")
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
// d.options.Logger.Error(err)
return
}
d.options.Logger.Debug("send ping")
conn.SetWriteDeadline(time.Time{})
}
}
4 changes: 2 additions & 2 deletions dialer/ws/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

const (
defaultPath = "/ws"
defaultKeepAlivePeriod = 15 * time.Second
defaultKeepalivePeriod = 15 * time.Second
)

type metadata struct {
Expand Down Expand Up @@ -65,7 +65,7 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
if mdutil.GetBool(md, "keepalive") {
d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval")
if d.md.keepaliveInterval <= 0 {
d.md.keepaliveInterval = defaultKeepAlivePeriod
d.md.keepaliveInterval = defaultKeepalivePeriod
}
}

Expand Down

0 comments on commit fb7b827

Please sign in to comment.