Skip to content

Commit

Permalink
feat/scale: add BitVec (#3253)
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishkatn authored and kishansagathiya committed Jul 15, 2024
1 parent f5b309b commit 1bcad51
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 1 deletion.
33 changes: 33 additions & 0 deletions pkg/scale/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,39 @@ SCALE uses a compact encoding for variable width unsigned integers.
| `Compact<u64>` | `uint` |
| `Compact<u128>` | `*big.Int` |

### BitVec

SCALE uses a bit vector to encode a sequence of booleans. The bit vector is encoded as a compact length followed by a byte array.
The byte array is a sequence of bytes where each bit represents a boolean value.

**Note: This is a work in progress.**
The current implementation of BitVec is just bare bones. It does not implement any of the methods of the `BitVec` type in Rust.

```go
import (
"fmt"
"github.com/ChainSafe/gossamer/pkg/scale"
)

func ExampleBitVec() {
bitvec := NewBitVec([]bool{true, false, true, false, true, false, true, false})
bytes, err := scale.Marshal(bitvec)
if err != nil {
panic(err)
}

var unmarshaled BitVec
err = scale.Unmarshal(bytes, &unmarshaled)
if err != nil {
panic(err)
}

// [true false true false true false true false]
fmt.Printf("%v", unmarshaled.Bits())
}
```


## Usage

### Basic Example
Expand Down
87 changes: 87 additions & 0 deletions pkg/scale/bitvec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package scale

const (
// maxLen equivalent of `ARCH32BIT_BITSLICE_MAX_BITS` in parity-scale-codec
maxLen = 268435455
// byteSize is the number of bits in a byte
byteSize = 8
)

// BitVec is the implementation of the bit vector
type BitVec struct {
bits []bool
}

// NewBitVec returns a new BitVec with the given bits
// This isn't a complete implementation of the bit vector
// It is only used for ParachainHost runtime exports
// TODO: Implement the full bit vector
// https://github.com/ChainSafe/gossamer/issues/3248
func NewBitVec(bits []bool) BitVec {
return BitVec{
bits: bits,
}
}

// Bits returns the bits in the BitVec
func (bv *BitVec) Bits() []bool {
return bv.bits
}

// Bytes returns the byte representation of the BitVec.Bits
func (bv *BitVec) Bytes() []byte {
return bitsToBytes(bv.bits)
}

// Size returns the number of bits in the BitVec
func (bv *BitVec) Size() uint {
return uint(len(bv.bits))
}

// bitsToBytes converts a slice of bits to a slice of bytes
// Uses lsb ordering
// TODO: Implement msb ordering
// https://github.com/ChainSafe/gossamer/issues/3248
func bitsToBytes(bits []bool) []byte {
bitLength := len(bits)
numOfBytes := (bitLength + (byteSize - 1)) / byteSize
bytes := make([]byte, numOfBytes)

if len(bits)%byteSize != 0 {
// Pad with zeros to make the number of bits a multiple of byteSize
pad := make([]bool, byteSize-len(bits)%byteSize)
bits = append(bits, pad...)
}

for i := 0; i < bitLength; i++ {
if bits[i] {
byteIndex := i / byteSize
bitIndex := i % byteSize
bytes[byteIndex] |= 1 << bitIndex
}
}

return bytes
}

// bytesToBits converts a slice of bytes to a slice of bits
func bytesToBits(b []byte, size uint) []bool {
var bits []bool
for _, uint8val := range b {
end := size
if end > byteSize {
end = byteSize
}
size -= end

for j := uint(0); j < end; j++ {
bit := (uint8val>>j)&1 == 1
bits = append(bits, bit)
}
}

return bits
}
186 changes: 186 additions & 0 deletions pkg/scale/bitvec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package scale

import (
"testing"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/stretchr/testify/require"
)

func TestBitVec(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
wantBitVec BitVec
wantErr bool
}{
{
name: "empty_bitvec",
in: "0x00",
wantBitVec: NewBitVec(nil),
wantErr: false,
},
{
name: "1_byte",
in: "0x2055",
wantBitVec: NewBitVec([]bool{true, false, true, false, true, false, true, false}),
wantErr: false,
},
{
name: "4_bytes",
in: "0x645536aa01",
wantBitVec: NewBitVec([]bool{
true, false, true, false, true, false, true, false,
false, true, true, false, true, true, false, false,
false, true, false, true, false, true, false, true,
true,
}),
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
resultBytes, err := common.HexToBytes(tt.in)
require.NoError(t, err)

bv := NewBitVec(nil)
err = Unmarshal(resultBytes, &bv)
require.NoError(t, err)

require.Equal(t, tt.wantBitVec.Size(), bv.Size())
require.Equal(t, tt.wantBitVec.Size(), bv.Size())

b, err := Marshal(bv)
require.NoError(t, err)
require.Equal(t, resultBytes, b)
})
}
}

