Skip to content

Commit

Permalink
internal/socket: don't crash with empty Message.Buffers
Browse files Browse the repository at this point in the history
Fixes golang/go#22117.

Change-Id: I0d9c3e126aaf97cd297c84e064e9a521ddac626f
Reviewed-on: https://go-review.googlesource.com/67750
Reviewed-by: Ian Lance Taylor <iant@golang.org>
  • Loading branch information
cixtor committed Oct 4, 2017
1 parent 0a93976 commit 4f8c76a
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 70 deletions.
6 changes: 5 additions & 1 deletion internal/socket/iovec_32bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ package socket
import "unsafe"

func (v *iovec) set(b []byte) {
l := len(b)
if l == 0 {
return
}
v.Base = (*byte)(unsafe.Pointer(&b[0]))
v.Len = uint32(len(b))
v.Len = uint32(l)
}
6 changes: 5 additions & 1 deletion internal/socket/iovec_64bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ package socket
import "unsafe"

func (v *iovec) set(b []byte) {
l := len(b)
if l == 0 {
return
}
v.Base = (*byte)(unsafe.Pointer(&b[0]))
v.Len = uint64(len(b))
v.Len = uint64(l)
}
6 changes: 5 additions & 1 deletion internal/socket/iovec_solaris_64bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ package socket
import "unsafe"

func (v *iovec) set(b []byte) {
l := len(b)
if l == 0 {
return
}
v.Base = (*int8)(unsafe.Pointer(&b[0]))
v.Len = uint64(len(b))
v.Len = uint64(l)
}
6 changes: 5 additions & 1 deletion internal/socket/msghdr_bsdvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
package socket

func (h *msghdr) setIov(vs []iovec) {
l := len(vs)
if l == 0 {
return
}
h.Iov = &vs[0]
h.Iovlen = int32(len(vs))
h.Iovlen = int32(l)
}
6 changes: 5 additions & 1 deletion internal/socket/msghdr_linux_32bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ package socket
import "unsafe"

func (h *msghdr) setIov(vs []iovec) {
l := len(vs)
if l == 0 {
return
}
h.Iov = &vs[0]
h.Iovlen = uint32(len(vs))
h.Iovlen = uint32(l)
}

