Skip to content

Commit

Permalink
refactor(shwap): Extract eds interface (#3452)
Browse files Browse the repository at this point in the history
- Reintroduce file interface as eds interface. Change aims to allow
usage of EDS interface outside of storage package and to be high level
interface of EDS methods.
- Renames of eds interface methods to align with returned shwap types
names
    - Share() -> Sample
    - Data -> Row data
- Extracts New<shwap_type_name>FromEDS functions to eds file methods
    - moves associated tests to eds pkg
    
**Additional refactoring:**
- **Change Interface Name**: Realized that 'EDS' is a terrible name for
an interface. Renamed `eds.EDS` to `eds.Accessor` to more accurately
reflect its functionality rather than its internal content.

- **Separate Closer**: Extracted `Closer` from `Accessor`. Now it is
available in a new composite interface `AccessorCloser`.

- **Rename InMem**: Renamed `InMem` to `rsmt2d` to better align with its
usage.

- **Decouple NamespacedData**: Separated `NamespacedData` from the
`rsmt2d` implementation. It is now a standalone function.

- **Update EDS Method**: Replaced the `EDS()` method with `Flattened`,
similar to `rsmt2d`. Considered introducing two separate methods,
`Flattened` and `FlattenedODS`, with the latter to be potentially added
later. Proposed to park this suggestion in an issue for future
consideration.
  • Loading branch information
walldiss committed Jun 5, 2024
1 parent 9e82fd6 commit 60e757e
Show file tree
Hide file tree
Showing 16 changed files with 372 additions and 372 deletions.
34 changes: 34 additions & 0 deletions share/new_eds/accessor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package eds

import (
"context"
"io"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

// Accessor is an interface for accessing extended data square data.
type Accessor interface {
// Size returns square size of the Accessor.
Size(ctx context.Context) int
// Sample returns share and corresponding proof for row and column indices. Implementation can
// choose which axis to use for proof. Chosen axis for proof should be indicated in the returned
// Sample.
Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error)
// AxisHalf returns half of shares axis of the given type and index. Side is determined by
// implementation. Implementations should indicate the side in the returned AxisHalf.
AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error)
// RowNamespaceData returns data for the given namespace and row index.
RowNamespaceData(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error)
// Shares returns data shares extracted from the Accessor.
Shares(ctx context.Context) ([]share.Share, error)
}

// AccessorCloser is an interface that groups Accessor and io.Closer interfaces.
type AccessorCloser interface {
Accessor
io.Closer
}
7 changes: 5 additions & 2 deletions share/store/file/axis_half.go → share/new_eds/axis_half.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package file
package eds

