Skip to content

Commit

Permalink
add BitVec
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishkatn committed May 11, 2023
1 parent d594e1e commit 13d1a3b
Show file tree
Hide file tree
Showing 7 changed files with 408 additions and 0 deletions.
108 changes: 108 additions & 0 deletions pkg/scale/bitvec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package scale

const (
// maxLen equivalent of `ARCH32BIT_BITSLICE_MAX_BITS` in parity-scale-codec
maxLen = 268435455
// bitSize is the number of bits in a byte
bitSize = 8
)

// BitVec represents rust's `bitvec::BitVec` in SCALE
// It is encoded as a compact u32 representing the number of bits in the vector
// followed by the actual bits, rounded up to the nearest byte
type BitVec interface {
// Bits returns the bits in the BitVec
Bits() []uint8
// Bytes returns the byte representation of the Bits
Bytes() []byte
// Size returns the number of bits in the BitVec
Size() uint
}

// bitVec implements BitVec
type bitVec struct {
size uint `scale:"1"`
bits []uint8 `scale:"2"`
}

// NewBitVec returns a new BitVec with the given bits
func NewBitVec(bits []uint8) BitVec {
var size uint
if bits != nil {
size = uint(len(bits))
}

return &bitVec{
size: size,
bits: bits,
}
}

// Bits returns the bits in the BitVec
func (bv *bitVec) Bits() []uint8 {
return bv.bits
}

// Bytes returns the byte representation of the BitVec.Bits
func (bv *bitVec) Bytes() []byte {
var b []byte
for i := uint(0); i < bv.size; i += bitSize {
end := i + bitSize
if end > bv.size {
end = bv.size
}
chunk := bv.bits[i:end]
b = append(b, bitsToBytes(chunk)...)
}
return b
}

// Size returns the number of bits in the BitVec
func (bv *bitVec) Size() uint {
return bv.size
}

// bitsToBytes converts a slice of bits to a slice of bytes
func bitsToBytes(bits []uint8) []byte {
bitLength := len(bits)
numOfBytes := (bitLength + (bitSize - 1)) / bitSize
bytes := make([]byte, numOfBytes)

if len(bits)%bitSize != 0 {
// Pad with zeros to make the number of bits a multiple of bitSize
pad := make([]uint8, bitSize-len(bits)%bitSize)
bits = append(bits, pad...)
}

for i := 0; i < bitLength; i++ {
if bits[i] == 1 {
byteIndex := i / 8
bitIndex := i % 8
bytes[byteIndex] |= 1 << bitIndex
}
}

return bytes
}

// bytesToBits converts a slice of bytes to a slice of bits
func bytesToBits(b []byte, size uint) []uint8 {
var bits []uint8
for _, uint8val := range b {
end := size
if end > bitSize {
end = bitSize
}
size -= end

for j := uint(0); j < end; j++ {
bit := (uint8val >> j) & 1
bits = append(bits, bit)
}
}

return bits
}
183 changes: 183 additions & 0 deletions pkg/scale/bitvec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package scale

import (
"testing"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/stretchr/testify/require"
)

func NewTestBitVec(size uint, bits []uint8) BitVec {
return &bitVec{
size: size,
bits: bits,
}
}

func TestBitVec(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
wantBitVec BitVec
wantErr bool
}{
{
name: "empty_bitvec",
in: "0x00",
wantBitVec: NewBitVec(nil),
wantErr: false,
},
{
name: "1_byte",
in: "0x2055",
wantBitVec: NewBitVec([]uint8{1, 0, 1, 0, 1, 0, 1, 0}),
wantErr: false,
},
{
name: "4_bytes",
in: "0x645536aa01",
wantBitVec: NewBitVec([]uint8{
1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 1, 0, 1, 1, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1,
1}),
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resultBytes, err := common.HexToBytes(tt.in)
require.NoError(t, err)

bv := NewBitVec(nil)
err = Unmarshal(resultBytes, &bv)
require.NoError(t, err)

require.Equal(t, tt.wantBitVec.Size(), bv.Size())
require.Equal(t, tt.wantBitVec.Size(), bv.Size())

b, err := Marshal(bv)
require.NoError(t, err)
require.Equal(t, resultBytes, b)
})
}
}

func TestBitVecBytes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in BitVec
want []byte
wantErr bool
}{
{
name: "empty_bitvec",
in: NewBitVec(nil),
want: []byte(nil),
wantErr: false,
},
{
name: "1_byte",
in: NewBitVec([]uint8{1, 0, 1, 0, 1, 0, 1, 0}),
want: []byte{0x55},
wantErr: false,
},
{
name: "4_bytes",
in: NewBitVec([]uint8{
1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 1, 0, 1, 1, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1,
1}),
want: []byte{0x55, 0x36, 0xaa, 0x1},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, tt.in.Bytes())
})
}
}

