diff --git a/buf.go b/buf.go index fd966c39..e7ff5777 100644 --- a/buf.go +++ b/buf.go @@ -47,28 +47,44 @@ func (b *readBuf) byte() byte { return b.next(1)[0] } -type writeBuf []byte +type writeBuf struct { + buf []byte + pos int +} func (b *writeBuf) int32(n int) { x := make([]byte, 4) binary.BigEndian.PutUint32(x, uint32(n)) - *b = append(*b, x...) + b.buf = append(b.buf, x...) } func (b *writeBuf) int16(n int) { x := make([]byte, 2) binary.BigEndian.PutUint16(x, uint16(n)) - *b = append(*b, x...) + b.buf = append(b.buf, x...) } func (b *writeBuf) string(s string) { - *b = append(*b, (s + "\000")...) + b.buf = append(b.buf, (s + "\000")...) } func (b *writeBuf) byte(c byte) { - *b = append(*b, c) + b.buf = append(b.buf, c) } func (b *writeBuf) bytes(v []byte) { - *b = append(*b, v...) + b.buf = append(b.buf, v...) +} + +func (b *writeBuf) wrap() []byte { + p := b.buf[b.pos:] + binary.BigEndian.PutUint32(p, uint32(len(p))) + return b.buf +} + +func (b *writeBuf) next(c byte) { + p := b.buf[b.pos:] + binary.BigEndian.PutUint32(p, uint32(len(p))) + b.pos = len(b.buf) + 1 + b.buf = append(b.buf, c, 0, 0, 0, 0) } diff --git a/conn.go b/conn.go index e92798ed..d188c9b7 100644 --- a/conn.go +++ b/conn.go @@ -110,8 +110,10 @@ type conn struct { func (c *conn) writeBuf(b byte) *writeBuf { c.scratch[0] = b - w := writeBuf(c.scratch[:5]) - return &w + return &writeBuf{ + buf: c.scratch[:5], + pos: 1, + } } func Open(name string) (_ driver.Conn, err error) { @@ -553,14 +555,13 @@ func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) { b.string(st.name) b.string(q) b.int16(0) - cn.send(b) - b = cn.writeBuf('D') + b.next('D') b.byte('S') b.string(st.name) - cn.send(b) - cn.send(cn.writeBuf('S')) + b.next('S') + cn.send(b) for { t, r := cn.recv1() @@ -670,16 +671,20 @@ func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err er return r, err } -// Assumes len(*m) is > 5 func (cn *conn) send(m *writeBuf) { - b := (*m)[1:] - binary.BigEndian.PutUint32(b, uint32(len(b))) + _, err := cn.c.Write(m.wrap()) + if err != nil { + panic(err) + } +} - if (*m)[0] == 0 { - *m = b +func (cn *conn) sendStartupPacket(m *writeBuf) { + // sanity check + if m.buf[0] != 0 { + panic("oops") } - _, err := cn.c.Write(*m) + _, err := cn.c.Write((m.wrap())[1:]) if err != nil { panic(err) } @@ -819,7 +824,7 @@ func (cn *conn) ssl(o values) { w := cn.writeBuf(0) w.int32(80877103) - cn.send(w) + cn.sendStartupPacket(w) b := cn.scratch[:1] _, err := io.ReadFull(cn.c, b) @@ -983,7 +988,7 @@ func (cn *conn) startup(o values) { w.string(v) } w.string("") - cn.send(w) + cn.sendStartupPacket(w) for { t, r := cn.recv() @@ -1127,7 +1132,7 @@ func (st *stmt) exec(v []driver.Value) { } w := st.cn.writeBuf('B') - w.string("") + w.byte(0) w.string(st.name) w.int16(0) w.int16(len(v)) @@ -1141,14 +1146,13 @@ func (st *stmt) exec(v []driver.Value) { } } w.int16(0) - st.cn.send(w) - w = st.cn.writeBuf('E') - w.string("") + w.next('E') + w.byte(0) w.int32(0) - st.cn.send(w) - st.cn.send(st.cn.writeBuf('S')) + w.next('S') + st.cn.send(w) var err error for { diff --git a/notify.go b/notify.go index c756af8b..8cad5781 100644 --- a/notify.go +++ b/notify.go @@ -253,8 +253,10 @@ func (l *ListenerConn) sendSimpleQuery(q string) (err error) { // Can't use l.cn.writeBuf here because it uses the scratch buffer which // might get overwritten by listenerConnLoop. - data := writeBuf([]byte("Q\x00\x00\x00\x00")) - b := &data + b := &writeBuf{ + buf: []byte("Q\x00\x00\x00\x00"), + pos: 1, + } b.string(q) l.cn.send(b)