import (
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

// AxisHalf represents a half of data for a row or column in the EDS.
type AxisHalf struct {
Shares []share.Share
Shares []share.Share
// IsParity indicates whether the half is parity or data.
IsParity bool
}

// ToRow converts the AxisHalf to a shwap.Row.
func (a AxisHalf) ToRow() shwap.Row {
side := shwap.Left
if a.IsParity {
Expand Down
31 changes: 31 additions & 0 deletions share/new_eds/nd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package eds

import (
"context"
"fmt"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

// NamespacedData extracts shares for a specific namespace from an EDS, considering
// each row independently. It uses root to determine which rows to extract data from,
// avoiding the need to recalculate the row roots for each row.
func NamespacedData(
ctx context.Context,
root *share.Root,
eds Accessor,
namespace share.Namespace,
) (shwap.NamespacedData, error) {
rowIdxs := share.RowsWithNamespace(root, namespace)
rows := make(shwap.NamespacedData, len(rowIdxs))
var err error
for i, idx := range rowIdxs {
rows[i], err = eds.RowNamespaceData(ctx, namespace, idx)
if err != nil {
return nil, fmt.Errorf("failed to process row %d: %w", idx, err)
}
}

return rows, nil
}
32 changes: 32 additions & 0 deletions share/new_eds/nd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package eds

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/sharetest"
)

func TestNamespacedData(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)

const odsSize = 8
sharesAmount := odsSize * odsSize
namespace := sharetest.RandV0Namespace()
for amount := 1; amount < sharesAmount; amount++ {
eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize)
rsmt2d := Rsmt2D{ExtendedDataSquare: eds}
nd, err := NamespacedData(ctx, root, rsmt2d, namespace)
require.NoError(t, err)
require.True(t, len(nd) > 0)
require.Len(t, nd.Flatten(), amount)

err = nd.Validate(root, namespace)
require.NoError(t, err)
}
}
116 changes: 116 additions & 0 deletions share/new_eds/rsmt2d.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package eds

import (
"context"
"fmt"

"github.com/celestiaorg/celestia-app/pkg/wrapper"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ Accessor = Rsmt2D{}

// Rsmt2D is a rsmt2d based in-memory implementation of Accessor.
type Rsmt2D struct {
*rsmt2d.ExtendedDataSquare
}

// Size returns the size of the Extended Data Square.
func (eds Rsmt2D) Size(context.Context) int {
return int(eds.Width())
}

// Sample returns share and corresponding proof for row and column indices.
func (eds Rsmt2D) Sample(
_ context.Context,
rowIdx, colIdx int,
) (shwap.Sample, error) {
return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row)
}

// SampleForProofAxis samples a share from an Extended Data Square based on the provided
// row and column indices and proof axis. It returns a sample with the share and proof.
func (eds Rsmt2D) SampleForProofAxis(
rowIdx, colIdx int,
proofType rsmt2d.Axis,
) (shwap.Sample, error) {
axisIdx, shrIdx := relativeIndexes(rowIdx, colIdx, proofType)
shares := getAxis(eds.ExtendedDataSquare, proofType, axisIdx)

tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(eds.Width()/2), uint(axisIdx))
for _, shr := range shares {
err := tree.Push(shr)
if err != nil {
return shwap.Sample{}, fmt.Errorf("while pushing shares to NMT: %w", err)
}
}

prf, err := tree.ProveRange(shrIdx, shrIdx+1)
if err != nil {
return shwap.Sample{}, fmt.Errorf("while proving range share over NMT: %w", err)
}

return shwap.Sample{
Share: shares[shrIdx],
Proof: &prf,
ProofType: proofType,
}, nil
}

// AxisHalf returns Shares for the first half of the axis of the given type and index.
func (eds Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
shares := getAxis(eds.ExtendedDataSquare, axisType, axisIdx)
halfShares := shares[:eds.Width()/2]
return AxisHalf{
Shares: halfShares,
IsParity: false,
}, nil
}

// HalfRow constructs a new shwap.Row from an Extended Data Square based on the specified index and
// side.
func (eds Rsmt2D) HalfRow(idx int, side shwap.RowSide) shwap.Row {
shares := eds.ExtendedDataSquare.Row(uint(idx))
return shwap.RowFromShares(shares, side)
}

// RowNamespaceData returns data for the given namespace and row index.
func (eds Rsmt2D) RowNamespaceData(
_ context.Context,
namespace share.Namespace,
rowIdx int,
) (shwap.RowNamespaceData, error) {
shares := eds.Row(uint(rowIdx))
return shwap.RowNamespaceDataFromShares(shares, namespace, rowIdx)
}

// Shares returns data shares extracted from the EDS. It returns new copy of the shares each
// time.
func (eds Rsmt2D) Shares(_ context.Context) ([]share.Share, error) {
return eds.ExtendedDataSquare.Flattened(), nil
}

func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share {
switch axisType {
case rsmt2d.Row:
return eds.Row(uint(axisIdx))
case rsmt2d.Col:
return eds.Col(uint(axisIdx))
default:
panic("unknown axis")
}
}

func relativeIndexes(rowIdx, colIdx int, axisType rsmt2d.Axis) (axisIdx, shrIdx int) {
switch axisType {
case rsmt2d.Row:
return rowIdx, colIdx
case rsmt2d.Col:
return colIdx, rowIdx
default:
panic(fmt.Sprintf("invalid proof type: %d", axisType))
}
}
74 changes: 74 additions & 0 deletions share/new_eds/rsmt2d_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package eds

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/shwap"
)

func TestRsmt2dSample(t *testing.T) {
eds, root := randRsmt2dAccsessor(t, 8)

width := int(eds.Width())
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
shr, err := eds.Sample(context.TODO(), rowIdx, colIdx)
require.NoError(t, err)

err = shr.Validate(root, rowIdx, colIdx)
require.NoError(t, err)
}
}
}

