Skip to content

Commit

Permalink
Added FunctionCall support
Browse files Browse the repository at this point in the history
Added support for FunctionCall message as per
https://www.postgresql.org/docs/11/protocol-message-formats.html

Adds unit test for Encode / Decode cycle and invalid message format
errors.

Fixes #23
  • Loading branch information
MFAshby authored and jackc committed Nov 6, 2021
1 parent 5c447ff commit 9275da5
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 4 deletions.
12 changes: 8 additions & 4 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Backend struct {
describe Describe
execute Execute
flush Flush
functionCall FunctionCall
gssEncRequest GSSEncRequest
parse Parse
query Query
Expand All @@ -29,10 +30,11 @@ type Backend struct {
sync Sync
terminate Terminate

bodyLen int
msgType byte
partialMsg bool
authType uint32
bodyLen int
msgType byte
partialMsg bool
authType uint32

}

const (
Expand Down Expand Up @@ -125,6 +127,8 @@ func (b *Backend) Receive() (FrontendMessage, error) {
msg = &b.describe
case 'E':
msg = &b.execute
case 'F':
msg = &b.functionCall
case 'f':
msg = &b.copyFail
case 'd':
Expand Down
104 changes: 104 additions & 0 deletions function_call.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package pgproto3

import (
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)

type FunctionCall struct{
Function uint32
ArgFormatCodes []uint16
Arguments [][]byte
ResultFormatCode uint16
}

// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*FunctionCall) Frontend() {}

// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *FunctionCall) Decode(src []byte) error {
*dst = FunctionCall{}
rp := 0
// Specifies the object ID of the function to call.
dst.Function = binary.BigEndian.Uint32(src[rp:])
rp += 4
// The number of argument format codes that follow (denoted C below).
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
// or one, in which case the specified format code is applied to all arguments;
// or it can equal the actual number of arguments.
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
argumentCodes := make([]uint16, nArgumentCodes)
for i := 0; i < nArgumentCodes; i++ {
// The argument format codes. Each must presently be zero (text) or one (binary).
ac := binary.BigEndian.Uint16(src[rp:])
if ac != 0 && ac != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
argumentCodes[i] = ac
rp += 2
}
dst.ArgFormatCodes = argumentCodes

// Specifies the number of arguments being supplied to the function.
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
arguments := make([][]byte, nArguments)
for i := 0; i < nArguments; i++ {
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if argumentLength == -1 {
arguments[i] = nil
} else {
// The value of the argument, in the format indicated by the associated format code. n is the above length.
argumentValue := src[rp:rp+argumentLength]
rp += argumentLength
arguments[i] = argumentValue
}
}
dst.Arguments = arguments
// The format code for the function result. Must presently be zero (text) or one (binary).
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
if resultFormatCode != 0 && resultFormatCode != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
dst.ResultFormatCode = resultFormatCode
return nil
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCall) Encode(dst []byte) []byte {
dst = append(dst, 'F')
sp := len(dst)
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
dst = pgio.AppendUint32(dst, src.Function)
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(argument)))
dst = append(dst, argument...)
}
}
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}

// MarshalJSON implements encoding/json.Marshaler.
func (src FunctionCall) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "FunctionCall",
})
}
44 changes: 44 additions & 0 deletions function_call_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package pgproto3

import (
"github.com/go-test/deep"
"testing"
)

func TestFunctionCall_EncodeDecode(t *testing.T) {
type fields struct {
Function uint32
ArgFormatCodes []uint16
Arguments [][]byte
ResultFormatCode uint16
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"foo", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, false},
{"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true},
{"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
src := &FunctionCall{
Function: tt.fields.Function,
ArgFormatCodes: tt.fields.ArgFormatCodes,
Arguments: tt.fields.Arguments,
ResultFormatCode: tt.fields.ResultFormatCode,
}
encoded := src.Encode([]byte{})
decoded := &FunctionCall{}
err := decoded.Decode(encoded[5:])
if (err != nil) != tt.wantErr {
t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := deep.Equal(src, decoded); diff != nil {
t.Error(diff)
}
})
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/jackc/pgproto3/v2
go 1.12

require (
github.com/go-test/deep v1.0.8
github.com/jackc/chunkreader/v2 v2.0.0
github.com/jackc/pgio v1.0.0
github.com/stretchr/testify v1.4.0
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
Expand All @@ -9,6 +11,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

0 comments on commit 9275da5

Please sign in to comment.