From 4f8c76a9752541058ee4b5bd59b2c27eadda7262 Mon Sep 17 00:00:00 2001 From: Mikio Hara Date: Tue, 3 Oct 2017 12:23:07 +0900 Subject: [PATCH] internal/socket: don't crash with empty Message.Buffers Fixes golang/go#22117. Change-Id: I0d9c3e126aaf97cd297c84e064e9a521ddac626f Reviewed-on: https://go-review.googlesource.com/67750 Reviewed-by: Ian Lance Taylor --- internal/socket/iovec_32bit.go | 6 +- internal/socket/iovec_64bit.go | 6 +- internal/socket/iovec_solaris_64bit.go | 6 +- internal/socket/msghdr_bsdvar.go | 6 +- internal/socket/msghdr_linux_32bit.go | 6 +- internal/socket/msghdr_linux_64bit.go | 6 +- internal/socket/msghdr_openbsd.go | 6 +- internal/socket/msghdr_solaris_64bit.go | 6 +- internal/socket/socket_go1_9_test.go | 125 ++++++++++++------------ 9 files changed, 103 insertions(+), 70 deletions(-) diff --git a/internal/socket/iovec_32bit.go b/internal/socket/iovec_32bit.go index d6a570c900..05d6082d14 100644 --- a/internal/socket/iovec_32bit.go +++ b/internal/socket/iovec_32bit.go @@ -10,6 +10,10 @@ package socket import "unsafe" func (v *iovec) set(b []byte) { + l := len(b) + if l == 0 { + return + } v.Base = (*byte)(unsafe.Pointer(&b[0])) - v.Len = uint32(len(b)) + v.Len = uint32(l) } diff --git a/internal/socket/iovec_64bit.go b/internal/socket/iovec_64bit.go index 2ae435e64c..afb34ad58e 100644 --- a/internal/socket/iovec_64bit.go +++ b/internal/socket/iovec_64bit.go @@ -10,6 +10,10 @@ package socket import "unsafe" func (v *iovec) set(b []byte) { + l := len(b) + if l == 0 { + return + } v.Base = (*byte)(unsafe.Pointer(&b[0])) - v.Len = uint64(len(b)) + v.Len = uint64(l) } diff --git a/internal/socket/iovec_solaris_64bit.go b/internal/socket/iovec_solaris_64bit.go index 100a62820f..8d17a40c40 100644 --- a/internal/socket/iovec_solaris_64bit.go +++ b/internal/socket/iovec_solaris_64bit.go @@ -10,6 +10,10 @@ package socket import "unsafe" func (v *iovec) set(b []byte) { + l := len(b) + if l == 0 { + return + } v.Base = (*int8)(unsafe.Pointer(&b[0])) - v.Len = uint64(len(b)) + v.Len = uint64(l) } diff --git a/internal/socket/msghdr_bsdvar.go b/internal/socket/msghdr_bsdvar.go index 3fcb042801..b8c87b72b9 100644 --- a/internal/socket/msghdr_bsdvar.go +++ b/internal/socket/msghdr_bsdvar.go @@ -7,6 +7,10 @@ package socket func (h *msghdr) setIov(vs []iovec) { + l := len(vs) + if l == 0 { + return + } h.Iov = &vs[0] - h.Iovlen = int32(len(vs)) + h.Iovlen = int32(l) } diff --git a/internal/socket/msghdr_linux_32bit.go b/internal/socket/msghdr_linux_32bit.go index 9f671aec01..a7a5987c88 100644 --- a/internal/socket/msghdr_linux_32bit.go +++ b/internal/socket/msghdr_linux_32bit.go @@ -10,8 +10,12 @@ package socket import "unsafe" func (h *msghdr) setIov(vs []iovec) { + l := len(vs) + if l == 0 { + return + } h.Iov = &vs[0] - h.Iovlen = uint32(len(vs)) + h.Iovlen = uint32(l) } func (h *msghdr) setControl(b []byte) { diff --git a/internal/socket/msghdr_linux_64bit.go b/internal/socket/msghdr_linux_64bit.go index 9f78706214..610fc4f3bb 100644 --- a/internal/socket/msghdr_linux_64bit.go +++ b/internal/socket/msghdr_linux_64bit.go @@ -10,8 +10,12 @@ package socket import "unsafe" func (h *msghdr) setIov(vs []iovec) { + l := len(vs) + if l == 0 { + return + } h.Iov = &vs[0] - h.Iovlen = uint64(len(vs)) + h.Iovlen = uint64(l) } func (h *msghdr) setControl(b []byte) { diff --git a/internal/socket/msghdr_openbsd.go b/internal/socket/msghdr_openbsd.go index be354ff847..71a69e2513 100644 --- a/internal/socket/msghdr_openbsd.go +++ b/internal/socket/msghdr_openbsd.go @@ -5,6 +5,10 @@ package socket func (h *msghdr) setIov(vs []iovec) { + l := len(vs) + if l == 0 { + return + } h.Iov = &vs[0] - h.Iovlen = uint32(len(vs)) + h.Iovlen = uint32(l) } diff --git a/internal/socket/msghdr_solaris_64bit.go b/internal/socket/msghdr_solaris_64bit.go index d1b0593973..6465b20732 100644 --- a/internal/socket/msghdr_solaris_64bit.go +++ b/internal/socket/msghdr_solaris_64bit.go @@ -13,8 +13,10 @@ func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) { for i := range vs { vs[i].set(bs[i]) } - h.Iov = &vs[0] - h.Iovlen = int32(len(vs)) + if len(vs) > 0 { + h.Iov = &vs[0] + h.Iovlen = int32(len(vs)) + } if len(oob) > 0 { h.Accrights = (*int8)(unsafe.Pointer(&oob[0])) h.Accrightslen = int32(len(oob)) diff --git a/internal/socket/socket_go1_9_test.go b/internal/socket/socket_go1_9_test.go index 109fed762e..522486a213 100644 --- a/internal/socket/socket_go1_9_test.go +++ b/internal/socket/socket_go1_9_test.go @@ -119,81 +119,84 @@ func TestUDP(t *testing.T) { t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) } defer c.Close() + cc, err := socket.NewConn(c.(net.Conn)) + if err != nil { + t.Fatal(err) + } t.Run("Message", func(t *testing.T) { - testUDPMessage(t, c.(net.Conn)) + data := []byte("HELLO-R-U-THERE") + wm := socket.Message{ + Buffers: bytes.SplitAfter(data, []byte("-")), + Addr: c.LocalAddr(), + } + if err := cc.SendMsg(&wm, 0); err != nil { + t.Fatal(err) + } + b := make([]byte, 32) + rm := socket.Message{ + Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]}, + } + if err := cc.RecvMsg(&rm, 0); err != nil { + t.Fatal(err) + } + if !bytes.Equal(b[:rm.N], data) { + t.Fatalf("got %#v; want %#v", b[:rm.N], data) + } }) switch runtime.GOOS { case "linux": t.Run("Messages", func(t *testing.T) { - testUDPMessages(t, c.(net.Conn)) + data := []byte("HELLO-R-U-THERE") + wmbs := bytes.SplitAfter(data, []byte("-")) + wms := []socket.Message{ + {Buffers: wmbs[:1], Addr: c.LocalAddr()}, + {Buffers: wmbs[1:], Addr: c.LocalAddr()}, + } + n, err := cc.SendMsgs(wms, 0) + if err != nil { + t.Fatal(err) + } + if n != len(wms) { + t.Fatalf("got %d; want %d", n, len(wms)) + } + b := make([]byte, 32) + rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}} + rms := []socket.Message{ + {Buffers: rmbs[0]}, + {Buffers: rmbs[1]}, + } + n, err = cc.RecvMsgs(rms, 0) + if err != nil { + t.Fatal(err) + } + if n != len(rms) { + t.Fatalf("got %d; want %d", n, len(rms)) + } + nn := 0 + for i := 0; i < n; i++ { + nn += rms[i].N + } + if !bytes.Equal(b[:nn], data) { + t.Fatalf("got %#v; want %#v", b[:nn], data) + } }) } -} -func testUDPMessage(t *testing.T, c net.Conn) { - cc, err := socket.NewConn(c) - if err != nil { - t.Fatal(err) - } - data := []byte("HELLO-R-U-THERE") + // The behavior of transmission for zero byte paylaod depends + // on each platform implementation. Some may transmit only + // protocol header and options, other may transmit nothing. + // We test only that SendMsg and SendMsgs will not crash with + // empty buffers. wm := socket.Message{ - Buffers: bytes.SplitAfter(data, []byte("-")), + Buffers: [][]byte{{}}, Addr: c.LocalAddr(), } - if err := cc.SendMsg(&wm, 0); err != nil { - t.Fatal(err) - } - b := make([]byte, 32) - rm := socket.Message{ - Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]}, - } - if err := cc.RecvMsg(&rm, 0); err != nil { - t.Fatal(err) - } - if !bytes.Equal(b[:rm.N], data) { - t.Fatalf("got %#v; want %#v", b[:rm.N], data) - } -} - -func testUDPMessages(t *testing.T, c net.Conn) { - cc, err := socket.NewConn(c) - if err != nil { - t.Fatal(err) - } - data := []byte("HELLO-R-U-THERE") - wmbs := bytes.SplitAfter(data, []byte("-")) + cc.SendMsg(&wm, 0) wms := []socket.Message{ - {Buffers: wmbs[:1], Addr: c.LocalAddr()}, - {Buffers: wmbs[1:], Addr: c.LocalAddr()}, - } - n, err := cc.SendMsgs(wms, 0) - if err != nil { - t.Fatal(err) - } - if n != len(wms) { - t.Fatalf("got %d; want %d", n, len(wms)) - } - b := make([]byte, 32) - rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}} - rms := []socket.Message{ - {Buffers: rmbs[0]}, - {Buffers: rmbs[1]}, - } - n, err = cc.RecvMsgs(rms, 0) - if err != nil { - t.Fatal(err) - } - if n != len(rms) { - t.Fatalf("got %d; want %d", n, len(rms)) - } - nn := 0 - for i := 0; i < n; i++ { - nn += rms[i].N - } - if !bytes.Equal(b[:nn], data) { - t.Fatalf("got %#v; want %#v", b[:nn], data) + {Buffers: [][]byte{{}}, Addr: c.LocalAddr()}, } + cc.SendMsgs(wms, 0) } func BenchmarkUDP(b *testing.B) {