func TestRsmt2dHalfRowFrom(t *testing.T) {
const odsSize = 8
eds, _ := randRsmt2dAccsessor(t, odsSize)

for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for _, side := range []shwap.RowSide{shwap.Left, shwap.Right} {
row := eds.HalfRow(rowIdx, side)

want := eds.Row(uint(rowIdx))
shares, err := row.Shares()
require.NoError(t, err)
require.Equal(t, want, shares)
}
}
}

func TestRsmt2dSampleForProofAxis(t *testing.T) {
const odsSize = 8
eds := edstest.RandEDS(t, odsSize)
accessor := Rsmt2D{ExtendedDataSquare: eds}

for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} {
for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for colIdx := 0; colIdx < odsSize*2; colIdx++ {
sample, err := accessor.SampleForProofAxis(rowIdx, colIdx, proofType)
require.NoError(t, err)

want := eds.GetCell(uint(rowIdx), uint(colIdx))
require.Equal(t, want, sample.Share)
require.Equal(t, proofType, sample.ProofType)
require.NotNil(t, sample.Proof)
require.Equal(t, sample.Proof.End()-sample.Proof.Start(), 1)
require.Len(t, sample.Proof.Nodes(), 4)
}
}
}
}

func randRsmt2dAccsessor(t *testing.T, size int) (Rsmt2D, *share.Root) {
eds := edstest.RandEDS(t, size)
root, err := share.NewRoot(eds)
require.NoError(t, err)
return Rsmt2D{ExtendedDataSquare: eds}, root
}
26 changes: 0 additions & 26 deletions share/shwap/namespace_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,13 @@ package shwap
import (
"fmt"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
)

// NamespacedData stores collections of RowNamespaceData, each representing shares and their proofs
// within a namespace.
type NamespacedData []RowNamespaceData

// NamespacedDataFromEDS extracts shares for a specific namespace from an EDS, considering
// each row independently.
func NamespacedDataFromEDS(
square *rsmt2d.ExtendedDataSquare,
namespace share.Namespace,
) (NamespacedData, error) {
root, err := share.NewRoot(square)
if err != nil {
return nil, fmt.Errorf("error computing root: %w", err)
}

rowIdxs := share.RowsWithNamespace(root, namespace)
rows := make(NamespacedData, len(rowIdxs))
for i, idx := range rowIdxs {
shares := square.Row(uint(idx))
rows[i], err = RowNamespaceDataFromShares(shares, namespace, idx)
if err != nil {
return nil, fmt.Errorf("failed to process row %d: %w", idx, err)
}
}

return rows, nil
}

// Flatten combines all shares from all rows within the namespace into a single slice.
func (ns NamespacedData) Flatten() []share.Share {
var shares []share.Share
Expand Down
14 changes: 7 additions & 7 deletions share/shwap/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"

"github.com/celestiaorg/celestia-app/pkg/wrapper"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap/pb"
Expand Down Expand Up @@ -35,14 +34,12 @@ func NewRow(halfShares []share.Share, side RowSide) Row {

// RowFromEDS constructs a new Row from an Extended Data Square based on the specified index and
// side.
func RowFromEDS(square *rsmt2d.ExtendedDataSquare, idx int, side RowSide) Row {
sqrLn := int(square.Width())
shares := square.Row(uint(idx))
func RowFromShares(shares []share.Share, side RowSide) Row {
var halfShares []share.Share
if side == Right {
halfShares = shares[sqrLn/2:] // Take the right half of the shares.
halfShares = shares[len(shares)/2:] // Take the right half of the shares.
} else {
halfShares = shares[:sqrLn/2] // Take the left half of the shares.
halfShares = shares[:len(shares)/2] // Take the left half of the shares.
}

return NewRow(halfShares, side)
Expand Down Expand Up @@ -95,7 +92,10 @@ func (r Row) Validate(dah *share.Root, idx int) error {
return fmt.Errorf("invalid RowSide: %d", r.side)
}

return r.verifyInclusion(dah, idx)
if err := r.verifyInclusion(dah, idx); err != nil {
return fmt.Errorf("%w: %w", ErrFailedVerification, err)
}
return nil
}

// verifyInclusion verifies the integrity of the row's shares against the provided root hash for the
Expand Down
Loading

0 comments on commit 60e757e

Please sign in to comment.