Skip to content

Commit

Permalink
Fix unit test, it should return after any error is returned from Decode
Browse files Browse the repository at this point in the history
function whether expected or not, rather than continue and try to
compare invalid decoded results.

Extend the unit test slightly to check the header.

Remove go-test/deep dependency in favour of standard library reflect
package.
  • Loading branch information
MFAshby authored and jackc committed Nov 6, 2021
1 parent 9275da5 commit 3d9a54f
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 45 deletions.
60 changes: 30 additions & 30 deletions function_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"github.com/jackc/pgio"
)

type FunctionCall struct{
Function uint32
ArgFormatCodes []uint16
Arguments [][]byte
type FunctionCall struct {
Function uint32
ArgFormatCodes []uint16
Arguments [][]byte
ResultFormatCode uint16
}

Expand All @@ -24,9 +24,9 @@ func (dst *FunctionCall) Decode(src []byte) error {
// Specifies the object ID of the function to call.
dst.Function = binary.BigEndian.Uint32(src[rp:])
rp += 4
// The number of argument format codes that follow (denoted C below).
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
// or one, in which case the specified format code is applied to all arguments;
// The number of argument format codes that follow (denoted C below).
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
// or one, in which case the specified format code is applied to all arguments;
// or it can equal the actual number of arguments.
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
Expand All @@ -37,36 +37,36 @@ func (dst *FunctionCall) Decode(src []byte) error {
if ac != 0 && ac != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
argumentCodes[i] = ac
rp += 2
}
argumentCodes[i] = ac
rp += 2
}
dst.ArgFormatCodes = argumentCodes

// Specifies the number of arguments being supplied to the function.
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
arguments := make([][]byte, nArguments)
for i := 0; i < nArguments; i++ {
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if argumentLength == -1 {
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if argumentLength == -1 {
arguments[i] = nil
} else {
// The value of the argument, in the format indicated by the associated format code. n is the above length.
argumentValue := src[rp:rp+argumentLength]
argumentValue := src[rp : rp+argumentLength]
rp += argumentLength
arguments[i] = argumentValue
}
}
}
}
dst.Arguments = arguments
// The format code for the function result. Must presently be zero (text) or one (binary).
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
if resultFormatCode != 0 && resultFormatCode != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
dst.ResultFormatCode = resultFormatCode
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
dst.ResultFormatCode = resultFormatCode
return nil
}

Expand All @@ -78,17 +78,17 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
dst = pgio.AppendUint32(dst, src.Function)
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}
dst = pgio.AppendUint16(dst, argFormatCode)
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(argument)))
dst = append(dst, argument...)
}
}
if argument == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(argument)))
dst = append(dst, argument...)
}
}
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
Expand Down
40 changes: 29 additions & 11 deletions function_call_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package pgproto3

import (
"github.com/go-test/deep"
"encoding/binary"
"reflect"
"testing"
)

Expand All @@ -17,7 +18,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
fields fields
wantErr bool
}{
{"foo", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, false},
{"valid", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(1)}, false},
{"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true},
{"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true},
}
Expand All @@ -30,15 +31,32 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
ResultFormatCode: tt.fields.ResultFormatCode,
}
encoded := src.Encode([]byte{})
decoded := &FunctionCall{}
err := decoded.Decode(encoded[5:])
if (err != nil) != tt.wantErr {
t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := deep.Equal(src, decoded); diff != nil {
t.Error(diff)
dst := &FunctionCall{}
// Check the header
msgTypeCode := encoded[0]
if msgTypeCode != 'F' {
t.Errorf("msgTypeCode %v should be 'F'", msgTypeCode)
return
}
// Check length, does not include type code character
l := binary.BigEndian.Uint32(encoded[1:5])
if int(l) != (len(encoded) - 1) {
t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded))
}
// Check decoding works as expected
err := dst.Decode(encoded[5:])
if err != nil {
if !tt.wantErr {
t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
}
return
}

if !reflect.DeepEqual(src, dst) {
t.Error("difference after encode / decode cycle")
t.Errorf("src = %v", src)
t.Errorf("dst = %v", dst)
}
})
}
}
}
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module github.com/jackc/pgproto3/v2
go 1.12

require (
github.com/go-test/deep v1.0.8
github.com/jackc/chunkreader/v2 v2.0.0
github.com/jackc/pgio v1.0.0
github.com/stretchr/testify v1.4.0
Expand Down
3 changes: 0 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
Expand All @@ -11,7 +9,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

0 comments on commit 3d9a54f

Please sign in to comment.