From 3d9a54f092879f0356034cea00157ef08ebbac71 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Sat, 6 Nov 2021 16:17:26 +0000 Subject: [PATCH] Fix unit test, it should return after any error is returned from Decode function whether expected or not, rather than continue and try to compare invalid decoded results. Extend the unit test slightly to check the header. Remove go-test/deep dependency in favour of standard library reflect package. --- function_call.go | 60 +++++++++++++++++++++---------------------- function_call_test.go | 40 +++++++++++++++++++++-------- go.mod | 1 - go.sum | 3 --- 4 files changed, 59 insertions(+), 45 deletions(-) diff --git a/function_call.go b/function_call.go index 74d3c3c..11cccb3 100644 --- a/function_call.go +++ b/function_call.go @@ -6,10 +6,10 @@ import ( "github.com/jackc/pgio" ) -type FunctionCall struct{ - Function uint32 - ArgFormatCodes []uint16 - Arguments [][]byte +type FunctionCall struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte ResultFormatCode uint16 } @@ -24,9 +24,9 @@ func (dst *FunctionCall) Decode(src []byte) error { // Specifies the object ID of the function to call. dst.Function = binary.BigEndian.Uint32(src[rp:]) rp += 4 - // The number of argument format codes that follow (denoted C below). - // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); - // or one, in which case the specified format code is applied to all arguments; + // The number of argument format codes that follow (denoted C below). + // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); + // or one, in which case the specified format code is applied to all arguments; // or it can equal the actual number of arguments. nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 @@ -37,36 +37,36 @@ func (dst *FunctionCall) Decode(src []byte) error { if ac != 0 && ac != 1 { return &invalidMessageFormatErr{messageType: "FunctionCall"} } - argumentCodes[i] = ac - rp += 2 - } + argumentCodes[i] = ac + rp += 2 + } dst.ArgFormatCodes = argumentCodes - + // Specifies the number of arguments being supplied to the function. nArguments := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 arguments := make([][]byte, nArguments) for i := 0; i < nArguments; i++ { - // The length of the argument value, in bytes (this count does not include itself). Can be zero. + // The length of the argument value, in bytes (this count does not include itself). Can be zero. // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. - argumentLength := int(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - if argumentLength == -1 { + argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if argumentLength == -1 { arguments[i] = nil } else { // The value of the argument, in the format indicated by the associated format code. n is the above length. - argumentValue := src[rp:rp+argumentLength] + argumentValue := src[rp : rp+argumentLength] rp += argumentLength arguments[i] = argumentValue - } - } + } + } dst.Arguments = arguments // The format code for the function result. Must presently be zero (text) or one (binary). resultFormatCode := binary.BigEndian.Uint16(src[rp:]) if resultFormatCode != 0 && resultFormatCode != 1 { - return &invalidMessageFormatErr{messageType: "FunctionCall"} - } - dst.ResultFormatCode = resultFormatCode + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + dst.ResultFormatCode = resultFormatCode return nil } @@ -78,17 +78,17 @@ func (src *FunctionCall) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, src.Function) dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { - dst = pgio.AppendUint16(dst, argFormatCode) - } + dst = pgio.AppendUint16(dst, argFormatCode) + } dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) for _, argument := range src.Arguments { - if argument == nil { - dst = pgio.AppendInt32(dst, -1) - } else { - dst = pgio.AppendInt32(dst, int32(len(argument))) - dst = append(dst, argument...) - } - } + if argument == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(argument))) + dst = append(dst, argument...) + } + } dst = pgio.AppendUint16(dst, src.ResultFormatCode) pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return dst diff --git a/function_call_test.go b/function_call_test.go index f158656..8c08bb2 100644 --- a/function_call_test.go +++ b/function_call_test.go @@ -1,7 +1,8 @@ package pgproto3 import ( - "github.com/go-test/deep" + "encoding/binary" + "reflect" "testing" ) @@ -17,7 +18,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { fields fields wantErr bool }{ - {"foo", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, false}, + {"valid", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(1)}, false}, {"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true}, {"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true}, } @@ -30,15 +31,32 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { ResultFormatCode: tt.fields.ResultFormatCode, } encoded := src.Encode([]byte{}) - decoded := &FunctionCall{} - err := decoded.Decode(encoded[5:]) - if (err != nil) != tt.wantErr { - t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) - return - } - if diff := deep.Equal(src, decoded); diff != nil { - t.Error(diff) + dst := &FunctionCall{} + // Check the header + msgTypeCode := encoded[0] + if msgTypeCode != 'F' { + t.Errorf("msgTypeCode %v should be 'F'", msgTypeCode) + return + } + // Check length, does not include type code character + l := binary.BigEndian.Uint32(encoded[1:5]) + if int(l) != (len(encoded) - 1) { + t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) + } + // Check decoding works as expected + err := dst.Decode(encoded[5:]) + if err != nil { + if !tt.wantErr { + t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + + if !reflect.DeepEqual(src, dst) { + t.Error("difference after encode / decode cycle") + t.Errorf("src = %v", src) + t.Errorf("dst = %v", dst) } }) } -} \ No newline at end of file +} diff --git a/go.mod b/go.mod index 030953b..36041a9 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/jackc/pgproto3/v2 go 1.12 require ( - github.com/go-test/deep v1.0.8 github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index 765190c..dd9cd04 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= -github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= @@ -11,7 +9,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=