Skip to content

Commit

Permalink
Update p2p proto definitons
Browse files Browse the repository at this point in the history
and make each request use seperate pids
  • Loading branch information
omerfirmak committed Sep 20, 2023
1 parent 0bfac9d commit 5b8088c
Show file tree
Hide file tree
Showing 22 changed files with 4,497 additions and 2,195 deletions.
115 changes: 31 additions & 84 deletions p2p/starknet/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ import (
type NewStreamFunc func(ctx context.Context, pids ...protocol.ID) (network.Stream, error)

type Client struct {
newStream NewStreamFunc
protocolID protocol.ID
log utils.Logger
newStream NewStreamFunc
network utils.Network
log utils.Logger
}

func NewClient(newStream NewStreamFunc, protocolID protocol.ID, log utils.Logger) *Client {
func NewClient(newStream NewStreamFunc, snNetwork utils.Network, log utils.Logger) *Client {
return &Client{
newStream: newStream,
protocolID: protocolID,
log: log,
newStream: newStream,
network: snNetwork,
log: log,
}
}

func (c *Client) sendAndCloseWrite(stream network.Stream, req proto.Message) error {
func sendAndCloseWrite(stream network.Stream, req proto.Message) error {
reqBytes, err := proto.Marshal(req)
if err != nil {
return err
Expand All @@ -39,101 +39,48 @@ func (c *Client) sendAndCloseWrite(stream network.Stream, req proto.Message) err
return stream.CloseWrite()
}

func (c *Client) receiveInto(stream network.Stream, res proto.Message) error {
func receiveInto(stream network.Stream, res proto.Message) error {
return protodelim.UnmarshalFrom(&byteReader{stream}, res)
}

func (c *Client) sendAndReceiveInto(ctx context.Context, req, res proto.Message) error {
stream, err := c.newStream(ctx, c.protocolID)
if err != nil {
return err
}
defer stream.Close() // todo: dont ignore close errors

if err = c.sendAndCloseWrite(stream, req); err != nil {
return err
}

return c.receiveInto(stream, res)
}

func (c *Client) GetBlocks(ctx context.Context, req *spec.GetBlocks) (Stream[*spec.BlockHeader], error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetBlocks{
GetBlocks: req,
},
}

stream, err := c.newStream(ctx, c.protocolID)
func requestAndReceiveStream[ReqT proto.Message, ResT proto.Message](ctx context.Context,
newStream NewStreamFunc, protocolID protocol.ID, req ReqT,
) (Stream[ResT], error) {
stream, err := newStream(ctx, protocolID)
if err != nil {
return nil, err
}
if err := c.sendAndCloseWrite(stream, &wrappedReq); err != nil {
if err := sendAndCloseWrite(stream, req); err != nil {
return nil, err
}

return func() (*spec.BlockHeader, bool) {
var res spec.BlockHeader
if err := c.receiveInto(stream, &res); err != nil {
return func() (ResT, bool) {
var zero ResT
res := zero.ProtoReflect().New().Interface()
if err := receiveInto(stream, res); err != nil {
stream.Close() // todo: dont ignore close errors
return nil, false
return zero, false
}
return &res, true
return res.(ResT), true
}, nil
}

func (c *Client) GetSignatures(ctx context.Context, req *spec.GetSignatures) (*spec.Signatures, error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetSignatures{
GetSignatures: req,
},
}

var res spec.Signatures
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
return nil, err
}
return &res, nil
func (c *Client) RequestBlockHeaders(ctx context.Context, req *spec.BlockHeadersRequest) (Stream[*spec.BlockHeadersResponse], error) {
return requestAndReceiveStream[*spec.BlockHeadersRequest, *spec.BlockHeadersResponse](ctx, c.newStream, BlockHeadersPID(c.network), req)
}

func (c *Client) GetEvents(ctx context.Context, req *spec.GetEvents) (*spec.Events, error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetEvents{
GetEvents: req,
},
}

var res spec.Events
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
return nil, err
}
return &res, nil
func (c *Client) RequestBlockBodies(ctx context.Context, req *spec.BlockBodiesRequest) (Stream[*spec.BlockBodiesResponse], error) {
return requestAndReceiveStream[*spec.BlockBodiesRequest, *spec.BlockBodiesResponse](ctx, c.newStream, BlockBodiesPID(c.network), req)
}

func (c *Client) GetReceipts(ctx context.Context, req *spec.GetReceipts) (*spec.Receipts, error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetReceipts{
GetReceipts: req,
},
}

var res spec.Receipts
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
return nil, err
}
return &res, nil
func (c *Client) RequestEvents(ctx context.Context, req *spec.EventsRequest) (Stream[*spec.EventsResponse], error) {
return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(c.network), req)
}

func (c *Client) GetTransactions(ctx context.Context, req *spec.GetTransactions) (*spec.Transactions, error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetTransactions{
GetTransactions: req,
},
}
func (c *Client) RequestReceipts(ctx context.Context, req *spec.ReceiptsRequest) (Stream[*spec.ReceiptsResponse], error) {
return requestAndReceiveStream[*spec.ReceiptsRequest, *spec.ReceiptsResponse](ctx, c.newStream, ReceiptsPID(c.network), req)
}

var res spec.Transactions
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
return nil, err
}
return &res, nil
func (c *Client) RequestTransactions(ctx context.Context, req *spec.TransactionsRequest) (Stream[*spec.TransactionsResponse], error) {
return requestAndReceiveStream[*spec.TransactionsRequest, *spec.TransactionsResponse](ctx, c.newStream, TransactionsPID(c.network), req)
}
159 changes: 76 additions & 83 deletions p2p/starknet/handlers.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
//go:generate protoc --go_out=./ --proto_path=./ --go_opt=Mp2p/proto/requests.proto=./spec --go_opt=Mp2p/proto/transaction.proto=./spec --go_opt=Mp2p/proto/state.proto=./spec --go_opt=Mp2p/proto/snapshot.proto=./spec --go_opt=Mp2p/proto/receipt.proto=./spec --go_opt=Mp2p/proto/mempool.proto=./spec --go_opt=Mp2p/proto/event.proto=./spec --go_opt=Mp2p/proto/block.proto=./spec --go_opt=Mp2p/proto/common.proto=./spec p2p/proto/transaction.proto p2p/proto/state.proto p2p/proto/snapshot.proto p2p/proto/common.proto p2p/proto/block.proto p2p/proto/event.proto p2p/proto/receipt.proto p2p/proto/requests.proto
//go:generate protoc --go_out=./ --proto_path=./ --go_opt=Mp2p/proto/transaction.proto=./spec --go_opt=Mp2p/proto/state.proto=./spec --go_opt=Mp2p/proto/snapshot.proto=./spec --go_opt=Mp2p/proto/receipt.proto=./spec --go_opt=Mp2p/proto/mempool.proto=./spec --go_opt=Mp2p/proto/event.proto=./spec --go_opt=Mp2p/proto/block.proto=./spec --go_opt=Mp2p/proto/common.proto=./spec p2p/proto/transaction.proto p2p/proto/state.proto p2p/proto/snapshot.proto p2p/proto/common.proto p2p/proto/block.proto p2p/proto/event.proto p2p/proto/receipt.proto
package starknet

import (
"bytes"
"errors"
"fmt"
"sync"

"github.com/NethermindEth/juno/adapters/core2p2p"
"github.com/NethermindEth/juno/adapters/p2p2core"
"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/p2p/starknet/spec"
"github.com/NethermindEth/juno/utils"
"github.com/libp2p/go-libp2p/core/network"
Expand Down Expand Up @@ -43,133 +39,130 @@ func getBuffer() *bytes.Buffer {
return buffer
}

func (h *Handler) StreamHandler(stream network.Stream) {
func streamHandler[ReqT proto.Message](stream network.Stream,
reqHandler func(req ReqT) (Stream[proto.Message], error), log utils.SimpleLogger,
) {
defer func() {
if err := stream.Close(); err != nil {
h.log.Debugw("Error closing stream", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
log.Debugw("Error closing stream", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)

Check warning on line 47 in p2p/starknet/handlers.go

View check run for this annotation

Codecov / codecov/patch

p2p/starknet/handlers.go#L47

Added line #L47 was not covered by tests
}
}()

buffer := getBuffer()
defer bufferPool.Put(buffer)

if _, err := buffer.ReadFrom(stream); err != nil {
h.log.Debugw("Error reading from stream", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
log.Debugw("Error reading from stream", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)

Check warning on line 55 in p2p/starknet/handlers.go

View check run for this annotation

Codecov / codecov/patch

p2p/starknet/handlers.go#L55

Added line #L55 was not covered by tests
return
}

var req spec.Request
if err := proto.Unmarshal(buffer.Bytes(), &req); err != nil {
h.log.Debugw("Error unmarshalling message", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
var zero ReqT
req := zero.ProtoReflect().New().Interface()
if err := proto.Unmarshal(buffer.Bytes(), req); err != nil {
log.Debugw("Error unmarshalling message", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)

Check warning on line 62 in p2p/starknet/handlers.go

View check run for this annotation

Codecov / codecov/patch

p2p/starknet/handlers.go#L62

Added line #L62 was not covered by tests
return
}

response, err := h.reqHandler(&req)
response, err := reqHandler(req.(ReqT))
if err != nil {
h.log.Debugw("Error handling request", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err, "request", req.String())
log.Debugw("Error handling request", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)

Check warning on line 68 in p2p/starknet/handlers.go

View check run for this annotation

Codecov / codecov/patch

p2p/starknet/handlers.go#L68

Added line #L68 was not covered by tests
return
}

for msg, valid := response(); valid; msg, valid = response() {
if _, err := protodelim.MarshalTo(stream, msg); err != nil { // todo: figure out if we need buffered io here
h.log.Debugw("Error writing response", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
log.Debugw("Error writing response", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)

Check warning on line 74 in p2p/starknet/handlers.go

View check run for this annotation

Codecov / codecov/patch

p2p/starknet/handlers.go#L74

Added line #L74 was not covered by tests
}
}
}

func (h *Handler) reqHandler(req *spec.Request) (Stream[proto.Message], error) {
var singleResponse proto.Message
var err error
switch typedReq := req.GetReq().(type) {
case *spec.Request_GetBlocks:
return h.HandleGetBlocks(typedReq.GetBlocks)
case *spec.Request_GetSignatures:
singleResponse, err = h.HandleGetSignatures(typedReq.GetSignatures)
case *spec.Request_GetEvents:
singleResponse, err = h.HandleGetEvents(typedReq.GetEvents)
case *spec.Request_GetReceipts:
singleResponse, err = h.HandleGetReceipts(typedReq.GetReceipts)
case *spec.Request_GetTransactions:
singleResponse, err = h.HandleGetTransactions(typedReq.GetTransactions)
default:
return nil, fmt.Errorf("unhandled request %T", typedReq)
}
func (h *Handler) BlockHeadersHandler(stream network.Stream) {
streamHandler[*spec.BlockHeadersRequest](stream, h.OnBlockHeadersRequest, h.log)
}

if err != nil {
return nil, err
}
return StaticStream[proto.Message](singleResponse), nil
func (h *Handler) BlockBodiesHandler(stream network.Stream) {
streamHandler[*spec.BlockBodiesRequest](stream, h.OnBlockBodiesRequest, h.log)
}

func (h *Handler) EventsHandler(stream network.Stream) {
streamHandler[*spec.EventsRequest](stream, h.OnEventsRequest, h.log)
}

func (h *Handler) HandleGetBlocks(req *spec.GetBlocks) (Stream[proto.Message], error) {
func (h *Handler) ReceiptsHandler(stream network.Stream) {
streamHandler[*spec.ReceiptsRequest](stream, h.OnReceiptsRequest, h.log)
}

func (h *Handler) TransactionsHandler(stream network.Stream) {
streamHandler[*spec.TransactionsRequest](stream, h.OnTransactionsRequest, h.log)
}

func (h *Handler) OnBlockHeadersRequest(req *spec.BlockHeadersRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
count := uint32(0)
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.BlockHeader{
State: &spec.Merkle{
NLeaves: count - 1,
},
fmt.Println("counting ", count-1)
return &spec.BlockHeadersResponse{
BlockNumber: count - 1,
}, true
}, nil
}

func (h *Handler) HandleGetSignatures(req *spec.GetSignatures) (*spec.Signatures, error) {
func (h *Handler) OnBlockBodiesRequest(req *spec.BlockBodiesRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
return &spec.Signatures{
Id: req.Id,
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.BlockBodiesResponse{
BlockNumber: count - 1,
}, true
}, nil
}

func (h *Handler) HandleGetEvents(req *spec.GetEvents) (*spec.Events, error) {
block, err := h.blockByID(req.Id)
if err != nil {
return nil, err
}

var result spec.Events
for _, receipt := range block.Receipts {
for _, ev := range receipt.Events {
event := &spec.Event{
FromAddress: core2p2p.AdaptFelt(ev.From),
Keys: utils.Map(ev.Keys, core2p2p.AdaptFelt),
Data: utils.Map(ev.Data, core2p2p.AdaptFelt),
}

result.Events = append(result.Events, event)
func (h *Handler) OnEventsRequest(req *spec.EventsRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
}

return &result, nil
count++
return &spec.EventsResponse{
BlockNumber: count - 1,
}, true
}, nil
}

func (h *Handler) HandleGetReceipts(req *spec.GetReceipts) (*spec.Receipts, error) {
func (h *Handler) OnReceiptsRequest(req *spec.ReceiptsRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
magic := 37
return &spec.Receipts{
Receipts: make([]*spec.Receipt, magic),
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.ReceiptsResponse{
BlockNumber: count - 1,
}, true
}, nil
}

func (h *Handler) HandleGetTransactions(req *spec.GetTransactions) (*spec.Transactions, error) {
func (h *Handler) OnTransactionsRequest(req *spec.TransactionsRequest) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
magic := 1337
return &spec.Transactions{
Transactions: make([]*spec.Transaction, magic),
count := uint64(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.TransactionsResponse{
BlockNumber: count - 1,
}, true
}, nil
}

func (h *Handler) blockByID(id *spec.BlockID) (*core.Block, error) {
switch {
case id == nil:
return nil, errors.New("block id is nil")
case id.Hash != nil:
hash := p2p2core.AdaptHash(id.Hash)
return h.bcReader.BlockByHash(hash)
default:
return h.bcReader.BlockByNumber(id.Height)
}
}
Loading

0 comments on commit 5b8088c

Please sign in to comment.