func TestBitVecBytes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in BitVec
want []byte
wantErr bool
}{
{
name: "empty_bitvec",
in: NewBitVec(nil),
want: []byte(nil),
wantErr: false,
},
{
name: "1_byte",
in: NewBitVec([]bool{true, false, true, false, true, false, true, false}),
want: []byte{0x55},
wantErr: false,
},
{
name: "4_bytes",
in: NewBitVec([]bool{
true, false, true, false, true, false, true, false,
false, true, true, false, true, true, false, false,
false, true, false, true, false, true, false, true,
true,
}),
want: []byte{0x55, 0x36, 0xaa, 0x1},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, tt.in.Bytes())
})
}
}

func TestBitVecBytesToBits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in []byte
want []bool
wantErr bool
}{
{
name: "empty",
in: []byte(nil),
want: []bool(nil),
wantErr: false,
},
{
name: "1_byte",
in: []byte{0x55},
want: []bool{true, false, true, false, true, false, true, false},
wantErr: false,
},
{
name: "4_bytes",
in: []byte{0x55, 0x36, 0xaa, 0x1},
want: []bool{
true, false, true, false, true, false, true, false,
false, true, true, false, true, true, false, false,
false, true, false, true, false, true, false, true,
true, false, false, false, false, false, false, false,
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, bytesToBits(tt.in, uint(len(tt.in)*byteSize)))
})
}
}

func TestBitVecBitsToBytes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in []bool
want []byte
wantErr bool
}{
{
name: "empty",
in: []bool(nil),
want: []byte{},
wantErr: false,
},
{
name: "1_byte",
in: []bool{true, false, true, false, true, false, true, false},
want: []byte{0x55},
wantErr: false,
},
{
name: "4_bytes",
in: []bool{
true, false, true, false, true, false, true, false,
false, true, true, false, true, true, false, false,
false, true, false, true, false, true, false, true,
true,
},
want: []byte{0x55, 0x36, 0xaa, 0x1},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, bitsToBytes(tt.in))
})
}
}
30 changes: 30 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
err = ds.decodeBigInt(dstv)
case *Uint128:
err = ds.decodeUint128(dstv)
case BitVec:
err = ds.decodeBitVec(dstv)
case int, uint:
err = ds.decodeUint(dstv)
case int8, uint8, int16, uint16, int32, uint32, int64, uint64:
Expand Down Expand Up @@ -747,3 +749,31 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) {
dstv.Set(reflect.ValueOf(ui128))
return
}

// decodeBitVec accepts a byte array representing a SCALE encoded
// BitVec and performs SCALE decoding of the BitVec
func (ds *decodeState) decodeBitVec(dstv reflect.Value) error {
var size uint
if err := ds.decodeUint(reflect.ValueOf(&size).Elem()); err != nil {
return err
}

if size > maxLen {
return fmt.Errorf("%w: %d", errBitVecTooLong, size)
}

numBytes := (size + (byteSize - 1)) / byteSize
b := make([]byte, numBytes)
_, err := ds.Read(b)
if err != nil {
return err
}

bitvec := NewBitVec(bytesToBits(b, size))
if len(bitvec.bits) > int(size) {
return fmt.Errorf("bitvec length mismatch: expected %d, got %d", size, len(bitvec.bits))
}

dstv.Set(reflect.ValueOf(bitvec))
return nil
}
24 changes: 24 additions & 0 deletions pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/require"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -254,6 +256,28 @@ func Test_decodeState_decodeMap(t *testing.T) {
}
}

func Test_decodeState_decodeBitVec(t *testing.T) {
for _, tt := range bitVecTests {
t.Run(tt.name, func(t *testing.T) {
dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface()
if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(dst, tt.in) {
t.Errorf("decodeState.unmarshal() = %v, want %v", dst, tt.in)
}
})
}
}

func Test_decodeState_decodeBitVecMaxLen(t *testing.T) {
t.Parallel()
bitvec := NewBitVec(nil)
maxLen10 := []byte{38, 0, 0, 64, 0} // maxLen + 10
err := Unmarshal(maxLen10, &bitvec)
require.Error(t, err, errBitVecTooLong)
}

func Test_unmarshal_optionality(t *testing.T) {
var ptrTests tests
for _, t := range append(tests{}, allTests...) {
Expand Down
Loading

0 comments on commit 1bcad51

Please sign in to comment.