Skip to content

Commit

Permalink
feat(pkg/scale): add Encoder with Encode method (#2741)
Browse files Browse the repository at this point in the history
- Change `encodeState` to use `io.Writer` instead of `bytes.Buffer`
- Define `Encoder` with `Encode(value interface{}) error` method
- Define constructor `NewEncoder(writer io.Writer) *Encoder`
- Add unit tests for encoder
  • Loading branch information
qdm12 authored Aug 16, 2022
1 parent 363c080 commit af5c63f
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 46 deletions.
37 changes: 30 additions & 7 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,48 @@ import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math/big"
"reflect"
)

// Encoder scale encodes to a given io.Writer.
type Encoder struct {
encodeState
}

// NewEncoder creates a new encoder with the given writer.
func NewEncoder(writer io.Writer) (encoder *Encoder) {
return &Encoder{
encodeState: encodeState{
Writer: writer,
fieldScaleIndicesCache: cache,
},
}
}

// Encode scale encodes value to the encoder writer.
func (e *Encoder) Encode(value interface{}) (err error) {
return e.marshal(value)
}

// Marshal takes in an interface{} and attempts to marshal into []byte
func Marshal(v interface{}) (b []byte, err error) {
buffer := bytes.NewBuffer(nil)
es := encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
err = es.marshal(v)
if err != nil {
return
}
b = es.Bytes()
b = buffer.Bytes()
return
}

type encodeState struct {
bytes.Buffer
io.Writer
*fieldScaleIndicesCache
}

Expand Down Expand Up @@ -64,9 +87,9 @@ func (es *encodeState) marshal(in interface{}) (err error) {
elem := reflect.ValueOf(in).Elem()
switch elem.IsValid() {
case false:
err = es.WriteByte(0)
_, err = es.Write([]byte{0})
default:
err = es.WriteByte(1)
_, err = es.Write([]byte{1})
if err != nil {
return
}
Expand Down Expand Up @@ -133,13 +156,13 @@ func (es *encodeState) encodeResult(res Result) (err error) {
var in interface{}
switch res.mode {
case OK:
err = es.WriteByte(0)
_, err = es.Write([]byte{0})
if err != nil {
return
}
in = res.ok
case Err:
err = es.WriteByte(1)
_, err = es.Write([]byte{1})
if err != nil {
return
}
Expand All @@ -159,7 +182,7 @@ func (es *encodeState) encodeCustomVaryingDataType(in interface{}) (err error) {
}

func (es *encodeState) encodeVaryingDataType(vdt VaryingDataType) (err error) {
err = es.WriteByte(byte(vdt.value.Index()))
_, err = es.Write([]byte{byte(vdt.value.Index())})
if err != nil {
return
}
Expand Down
168 changes: 135 additions & 33 deletions pkg/scale/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,74 @@
package scale

import (
"bytes"
"math/big"
"reflect"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_NewEncoder(t *testing.T) {
t.Parallel()

cache.Lock()
defer cache.Unlock()

writer := bytes.NewBuffer(nil)
encoder := NewEncoder(writer)

expectedEncoder := &Encoder{
encodeState: encodeState{
Writer: writer,
fieldScaleIndicesCache: cache,
},
}

assert.Equal(t, expectedEncoder, encoder)
}

func Test_Encoder_Encode(t *testing.T) {
t.Parallel()

buffer := bytes.NewBuffer(nil)
encoder := NewEncoder(buffer)

err := encoder.Encode(uint16(1))
require.NoError(t, err)

err = encoder.Encode(uint8(2))
require.NoError(t, err)

array := [2]byte{4, 5}
err = encoder.Encode(array)
require.NoError(t, err)

type T struct {
Array [2]byte
}

someStruct := T{Array: [2]byte{6, 7}}
err = encoder.Encode(someStruct)
require.NoError(t, err)

structSlice := []T{{Array: [2]byte{8, 9}}}
err = encoder.Encode(structSlice)
require.NoError(t, err)

written := buffer.Bytes()
expectedWritten := []byte{
1, 0,
2,
4, 5,
6, 7,
4, 8, 9,
}
assert.Equal(t, expectedWritten, written)
}

type test struct {
name string
in interface{}
Expand Down Expand Up @@ -869,12 +931,15 @@ type MyStructWithPrivate struct {
func Test_encodeState_encodeFixedWidthInteger(t *testing.T) {
for _, tt := range fixedWidthIntegerTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -883,12 +948,15 @@ func Test_encodeState_encodeFixedWidthInteger(t *testing.T) {
func Test_encodeState_encodeVariableWidthIntegers(t *testing.T) {
for _, tt := range variableWidthIntegerTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -897,12 +965,15 @@ func Test_encodeState_encodeVariableWidthIntegers(t *testing.T) {
func Test_encodeState_encodeBigInt(t *testing.T) {
for _, tt := range bigIntTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeBigInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBigInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBigInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -911,12 +982,15 @@ func Test_encodeState_encodeBigInt(t *testing.T) {
func Test_encodeState_encodeUint128(t *testing.T) {
for _, tt := range uint128Tests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeUin128() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeUin128() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeUin128() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -925,12 +999,16 @@ func Test_encodeState_encodeUint128(t *testing.T) {
func Test_encodeState_encodeBytes(t *testing.T) {
for _, tt := range stringTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}

buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeBytes() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBytes() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBytes() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -939,12 +1017,16 @@ func Test_encodeState_encodeBytes(t *testing.T) {
func Test_encodeState_encodeBool(t *testing.T) {
for _, tt := range boolTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}

buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeBool() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBool() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBool() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -953,12 +1035,16 @@ func Test_encodeState_encodeBool(t *testing.T) {
func Test_encodeState_encodeStruct(t *testing.T) {
for _, tt := range structTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeStruct() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeStruct() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -967,12 +1053,16 @@ func Test_encodeState_encodeStruct(t *testing.T) {
func Test_encodeState_encodeSlice(t *testing.T) {
for _, tt := range sliceTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeSlice() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeSlice() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeSlice() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -981,12 +1071,16 @@ func Test_encodeState_encodeSlice(t *testing.T) {
func Test_encodeState_encodeArray(t *testing.T) {
for _, tt := range arrayTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeArray() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeArray() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeArray() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -1007,12 +1101,16 @@ func Test_marshal_optionality(t *testing.T) {
}
for _, tt := range ptrTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand Down Expand Up @@ -1043,12 +1141,16 @@ func Test_marshal_optionality_nil_cases(t *testing.T) {
}
for _, tt := range ptrTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand Down
Loading

0 comments on commit af5c63f

Please sign in to comment.