func TestBitVecBytesToBits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in []byte
want []uint8
wantErr bool
}{
{
name: "empty",
in: []byte(nil),
want: []uint8(nil),
wantErr: false,
},
{
name: "1_byte",
in: []byte{0x55},
want: []uint8{1, 0, 1, 0, 1, 0, 1, 0},
wantErr: false,
},
{
name: "4_bytes",
in: []byte{0x55, 0x36, 0xaa, 0x1},
want: []uint8{1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 1, 0, 1, 1, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 0, 0, 0, 0, 0, 0},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, bytesToBits(tt.in, uint(len(tt.in)*bitSize)))
})
}
}

func TestBitVecBitsToBytes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in []uint8
want []byte
wantErr bool
}{
{
name: "empty",
in: []uint8(nil),
want: []byte{},
wantErr: false,
},
{
name: "1_byte",
in: []uint8{1, 0, 1, 0, 1, 0, 1, 0},
want: []byte{0x55},
wantErr: false,
},
{
name: "4_bytes",
in: []uint8{1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 1, 0, 1, 1, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x55, 0x36, 0xaa, 0x1},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, bitsToBytes(tt.in))
})
}
}
26 changes: 26 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
err = ds.decodeBigInt(dstv)
case *Uint128:
err = ds.decodeUint128(dstv)
case BitVec:
err = ds.decodeBitVec(dstv)
case int, uint:
err = ds.decodeUint(dstv)
case int8, uint8, int16, uint16, int32, uint32, int64, uint64:
Expand Down Expand Up @@ -752,3 +754,27 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) {
dstv.Set(reflect.ValueOf(ui128))
return
}

// decodeBitVec accepts a byte array representing a SCALE encoded
// BitVec and performs SCALE decoding of the BitVec
func (ds *decodeState) decodeBitVec(dstv reflect.Value) error {
var size uint
if err := ds.decodeUint(reflect.ValueOf(&size).Elem()); err != nil {
return err
}

if size > maxLen {
return fmt.Errorf("%w: %d", errBitVecTooLong, size)
}

numBytes := (size + 7) / 8
b := make([]byte, numBytes)
_, err := ds.Read(b)
if err != nil {
return err
}

bitvec := NewBitVec(bytesToBits(b, size))
dstv.Set(reflect.ValueOf(bitvec))
return nil
}
24 changes: 24 additions & 0 deletions pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/require"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -247,6 +249,28 @@ func Test_decodeState_decodeMap(t *testing.T) {
}
}

func Test_decodeState_decodeBitVec(t *testing.T) {
for _, tt := range bitVecTests {
t.Run(tt.name, func(t *testing.T) {
dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface()
if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(dst, tt.in) {
t.Errorf("decodeState.unmarshal() = %v, want %v", dst, tt.in)
}
})
}
}

func Test_decodeState_decodeBitVecMaxLen(t *testing.T) {
t.Parallel()
bitvec := NewBitVec(nil)
maxLen10 := []byte{38, 0, 0, 64, 0} // maxLen + 10
err := Unmarshal(maxLen10, &bitvec)
require.Error(t, err, errBitVecTooLong)
}

func Test_unmarshal_optionality(t *testing.T) {
var ptrTests tests
for _, t := range append(tests{}, allTests...) {
Expand Down
19 changes: 19 additions & 0 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ func (es *encodeState) marshal(in interface{}) (err error) {
err = es.encodeBigInt(in)
case *Uint128:
err = es.encodeUint128(in)
case BitVec:
err = es.encodeBitVec(in)
case []byte:
err = es.encodeBytes(in)
case string:
Expand Down Expand Up @@ -423,3 +425,20 @@ func (es *encodeState) encodeUint128(i *Uint128) (err error) {
err = binary.Write(es, binary.LittleEndian, padBytes(i.Bytes(), binary.LittleEndian))
return
}

// encodeBitVec encodes a BitVec
func (es *encodeState) encodeBitVec(bitvec BitVec) (err error) {
if bitvec.Size() > maxLen {
err = fmt.Errorf("%w: %d", errBitVecTooLong, bitvec.Size())
return
}

err = es.encodeUint(bitvec.Size())
if err != nil {
return
}

data := bitvec.Bytes()
_, err = es.Write(data)
return
}
Loading

0 comments on commit 13d1a3b

Please sign in to comment.