diff --git a/function_call.go b/function_call.go index 74d3c3c..11cccb3 100644 --- a/function_call.go +++ b/function_call.go @@ -6,10 +6,10 @@ import ( "github.com/jackc/pgio" ) -type FunctionCall struct{ - Function uint32 - ArgFormatCodes []uint16 - Arguments [][]byte +type FunctionCall struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte ResultFormatCode uint16 } @@ -24,9 +24,9 @@ func (dst *FunctionCall) Decode(src []byte) error { // Specifies the object ID of the function to call. dst.Function = binary.BigEndian.Uint32(src[rp:]) rp += 4 - // The number of argument format codes that follow (denoted C below). - // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); - // or one, in which case the specified format code is applied to all arguments; + // The number of argument format codes that follow (denoted C below). + // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); + // or one, in which case the specified format code is applied to all arguments; // or it can equal the actual number of arguments. nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 @@ -37,36 +37,36 @@ func (dst *FunctionCall) Decode(src []byte) error { if ac != 0 && ac != 1 { return &invalidMessageFormatErr{messageType: "FunctionCall"} } - argumentCodes[i] = ac - rp += 2 - } + argumentCodes[i] = ac + rp += 2 + } dst.ArgFormatCodes = argumentCodes - + // Specifies the number of arguments being supplied to the function. nArguments := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 arguments := make([][]byte, nArguments) for i := 0; i < nArguments; i++ { - // The length of the argument value, in bytes (this count does not include itself). Can be zero. + // The length of the argument value, in bytes (this count does not include itself). Can be zero. // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. - argumentLength := int(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - if argumentLength == -1 { + argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if argumentLength == -1 { arguments[i] = nil } else { // The value of the argument, in the format indicated by the associated format code. n is the above length. - argumentValue := src[rp:rp+argumentLength] + argumentValue := src[rp : rp+argumentLength] rp += argumentLength arguments[i] = argumentValue - } - } + } + } dst.Arguments = arguments // The format code for the function result. Must presently be zero (text) or one (binary). resultFormatCode := binary.BigEndian.Uint16(src[rp:]) if resultFormatCode != 0 && resultFormatCode != 1 { - return &invalidMessageFormatErr{messageType: "FunctionCall"} - } - dst.ResultFormatCode = resultFormatCode + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + dst.ResultFormatCode = resultFormatCode return nil } @@ -78,17 +78,17 @@ func (src *FunctionCall) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, src.Function) dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { - dst = pgio.AppendUint16(dst, argFormatCode) - } + dst = pgio.AppendUint16(dst, argFormatCode) + } dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) for _, argument := range src.Arguments { - if argument == nil { - dst = pgio.AppendInt32(dst, -1) - } else { - dst = pgio.AppendInt32(dst, int32(len(argument))) - dst = append(dst, argument...) - } - } + if argument == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(argument))) + dst = append(dst, argument...) + } + } dst = pgio.AppendUint16(dst, src.ResultFormatCode) pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return dst diff --git a/function_call_test.go b/function_call_test.go index f158656..8c08bb2 100644 --- a/function_call_test.go +++ b/function_call_test.go @@ -1,7 +1,8 @@ package pgproto3 import ( - "github.com/go-test/deep" + "encoding/binary" + "reflect" "testing" ) @@ -17,7 +18,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { fields fields wantErr bool }{ - {"foo", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, false}, + {"valid", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(1)}, false}, {"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true}, {"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true}, } @@ -30,15 +31,32 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { ResultFormatCode: tt.fields.ResultFormatCode, } encoded := src.Encode([]byte{}) - decoded := &FunctionCall{} - err := decoded.Decode(encoded[5:]) - if (err != nil) != tt.wantErr { - t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) - return - } - if diff := deep.Equal(src, decoded); diff != nil { - t.Error(diff) + dst := &FunctionCall{} + // Check the header + msgTypeCode := encoded[0] + if msgTypeCode != 'F' { + t.Errorf("msgTypeCode %v should be 'F'", msgTypeCode) + return + } + // Check length, does not include type code character + l := binary.BigEndian.Uint32(encoded[1:5]) + if int(l) != (len(encoded) - 1) { + t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) + } + // Check decoding works as expected + err := dst.Decode(encoded[5:]) + if err != nil { + if !tt.wantErr { + t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + + if !reflect.DeepEqual(src, dst) { + t.Error("difference after encode / decode cycle") + t.Errorf("src = %v", src) + t.Errorf("dst = %v", dst) } }) } -} \ No newline at end of file +} diff --git a/go.mod b/go.mod index 030953b..36041a9 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/jackc/pgproto3/v2 go 1.12 require ( - github.com/go-test/deep v1.0.8 github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index 765190c..dd9cd04 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= -github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= @@ -11,7 +9,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=