func (h *msghdr) setControl(b []byte) {
Expand Down
6 changes: 5 additions & 1 deletion internal/socket/msghdr_linux_64bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ package socket
import "unsafe"

func (h *msghdr) setIov(vs []iovec) {
l := len(vs)
if l == 0 {
return
}
h.Iov = &vs[0]
h.Iovlen = uint64(len(vs))
h.Iovlen = uint64(l)
}

func (h *msghdr) setControl(b []byte) {
Expand Down
6 changes: 5 additions & 1 deletion internal/socket/msghdr_openbsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
package socket

func (h *msghdr) setIov(vs []iovec) {
l := len(vs)
if l == 0 {
return
}
h.Iov = &vs[0]
h.Iovlen = uint32(len(vs))
h.Iovlen = uint32(l)
}
6 changes: 4 additions & 2 deletions internal/socket/msghdr_solaris_64bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) {
for i := range vs {
vs[i].set(bs[i])
}
h.Iov = &vs[0]
h.Iovlen = int32(len(vs))
if len(vs) > 0 {
h.Iov = &vs[0]
h.Iovlen = int32(len(vs))
}
if len(oob) > 0 {
h.Accrights = (*int8)(unsafe.Pointer(&oob[0]))
h.Accrightslen = int32(len(oob))
Expand Down
125 changes: 64 additions & 61 deletions internal/socket/socket_go1_9_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,81 +119,84 @@ func TestUDP(t *testing.T) {
t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
}
defer c.Close()
cc, err := socket.NewConn(c.(net.Conn))
if err != nil {
t.Fatal(err)
}

t.Run("Message", func(t *testing.T) {
testUDPMessage(t, c.(net.Conn))
data := []byte("HELLO-R-U-THERE")
wm := socket.Message{
Buffers: bytes.SplitAfter(data, []byte("-")),
Addr: c.LocalAddr(),
}
if err := cc.SendMsg(&wm, 0); err != nil {
t.Fatal(err)
}
b := make([]byte, 32)
rm := socket.Message{
Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
}
if err := cc.RecvMsg(&rm, 0); err != nil {
t.Fatal(err)
}
if !bytes.Equal(b[:rm.N], data) {
t.Fatalf("got %#v; want %#v", b[:rm.N], data)
}
})
switch runtime.GOOS {
case "linux":
t.Run("Messages", func(t *testing.T) {
testUDPMessages(t, c.(net.Conn))
data := []byte("HELLO-R-U-THERE")
wmbs := bytes.SplitAfter(data, []byte("-"))
wms := []socket.Message{
{Buffers: wmbs[:1], Addr: c.LocalAddr()},
{Buffers: wmbs[1:], Addr: c.LocalAddr()},
}
n, err := cc.SendMsgs(wms, 0)
if err != nil {
t.Fatal(err)
}
if n != len(wms) {
t.Fatalf("got %d; want %d", n, len(wms))
}
b := make([]byte, 32)
rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
rms := []socket.Message{
{Buffers: rmbs[0]},
{Buffers: rmbs[1]},
}
n, err = cc.RecvMsgs(rms, 0)
if err != nil {
t.Fatal(err)
}
if n != len(rms) {
t.Fatalf("got %d; want %d", n, len(rms))
}
nn := 0
for i := 0; i < n; i++ {
nn += rms[i].N
}
if !bytes.Equal(b[:nn], data) {
t.Fatalf("got %#v; want %#v", b[:nn], data)
}
})
}
}

func testUDPMessage(t *testing.T, c net.Conn) {
cc, err := socket.NewConn(c)
if err != nil {
t.Fatal(err)
}
data := []byte("HELLO-R-U-THERE")
// The behavior of transmission for zero byte paylaod depends
// on each platform implementation. Some may transmit only
// protocol header and options, other may transmit nothing.
// We test only that SendMsg and SendMsgs will not crash with
// empty buffers.
wm := socket.Message{
Buffers: bytes.SplitAfter(data, []byte("-")),
Buffers: [][]byte{{}},
Addr: c.LocalAddr(),
}
if err := cc.SendMsg(&wm, 0); err != nil {
t.Fatal(err)
}
b := make([]byte, 32)
rm := socket.Message{
Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
}
if err := cc.RecvMsg(&rm, 0); err != nil {
t.Fatal(err)
}
if !bytes.Equal(b[:rm.N], data) {
t.Fatalf("got %#v; want %#v", b[:rm.N], data)
}
}

func testUDPMessages(t *testing.T, c net.Conn) {
cc, err := socket.NewConn(c)
if err != nil {
t.Fatal(err)
}
data := []byte("HELLO-R-U-THERE")
wmbs := bytes.SplitAfter(data, []byte("-"))
cc.SendMsg(&wm, 0)
wms := []socket.Message{
{Buffers: wmbs[:1], Addr: c.LocalAddr()},
{Buffers: wmbs[1:], Addr: c.LocalAddr()},
}
n, err := cc.SendMsgs(wms, 0)
if err != nil {
t.Fatal(err)
}
if n != len(wms) {
t.Fatalf("got %d; want %d", n, len(wms))
}
b := make([]byte, 32)
rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
rms := []socket.Message{
{Buffers: rmbs[0]},
{Buffers: rmbs[1]},
}
n, err = cc.RecvMsgs(rms, 0)
if err != nil {
t.Fatal(err)
}
if n != len(rms) {
t.Fatalf("got %d; want %d", n, len(rms))
}
nn := 0
for i := 0; i < n; i++ {
nn += rms[i].N
}
if !bytes.Equal(b[:nn], data) {
t.Fatalf("got %#v; want %#v", b[:nn], data)
{Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
}
cc.SendMsgs(wms, 0)
}

func BenchmarkUDP(b *testing.B) {
Expand Down

0 comments on commit 4f8c76a

Please sign in to comment.