diff --git a/pkg/scale/bitvec.go b/pkg/scale/bitvec.go new file mode 100644 index 00000000000..ea892600c16 --- /dev/null +++ b/pkg/scale/bitvec.go @@ -0,0 +1,108 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package scale + +const ( + // maxLen equivalent of `ARCH32BIT_BITSLICE_MAX_BITS` in parity-scale-codec + maxLen = 268435455 + // bitSize is the number of bits in a byte + bitSize = 8 +) + +// BitVec represents rust's `bitvec::BitVec` in SCALE +// It is encoded as a compact u32 representing the number of bits in the vector +// followed by the actual bits, rounded up to the nearest byte +type BitVec interface { + // Bits returns the bits in the BitVec + Bits() []uint8 + // Bytes returns the byte representation of the Bits + Bytes() []byte + // Size returns the number of bits in the BitVec + Size() uint +} + +// bitVec implements BitVec +type bitVec struct { + size uint `scale:"1"` + bits []uint8 `scale:"2"` +} + +// NewBitVec returns a new BitVec with the given bits +func NewBitVec(bits []uint8) BitVec { + var size uint + if bits != nil { + size = uint(len(bits)) + } + + return &bitVec{ + size: size, + bits: bits, + } +} + +// Bits returns the bits in the BitVec +func (bv *bitVec) Bits() []uint8 { + return bv.bits +} + +// Bytes returns the byte representation of the BitVec.Bits +func (bv *bitVec) Bytes() []byte { + var b []byte + for i := uint(0); i < bv.size; i += bitSize { + end := i + bitSize + if end > bv.size { + end = bv.size + } + chunk := bv.bits[i:end] + b = append(b, bitsToBytes(chunk)...) + } + return b +} + +// Size returns the number of bits in the BitVec +func (bv *bitVec) Size() uint { + return bv.size +} + +// bitsToBytes converts a slice of bits to a slice of bytes +func bitsToBytes(bits []uint8) []byte { + bitLength := len(bits) + numOfBytes := (bitLength + (bitSize - 1)) / bitSize + bytes := make([]byte, numOfBytes) + + if len(bits)%bitSize != 0 { + // Pad with zeros to make the number of bits a multiple of bitSize + pad := make([]uint8, bitSize-len(bits)%bitSize) + bits = append(bits, pad...) + } + + for i := 0; i < bitLength; i++ { + if bits[i] == 1 { + byteIndex := i / 8 + bitIndex := i % 8 + bytes[byteIndex] |= 1 << bitIndex + } + } + + return bytes +} + +// bytesToBits converts a slice of bytes to a slice of bits +func bytesToBits(b []byte, size uint) []uint8 { + var bits []uint8 + for _, uint8val := range b { + end := size + if end > bitSize { + end = bitSize + } + size -= end + + for j := uint(0); j < end; j++ { + bit := (uint8val >> j) & 1 + bits = append(bits, bit) + } + } + + return bits +} diff --git a/pkg/scale/bitvec_test.go b/pkg/scale/bitvec_test.go new file mode 100644 index 00000000000..08b523a57d8 --- /dev/null +++ b/pkg/scale/bitvec_test.go @@ -0,0 +1,183 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package scale + +import ( + "testing" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/stretchr/testify/require" +) + +func NewTestBitVec(size uint, bits []uint8) BitVec { + return &bitVec{ + size: size, + bits: bits, + } +} + +func TestBitVec(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + wantBitVec BitVec + wantErr bool + }{ + { + name: "empty_bitvec", + in: "0x00", + wantBitVec: NewBitVec(nil), + wantErr: false, + }, + { + name: "1_byte", + in: "0x2055", + wantBitVec: NewBitVec([]uint8{1, 0, 1, 0, 1, 0, 1, 0}), + wantErr: false, + }, + { + name: "4_bytes", + in: "0x645536aa01", + wantBitVec: NewBitVec([]uint8{ + 1, 0, 1, 0, 1, 0, 1, 0, + 0, 1, 1, 0, 1, 1, 0, 0, + 0, 1, 0, 1, 0, 1, 0, 1, + 1}), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resultBytes, err := common.HexToBytes(tt.in) + require.NoError(t, err) + + bv := NewBitVec(nil) + err = Unmarshal(resultBytes, &bv) + require.NoError(t, err) + + require.Equal(t, tt.wantBitVec.Size(), bv.Size()) + require.Equal(t, tt.wantBitVec.Size(), bv.Size()) + + b, err := Marshal(bv) + require.NoError(t, err) + require.Equal(t, resultBytes, b) + }) + } +} + +func TestBitVecBytes(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in BitVec + want []byte + wantErr bool + }{ + { + name: "empty_bitvec", + in: NewBitVec(nil), + want: []byte(nil), + wantErr: false, + }, + { + name: "1_byte", + in: NewBitVec([]uint8{1, 0, 1, 0, 1, 0, 1, 0}), + want: []byte{0x55}, + wantErr: false, + }, + { + name: "4_bytes", + in: NewBitVec([]uint8{ + 1, 0, 1, 0, 1, 0, 1, 0, + 0, 1, 1, 0, 1, 1, 0, 0, + 0, 1, 0, 1, 0, 1, 0, 1, + 1}), + want: []byte{0x55, 0x36, 0xaa, 0x1}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.in.Bytes()) + }) + } +} + +func TestBitVecBytesToBits(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in []byte + want []uint8 + wantErr bool + }{ + { + name: "empty", + in: []byte(nil), + want: []uint8(nil), + wantErr: false, + }, + { + name: "1_byte", + in: []byte{0x55}, + want: []uint8{1, 0, 1, 0, 1, 0, 1, 0}, + wantErr: false, + }, + { + name: "4_bytes", + in: []byte{0x55, 0x36, 0xaa, 0x1}, + want: []uint8{1, 0, 1, 0, 1, 0, 1, 0, + 0, 1, 1, 0, 1, 1, 0, 0, + 0, 1, 0, 1, 0, 1, 0, 1, + 1, 0, 0, 0, 0, 0, 0, 0}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, bytesToBits(tt.in, uint(len(tt.in)*bitSize))) + }) + } +} + +func TestBitVecBitsToBytes(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in []uint8 + want []byte + wantErr bool + }{ + { + name: "empty", + in: []uint8(nil), + want: []byte{}, + wantErr: false, + }, + { + name: "1_byte", + in: []uint8{1, 0, 1, 0, 1, 0, 1, 0}, + want: []byte{0x55}, + wantErr: false, + }, + { + name: "4_bytes", + in: []uint8{1, 0, 1, 0, 1, 0, 1, 0, + 0, 1, 1, 0, 1, 1, 0, 0, + 0, 1, 0, 1, 0, 1, 0, 1, + 1, 0, 0, 0, 0, 0, 0, 0}, + want: []byte{0x55, 0x36, 0xaa, 0x1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, bitsToBytes(tt.in)) + }) + } +} diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index d4cc4f84a3c..75dd11b144a 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -114,6 +114,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { err = ds.decodeBigInt(dstv) case *Uint128: err = ds.decodeUint128(dstv) + case BitVec: + err = ds.decodeBitVec(dstv) case int, uint: err = ds.decodeUint(dstv) case int8, uint8, int16, uint16, int32, uint32, int64, uint64: @@ -752,3 +754,27 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) { dstv.Set(reflect.ValueOf(ui128)) return } + +// decodeBitVec accepts a byte array representing a SCALE encoded +// BitVec and performs SCALE decoding of the BitVec +func (ds *decodeState) decodeBitVec(dstv reflect.Value) error { + var size uint + if err := ds.decodeUint(reflect.ValueOf(&size).Elem()); err != nil { + return err + } + + if size > maxLen { + return fmt.Errorf("%w: %d", errBitVecTooLong, size) + } + + numBytes := (size + 7) / 8 + b := make([]byte, numBytes) + _, err := ds.Read(b) + if err != nil { + return err + } + + bitvec := NewBitVec(bytesToBits(b, size)) + dstv.Set(reflect.ValueOf(bitvec)) + return nil +} diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index 94d51f993b8..cfc9e9f216d 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -9,6 +9,8 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/require" + "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" @@ -247,6 +249,28 @@ func Test_decodeState_decodeMap(t *testing.T) { } } +func Test_decodeState_decodeBitVec(t *testing.T) { + for _, tt := range bitVecTests { + t.Run(tt.name, func(t *testing.T) { + dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface() + if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr { + t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(dst, tt.in) { + t.Errorf("decodeState.unmarshal() = %v, want %v", dst, tt.in) + } + }) + } +} + +func Test_decodeState_decodeBitVecMaxLen(t *testing.T) { + t.Parallel() + bitvec := NewBitVec(nil) + maxLen10 := []byte{38, 0, 0, 64, 0} // maxLen + 10 + err := Unmarshal(maxLen10, &bitvec) + require.Error(t, err, errBitVecTooLong) +} + func Test_unmarshal_optionality(t *testing.T) { var ptrTests tests for _, t := range append(tests{}, allTests...) { diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index d312b85f913..5a079191360 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -73,6 +73,8 @@ func (es *encodeState) marshal(in interface{}) (err error) { err = es.encodeBigInt(in) case *Uint128: err = es.encodeUint128(in) + case BitVec: + err = es.encodeBitVec(in) case []byte: err = es.encodeBytes(in) case string: @@ -423,3 +425,20 @@ func (es *encodeState) encodeUint128(i *Uint128) (err error) { err = binary.Write(es, binary.LittleEndian, padBytes(i.Bytes(), binary.LittleEndian)) return } + +// encodeBitVec encodes a BitVec +func (es *encodeState) encodeBitVec(bitvec BitVec) (err error) { + if bitvec.Size() > maxLen { + err = fmt.Errorf("%w: %d", errBitVecTooLong, bitvec.Size()) + return + } + + err = es.encodeUint(bitvec.Size()) + if err != nil { + return + } + + data := bitvec.Bytes() + _, err = es.Write(data) + return +} diff --git a/pkg/scale/encode_test.go b/pkg/scale/encode_test.go index 627924541b1..6bac37245d8 100644 --- a/pkg/scale/encode_test.go +++ b/pkg/scale/encode_test.go @@ -929,6 +929,28 @@ var ( }, } + bitVecTests = tests{ + { + name: "BitVec{Size:__0,_Bits:__nil}", + in: NewBitVec(nil), + want: []byte{0}, + }, + { + name: "BitVec{Size:_8}", + in: NewBitVec([]uint8{1, 0, 1, 0, 1, 0, 1, 0}), + want: []byte{0x20, 0x55}, + }, + { + name: "BitVec{Size:_25}", + in: NewBitVec([]uint8{ + 1, 0, 1, 0, 1, 0, 1, 0, + 0, 1, 1, 0, 1, 1, 0, 0, + 0, 1, 0, 1, 0, 1, 0, 1, + 1}), + want: []byte{0x64, 0x55, 0x36, 0xaa, 0x1}, + }, + } + allTests = newTests( fixedWidthIntegerTests, variableWidthIntegerTests, stringTests, boolTests, structTests, sliceTests, arrayTests, @@ -1159,6 +1181,31 @@ func Test_encodeState_encodeMap(t *testing.T) { } } +func Test_encodeState_encodeBitVec(t *testing.T) { + for _, tt := range bitVecTests { + t.Run(tt.name, func(t *testing.T) { + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } + if err := es.marshal(tt.in); (err != nil) != tt.wantErr { + t.Errorf("encodeState.encodeBitVec() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeBitVec() = %v, want %v", buffer.Bytes(), tt.want) + } + }) + } +} + +func Test_encodeState_encodeBitVecMaxLen(t *testing.T) { + t.Parallel() + bitvec := NewTestBitVec(maxLen+10, nil) + _, err := Marshal(bitvec) + require.Error(t, err, errBitVecTooLong) +} + func Test_marshal_optionality(t *testing.T) { var ptrTests tests for i := range allTests { diff --git a/pkg/scale/errors.go b/pkg/scale/errors.go index 0fcaed6861b..0f784d31597 100644 --- a/pkg/scale/errors.go +++ b/pkg/scale/errors.go @@ -13,6 +13,7 @@ var ( errUnsupportedOption = errors.New("unsupported option") errUnknownVaryingDataTypeValue = errors.New("unable to find VaryingDataTypeValue with index") errUint128IsNil = errors.New("uint128 in nil") + errBitVecTooLong = errors.New("bitvec too long") ErrResultNotSet = errors.New("result not set") ErrResultAlreadySet = errors.New("result already has an assigned value") ErrUnsupportedVaryingDataTypeValue = errors.New("unsupported VaryingDataTypeValue")