Skip to content

Commit

Permalink
update file interface to use shwap types
Browse files Browse the repository at this point in the history
  • Loading branch information
walldiss committed May 28, 2024
1 parent 6e9f1da commit a2df5d6
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 112 deletions.
14 changes: 7 additions & 7 deletions share/eds/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,24 @@ func CollectSharesByNamespace(
utils.SetStatusAndEnd(span, err)
}()

rootCIDs := ipld.FilterRootByNamespace(root, namespace)
if len(rootCIDs) == 0 {
rowIdxs := share.RowsWithNamespace(root, namespace)
if len(rowIdxs) == 0 {
return []share.NamespacedRow{}, nil
}

errGroup, ctx := errgroup.WithContext(ctx)
shares = make([]share.NamespacedRow, len(rootCIDs))
for i, rootCID := range rootCIDs {
shares = make([]share.NamespacedRow, len(rowIdxs))
for i, rowIdx := range rowIdxs {
// shadow loop variables, to ensure correct values are captured
i, rootCID := i, rootCID
rowIdx, rowRoot := rowIdx, root.RowRoots[rowIdx]
errGroup.Go(func() error {
row, proof, err := ipld.GetSharesByNamespace(ctx, bg, rootCID, namespace, len(root.RowRoots))
row, proof, err := ipld.GetSharesByNamespace(ctx, bg, rowRoot, namespace, len(root.RowRoots))
shares[i] = share.NamespacedRow{
Shares: row,
Proof: proof,
}
if err != nil {
return fmt.Errorf("retrieving shares by namespace %s for row %x: %w", namespace.String(), rootCID, err)
return fmt.Errorf("retrieving shares by namespace %s for row %d: %w", namespace.String(), rowIdx, err)
}
return nil
})
Expand Down
5 changes: 3 additions & 2 deletions share/ipld/get_shares.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ func GetShares(ctx context.Context, bg blockservice.BlockGetter, root cid.Cid, s
func GetSharesByNamespace(
ctx context.Context,
bGetter blockservice.BlockGetter,
root cid.Cid,
root []byte,
namespace share.Namespace,
maxShares int,
) ([]share.Share, *nmt.Proof, error) {
rootCid := MustCidFromNamespacedSha256(root)
data := NewNamespaceData(maxShares, namespace, WithLeaves(), WithProofs())
err := data.CollectLeavesByNamespace(ctx, bGetter, root)
err := data.CollectLeavesByNamespace(ctx, bGetter, rootCid)
if err != nil {
return nil, nil, err
}
Expand Down
10 changes: 4 additions & 6 deletions share/ipld/get_shares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ func TestGetSharesByNamespace(t *testing.T) {
rowRoots, err := eds.RowRoots()
require.NoError(t, err)
for _, row := range rowRoots {
rcid := MustCidFromNamespacedSha256(row)
rowShares, _, err := GetSharesByNamespace(ctx, bServ, rcid, namespace, len(rowRoots))
rowShares, _, err := GetSharesByNamespace(ctx, bServ, row, namespace, len(rowRoots))
if errors.Is(err, ErrNamespaceOutsideRange) {
continue
}
Expand Down Expand Up @@ -363,8 +362,7 @@ func TestGetSharesWithProofsByNamespace(t *testing.T) {
rowRoots, err := eds.RowRoots()
require.NoError(t, err)
for _, row := range rowRoots {
rcid := MustCidFromNamespacedSha256(row)
rowShares, proof, err := GetSharesByNamespace(ctx, bServ, rcid, namespace, len(rowRoots))
rowShares, proof, err := GetSharesByNamespace(ctx, bServ, row, namespace, len(rowRoots))
if namespace.IsOutsideRange(row, row) {
require.ErrorIs(t, err, ErrNamespaceOutsideRange)
continue
Expand All @@ -386,15 +384,15 @@ func TestGetSharesWithProofsByNamespace(t *testing.T) {
share.NewSHA256Hasher(),
namespace.ToNMT(),
leaves,
NamespacedSha256FromCID(rcid))
row)
require.True(t, verified)

// verify inclusion
verified = proof.VerifyInclusion(
share.NewSHA256Hasher(),
namespace.ToNMT(),
rowShares,
NamespacedSha256FromCID(rcid))
row)
require.True(t, verified)
}
}
Expand Down
51 changes: 51 additions & 0 deletions share/root.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package share

import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"

"github.com/celestiaorg/celestia-app/pkg/da"
"github.com/celestiaorg/rsmt2d"
)
Expand All @@ -9,6 +15,30 @@ import (
// In practice, it is a commitment to all the Data in a square.
type Root = da.DataAvailabilityHeader

// DataHash is a representation of the Root hash.
type DataHash []byte

func (dh DataHash) Validate() error {
if len(dh) != 32 {
return fmt.Errorf("invalid hash size, expected 32, got %d", len(dh))
}
return nil
}

func (dh DataHash) String() string {
return fmt.Sprintf("%X", []byte(dh))
}

// IsEmptyRoot check whether DataHash corresponds to the root of an empty block EDS.
func (dh DataHash) IsEmptyRoot() bool {
return bytes.Equal(EmptyRoot().Hash(), dh)
}

// NewSHA256Hasher returns a new instance of a SHA-256 hasher.
func NewSHA256Hasher() hash.Hash {
return sha256.New()
}

// NewRoot generates Root(DataAvailabilityHeader) using the
// provided extended data square.
func NewRoot(eds *rsmt2d.ExtendedDataSquare) (*Root, error) {
Expand All @@ -29,3 +59,24 @@ func RowsWithNamespace(root *Root, namespace Namespace) (idxs []int) {
}
return
}

// RootHashForCoordinates returns the root hash for the given coordinates.
func RootHashForCoordinates(r *Root, axisType rsmt2d.Axis, rowIdx, colIdx uint) []byte {
if axisType == rsmt2d.Row {
return r.RowRoots[rowIdx]
}
return r.ColumnRoots[colIdx]
}

// MustDataHashFromString converts a hex string to a valid datahash.
func MustDataHashFromString(datahash string) DataHash {
dh, err := hex.DecodeString(datahash)
if err != nil {
panic(fmt.Sprintf("datahash conversion: passed string was not valid hex: %s", datahash))
}
err = DataHash(dh).Validate()
if err != nil {
panic(fmt.Sprintf("datahash validation: passed hex string failed: %s", err))
}
return dh
}
48 changes: 0 additions & 48 deletions share/share.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package share

import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"

"github.com/celestiaorg/celestia-app/pkg/appconsts"
"github.com/celestiaorg/nmt"
Expand Down Expand Up @@ -72,48 +69,3 @@ func (s *ShareWithProof) Validate(rootHash []byte, x, y, edsSize int) bool {
rootHash,
)
}

// DataHash is a representation of the Root hash.
type DataHash []byte

func (dh DataHash) Validate() error {
if len(dh) != 32 {
return fmt.Errorf("invalid hash size, expected 32, got %d", len(dh))
}
return nil
}

func (dh DataHash) String() string {
return fmt.Sprintf("%X", []byte(dh))
}

// IsEmptyRoot check whether DataHash corresponds to the root of an empty block EDS.
func (dh DataHash) IsEmptyRoot() bool {
return bytes.Equal(EmptyRoot().Hash(), dh)
}

// MustDataHashFromString converts a hex string to a valid datahash.
func MustDataHashFromString(datahash string) DataHash {
dh, err := hex.DecodeString(datahash)
if err != nil {
panic(fmt.Sprintf("datahash conversion: passed string was not valid hex: %s", datahash))
}
err = DataHash(dh).Validate()
if err != nil {
panic(fmt.Sprintf("datahash validation: passed hex string failed: %s", err))
}
return dh
}

// NewSHA256Hasher returns a new instance of a SHA-256 hasher.
func NewSHA256Hasher() hash.Hash {
return sha256.New()
}

// RootHashForCoordinates returns the root hash for the given coordinates.
func RootHashForCoordinates(r *Root, axisType rsmt2d.Axis, rowIdx, colIdx uint) []byte {
if axisType == rsmt2d.Row {
return r.RowRoots[rowIdx]
}
return r.ColumnRoots[colIdx]
}
19 changes: 19 additions & 0 deletions share/store/file/axis_half.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package file

import (
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

type AxisHalf struct {
Shares []share.Share
IsParity bool
}

func (a AxisHalf) ToRow() shwap.Row {
side := shwap.Left
if a.IsParity {
side = shwap.Right
}
return shwap.NewRow(a.Shares, side)
}
11 changes: 6 additions & 5 deletions share/store/eds_file.go → share/store/file/eds_file.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package store
package file

import (
"context"
Expand All @@ -7,18 +7,19 @@ import (
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

type EdsFile interface {
io.Closer
// Size returns square size of the file.
Size() int
// Share returns share and corresponding proof for the given axis and share index in this axis.
Share(ctx context.Context, x, y int) (*share.ShareWithProof, error)
// AxisHalf returns shares for the first half of the axis of the given type and index.
AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error)
Share(ctx context.Context, rowIdx, colIdx int) (*shwap.Sample, error)
// AxisHalf returns Shares for the first half of the axis of the given type and index.
AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error)
// Data returns data for the given namespace and row index.
Data(ctx context.Context, namespace share.Namespace, rowIdx int) (share.NamespacedRow, error)
Data(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error)
// EDS returns extended data square stored in the file.
EDS(ctx context.Context) (*rsmt2d.ExtendedDataSquare, error)
}
77 changes: 42 additions & 35 deletions share/store/mem_file.go → share/store/file/mem_file.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package store
package file

import (
"context"
Expand All @@ -9,6 +9,7 @@ import (

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/ipld"
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ EdsFile = (*MemFile)(nil)
Expand All @@ -27,12 +28,12 @@ func (f *MemFile) Size() int {

func (f *MemFile) Share(
_ context.Context,
x, y int,
) (*share.ShareWithProof, error) {
rowIdx, colIdx int,
) (*shwap.Sample, error) {
axisType := rsmt2d.Row
axisIdx, shrIdx := y, x
axisIdx, shrIdx := rowIdx, colIdx

shares := f.axis(axisType, axisIdx)
shares := getAxis(f.Eds, axisType, axisIdx)
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(f.Size()/2), uint(axisIdx))
for _, shr := range shares {
err := tree.Push(shr)
Expand All @@ -46,62 +47,68 @@ func (f *MemFile) Share(
return nil, err
}

return &share.ShareWithProof{
Share: shares[shrIdx],
Proof: &proof,
Axis: axisType,
return &shwap.Sample{
Share: shares[shrIdx],
Proof: &proof,
ProofType: axisType,
}, nil
}

func (f *MemFile) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) {
return f.axis(axisType, axisIdx)[:f.Size()/2], nil
func (f *MemFile) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
return AxisHalf{
Shares: getAxis(f.Eds, axisType, axisIdx)[:f.Size()/2],
IsParity: false,
}, nil
}

func (f *MemFile) Data(_ context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error) {
shares := getAxis(f.Eds, rsmt2d.Row, rowIdx)
return ndDataFromShares(shares, namespace, rowIdx)
}

func (f *MemFile) EDS(_ context.Context) (*rsmt2d.ExtendedDataSquare, error) {
return f.Eds, nil
}

func (f *MemFile) Data(_ context.Context, namespace share.Namespace, rowIdx int) (share.NamespacedRow, error) {
shares := f.axis(rsmt2d.Row, rowIdx)
func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share {
switch axisType {
case rsmt2d.Row:
return eds.Row(uint(axisIdx))
case rsmt2d.Col:
return eds.Col(uint(axisIdx))
default:
panic("unknown axis")
}
}

func ndDataFromShares(shares []share.Share, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error) {
bserv := ipld.NewMemBlockservice()
batchAdder := ipld.NewNmtNodeAdder(context.TODO(), bserv, ipld.MaxSizeBatchOption(len(shares)))
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(rowIdx),
nmt.NodeVisitor(batchAdder.Visit))
for _, shr := range shares {
err := tree.Push(shr)
if err != nil {
return share.NamespacedRow{}, err
return shwap.RowNamespaceData{}, err
}
}

root, err := tree.Root()
if err != nil {
return share.NamespacedRow{}, err
return shwap.RowNamespaceData{}, err
}

err = batchAdder.Commit()
if err != nil {
return share.NamespacedRow{}, err
return shwap.RowNamespaceData{}, err
}

cid := ipld.MustCidFromNamespacedSha256(root)
row, proof, err := ipld.GetSharesByNamespace(context.TODO(), bserv, cid, namespace, len(shares))
row, proof, err := ipld.GetSharesByNamespace(context.TODO(), bserv, root, namespace, len(shares))
if err != nil {
return share.NamespacedRow{}, err
return shwap.RowNamespaceData{}, err
}
return share.NamespacedRow{
return shwap.RowNamespaceData{
Shares: row,
Proof: proof,
}, nil
}

func (f *MemFile) EDS(_ context.Context) (*rsmt2d.ExtendedDataSquare, error) {
return f.Eds, nil
}

func (f *MemFile) axis(axisType rsmt2d.Axis, axisIdx int) []share.Share {
switch axisType {
case rsmt2d.Row:
return f.Eds.Row(uint(axisIdx))
case rsmt2d.Col:
return f.Eds.Col(uint(axisIdx))
default:
panic("unknown axis")
}
}
Loading

0 comments on commit a2df5d6

Please sign in to comment.