diff --git a/les/backend.go b/les/backend.go index 646c81a7b13e..658c73c6ee8f 100644 --- a/les/backend.go +++ b/les/backend.go @@ -19,6 +19,7 @@ package les import ( "fmt" + "sync" "time" "github.com/ethereum/go-ethereum/accounts" @@ -38,6 +39,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/params" rpc "github.com/ethereum/go-ethereum/rpc" ) @@ -49,9 +51,13 @@ type LightEthereum struct { // Channel for shutting down the service shutdownChan chan bool // Handlers + peers *peerSet txPool *light.TxPool blockchain *light.LightChain protocolManager *ProtocolManager + serverPool *serverPool + reqDist *requestDistributor + retriever *retrieveManager // DB interfaces chainDb ethdb.Database // Block chain database @@ -63,6 +69,9 @@ type LightEthereum struct { networkId uint64 netRPCService *ethapi.PublicNetAPI + + quitSync chan struct{} + wg sync.WaitGroup } func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { @@ -76,20 +85,26 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { } log.Info("Initialised chain configuration", "config", chainConfig) - odr := NewLesOdr(chainDb) - relay := NewLesTxRelay() + peers := newPeerSet() + quitSync := make(chan struct{}) + eth := &LightEthereum{ - odr: odr, - relay: relay, - chainDb: chainDb, chainConfig: chainConfig, + chainDb: chainDb, eventMux: ctx.EventMux, + peers: peers, + reqDist: newRequestDistributor(peers, quitSync), accountManager: ctx.AccountManager, engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb), shutdownChan: make(chan bool), networkId: config.NetworkId, } - if eth.blockchain, err = light.NewLightChain(odr, eth.chainConfig, eth.engine, eth.eventMux); err != nil { + + eth.relay = NewLesTxRelay(peers, eth.reqDist) + eth.serverPool = newServerPool(chainDb, quitSync, ð.wg) + eth.retriever = newRetrieveManager(peers, eth.reqDist, eth.serverPool) + eth.odr = NewLesOdr(chainDb, eth.retriever) + if eth.blockchain, err = light.NewLightChain(eth.odr, eth.chainConfig, eth.engine, eth.eventMux); err != nil { return nil, err } // Rewind the chain in case of an incompatible config upgrade. @@ -100,13 +115,9 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { } eth.txPool = light.NewTxPool(eth.chainConfig, eth.eventMux, eth.blockchain, eth.relay) - lightSync := config.SyncMode == downloader.LightSync - if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, lightSync, config.NetworkId, eth.eventMux, eth.engine, eth.blockchain, nil, chainDb, odr, relay); err != nil { + if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, true, config.NetworkId, eth.eventMux, eth.engine, eth.peers, eth.blockchain, nil, chainDb, eth.odr, eth.relay, quitSync, ð.wg); err != nil { return nil, err } - relay.ps = eth.protocolManager.peers - relay.reqDist = eth.protocolManager.reqDist - eth.ApiBackend = &LesApiBackend{eth, nil} gpoParams := config.GPO if gpoParams.Default == nil { @@ -116,6 +127,10 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { return eth, nil } +func lesTopic(genesisHash common.Hash) discv5.Topic { + return discv5.Topic("LES@" + common.Bytes2Hex(genesisHash.Bytes()[0:8])) +} + type LightDummyAPI struct{} // Etherbase is the address that mining rewards will be send to @@ -188,7 +203,8 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { func (s *LightEthereum) Start(srvr *p2p.Server) error { log.Warn("Light client mode is an experimental feature") s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId) - s.protocolManager.Start(srvr) + s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash())) + s.protocolManager.Start() return nil } diff --git a/les/distributor.go b/les/distributor.go index 71afe2b73ed9..e8ef5b02e295 100644 --- a/les/distributor.go +++ b/les/distributor.go @@ -34,11 +34,11 @@ var ErrNoPeers = errors.New("no suitable peers available") type requestDistributor struct { reqQueue *list.List lastReqOrder uint64 + peers map[distPeer]struct{} + peerLock sync.RWMutex stopChn, loopChn chan struct{} loopNextSent bool lock sync.Mutex - - getAllPeers func() map[distPeer]struct{} } // distPeer is an LES server peer interface for the request distributor. @@ -71,15 +71,39 @@ type distReq struct { } // newRequestDistributor creates a new request distributor -func newRequestDistributor(getAllPeers func() map[distPeer]struct{}, stopChn chan struct{}) *requestDistributor { - r := &requestDistributor{ - reqQueue: list.New(), - loopChn: make(chan struct{}, 2), - stopChn: stopChn, - getAllPeers: getAllPeers, +func newRequestDistributor(peers *peerSet, stopChn chan struct{}) *requestDistributor { + d := &requestDistributor{ + reqQueue: list.New(), + loopChn: make(chan struct{}, 2), + stopChn: stopChn, + peers: make(map[distPeer]struct{}), + } + if peers != nil { + peers.notify(d) } - go r.loop() - return r + go d.loop() + return d +} + +// registerPeer implements peerSetNotify +func (d *requestDistributor) registerPeer(p *peer) { + d.peerLock.Lock() + d.peers[p] = struct{}{} + d.peerLock.Unlock() +} + +// unregisterPeer implements peerSetNotify +func (d *requestDistributor) unregisterPeer(p *peer) { + d.peerLock.Lock() + delete(d.peers, p) + d.peerLock.Unlock() +} + +// registerTestPeer adds a new test peer +func (d *requestDistributor) registerTestPeer(p distPeer) { + d.peerLock.Lock() + d.peers[p] = struct{}{} + d.peerLock.Unlock() } // distMaxWait is the maximum waiting time after which further necessary waiting @@ -152,8 +176,7 @@ func (sp selectPeerItem) Weight() int64 { // nextRequest returns the next possible request from any peer, along with the // associated peer and necessary waiting time func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) { - peers := d.getAllPeers() - + checkedPeers := make(map[distPeer]struct{}) elem := d.reqQueue.Front() var ( bestPeer distPeer @@ -162,11 +185,14 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) { sel *weightedRandomSelect ) - for (len(peers) > 0 || elem == d.reqQueue.Front()) && elem != nil { + d.peerLock.RLock() + defer d.peerLock.RUnlock() + + for (len(d.peers) > 0 || elem == d.reqQueue.Front()) && elem != nil { req := elem.Value.(*distReq) canSend := false - for peer, _ := range peers { - if peer.canQueue() && req.canSend(peer) { + for peer, _ := range d.peers { + if _, ok := checkedPeers[peer]; !ok && peer.canQueue() && req.canSend(peer) { canSend = true cost := req.getCost(peer) wait, bufRemain := peer.waitBefore(cost) @@ -182,7 +208,7 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) { bestWait = wait } } - delete(peers, peer) + checkedPeers[peer] = struct{}{} } } next := elem.Next() diff --git a/les/distributor_test.go b/les/distributor_test.go index ae184b21bc6b..4e7f8bd291f6 100644 --- a/les/distributor_test.go +++ b/les/distributor_test.go @@ -122,20 +122,14 @@ func testRequestDistributor(t *testing.T, resend bool) { stop := make(chan struct{}) defer close(stop) + dist := newRequestDistributor(nil, stop) var peers [testDistPeerCount]*testDistPeer for i, _ := range peers { peers[i] = &testDistPeer{} go peers[i].worker(t, !resend, stop) + dist.registerTestPeer(peers[i]) } - dist := newRequestDistributor(func() map[distPeer]struct{} { - m := make(map[distPeer]struct{}) - for _, peer := range peers { - m[peer] = struct{}{} - } - return m - }, stop) - var wg sync.WaitGroup for i := 1; i <= testDistReqCount; i++ { diff --git a/les/fetcher.go b/les/fetcher.go index a294d00d5292..4fc142f0f4c7 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -116,6 +116,7 @@ func newLightFetcher(pm *ProtocolManager) *lightFetcher { syncDone: make(chan *peer), maxConfirmedTd: big.NewInt(0), } + pm.peers.notify(f) go f.syncLoop() return f } @@ -209,8 +210,8 @@ func (f *lightFetcher) syncLoop() { } } -// addPeer adds a new peer to the fetcher's peer set -func (f *lightFetcher) addPeer(p *peer) { +// registerPeer adds a new peer to the fetcher's peer set +func (f *lightFetcher) registerPeer(p *peer) { p.lock.Lock() p.hasBlock = func(hash common.Hash, number uint64) bool { return f.peerHasBlock(p, hash, number) @@ -223,8 +224,8 @@ func (f *lightFetcher) addPeer(p *peer) { f.peers[p] = &fetcherPeerInfo{nodeByHash: make(map[common.Hash]*fetcherTreeNode)} } -// removePeer removes a new peer from the fetcher's peer set -func (f *lightFetcher) removePeer(p *peer) { +// unregisterPeer removes a new peer from the fetcher's peer set +func (f *lightFetcher) unregisterPeer(p *peer) { p.lock.Lock() p.hasBlock = nil p.lock.Unlock() @@ -416,7 +417,7 @@ func (f *lightFetcher) nextRequest() (*distReq, uint64) { f.syncing = bestSyncing var rq *distReq - reqID := getNextReqID() + reqID := genReqID() if f.syncing { rq = &distReq{ getCost: func(dp distPeer) uint64 { diff --git a/les/handler.go b/les/handler.go index 64023af0f5af..77bc077a2ee6 100644 --- a/les/handler.go +++ b/les/handler.go @@ -102,7 +102,9 @@ type ProtocolManager struct { odr *LesOdr server *LesServer serverPool *serverPool + lesTopic discv5.Topic reqDist *requestDistributor + retriever *retrieveManager downloader *downloader.Downloader fetcher *lightFetcher @@ -123,12 +125,12 @@ type ProtocolManager struct { // wait group is used for graceful shutdowns during downloading // and processing - wg sync.WaitGroup + wg *sync.WaitGroup } // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // with the ethereum network. -func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, networkId uint64, mux *event.TypeMux, engine consensus.Engine, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay) (*ProtocolManager, error) { +func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, networkId uint64, mux *event.TypeMux, engine consensus.Engine, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay, quitSync chan struct{}, wg *sync.WaitGroup) (*ProtocolManager, error) { // Create the protocol manager with the base fields manager := &ProtocolManager{ lightSync: lightSync, @@ -136,15 +138,20 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network blockchain: blockchain, chainConfig: chainConfig, chainDb: chainDb, + odr: odr, networkId: networkId, txpool: txpool, txrelay: txrelay, - odr: odr, - peers: newPeerSet(), + peers: peers, newPeerCh: make(chan *peer), - quitSync: make(chan struct{}), + quitSync: quitSync, + wg: wg, noMorePeers: make(chan struct{}), } + if odr != nil { + manager.retriever = odr.retriever + manager.reqDist = odr.retriever.dist + } // Initiate a sub-protocol for every implemented version we can handle manager.SubProtocols = make([]p2p.Protocol, 0, len(ProtocolVersions)) for i, version := range ProtocolVersions { @@ -202,84 +209,22 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, blockchain.HasHeader, nil, blockchain.GetHeaderByHash, nil, blockchain.CurrentHeader, nil, nil, nil, blockchain.GetTdByHash, blockchain.InsertHeaderChain, nil, nil, blockchain.Rollback, removePeer) + manager.peers.notify((*downloaderPeerNotify)(manager)) + manager.fetcher = newLightFetcher(manager) } - manager.reqDist = newRequestDistributor(func() map[distPeer]struct{} { - m := make(map[distPeer]struct{}) - peers := manager.peers.AllPeers() - for _, peer := range peers { - m[peer] = struct{}{} - } - return m - }, manager.quitSync) - if odr != nil { - odr.removePeer = removePeer - odr.reqDist = manager.reqDist - } - - /*validator := func(block *types.Block, parent *types.Block) error { - return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false) - } - heighter := func() uint64 { - return chainman.LastBlockNumberU64() - } - manager.fetcher = fetcher.New(chainman.GetBlockNoOdr, validator, nil, heighter, chainman.InsertChain, manager.removePeer) - */ return manager, nil } +// removePeer initiates disconnection from a peer by removing it from the peer set func (pm *ProtocolManager) removePeer(id string) { - // Short circuit if the peer was already removed - peer := pm.peers.Peer(id) - if peer == nil { - return - } - log.Debug("Removing light Ethereum peer", "peer", id) - if err := pm.peers.Unregister(id); err != nil { - if err == errNotRegistered { - return - } - } - // Unregister the peer from the downloader and Ethereum peer set - if pm.lightSync { - pm.downloader.UnregisterPeer(id) - if pm.txrelay != nil { - pm.txrelay.removePeer(id) - } - if pm.fetcher != nil { - pm.fetcher.removePeer(peer) - } - } - // Hard disconnect at the networking layer - if peer != nil { - peer.Peer.Disconnect(p2p.DiscUselessPeer) - } + pm.peers.Unregister(id) } -func (pm *ProtocolManager) Start(srvr *p2p.Server) { - var topicDisc *discv5.Network - if srvr != nil { - topicDisc = srvr.DiscV5 - } - lesTopic := discv5.Topic("LES@" + common.Bytes2Hex(pm.blockchain.Genesis().Hash().Bytes()[0:8])) +func (pm *ProtocolManager) Start() { if pm.lightSync { - // start sync handler - if srvr != nil { // srvr is nil during testing - pm.serverPool = newServerPool(pm.chainDb, []byte("serverPool/"), srvr, lesTopic, pm.quitSync, &pm.wg) - pm.odr.serverPool = pm.serverPool - pm.fetcher = newLightFetcher(pm) - } go pm.syncer() } else { - if topicDisc != nil { - go func() { - logger := log.New("topic", lesTopic) - logger.Info("Starting topic registration") - defer logger.Info("Terminated topic registration") - - topicDisc.RegisterTopic(lesTopic, pm.quitSync) - }() - } go func() { for range pm.newPeerCh { } @@ -342,65 +287,10 @@ func (pm *ProtocolManager) handle(p *peer) error { }() // Register the peer in the downloader. If the downloader considers it banned, we disconnect if pm.lightSync { - requestHeadersByHash := func(origin common.Hash, amount int, skip int, reverse bool) error { - reqID := getNextReqID() - rq := &distReq{ - getCost: func(dp distPeer) uint64 { - peer := dp.(*peer) - return peer.GetRequestCost(GetBlockHeadersMsg, amount) - }, - canSend: func(dp distPeer) bool { - return dp.(*peer) == p - }, - request: func(dp distPeer) func() { - peer := dp.(*peer) - cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) - peer.fcServer.QueueRequest(reqID, cost) - return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) } - }, - } - _, ok := <-pm.reqDist.queue(rq) - if !ok { - return ErrNoPeers - } - return nil - } - requestHeadersByNumber := func(origin uint64, amount int, skip int, reverse bool) error { - reqID := getNextReqID() - rq := &distReq{ - getCost: func(dp distPeer) uint64 { - peer := dp.(*peer) - return peer.GetRequestCost(GetBlockHeadersMsg, amount) - }, - canSend: func(dp distPeer) bool { - return dp.(*peer) == p - }, - request: func(dp distPeer) func() { - peer := dp.(*peer) - cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) - peer.fcServer.QueueRequest(reqID, cost) - return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) } - }, - } - _, ok := <-pm.reqDist.queue(rq) - if !ok { - return ErrNoPeers - } - return nil - } - if err := pm.downloader.RegisterPeer(p.id, ethVersion, p.HeadAndTd, - requestHeadersByHash, requestHeadersByNumber, nil, nil, nil); err != nil { - return err - } - if pm.txrelay != nil { - pm.txrelay.addPeer(p) - } - p.lock.Lock() head := p.headInfo p.lock.Unlock() if pm.fetcher != nil { - pm.fetcher.addPeer(p) pm.fetcher.announce(p, head) } @@ -926,7 +816,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } if deliverMsg != nil { - err := pm.odr.Deliver(p, deliverMsg) + err := pm.retriever.deliver(p, deliverMsg) if err != nil { p.responseErrors++ if p.responseErrors > maxResponseErrors { @@ -946,3 +836,64 @@ func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo { Head: self.blockchain.LastBlockHash(), } } + +// downloaderPeerNotify implements peerSetNotify +type downloaderPeerNotify ProtocolManager + +func (d *downloaderPeerNotify) registerPeer(p *peer) { + pm := (*ProtocolManager)(d) + + requestHeadersByHash := func(origin common.Hash, amount int, skip int, reverse bool) error { + reqID := genReqID() + rq := &distReq{ + getCost: func(dp distPeer) uint64 { + peer := dp.(*peer) + return peer.GetRequestCost(GetBlockHeadersMsg, amount) + }, + canSend: func(dp distPeer) bool { + return dp.(*peer) == p + }, + request: func(dp distPeer) func() { + peer := dp.(*peer) + cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) + peer.fcServer.QueueRequest(reqID, cost) + return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) } + }, + } + _, ok := <-pm.reqDist.queue(rq) + if !ok { + return ErrNoPeers + } + return nil + } + requestHeadersByNumber := func(origin uint64, amount int, skip int, reverse bool) error { + reqID := genReqID() + rq := &distReq{ + getCost: func(dp distPeer) uint64 { + peer := dp.(*peer) + return peer.GetRequestCost(GetBlockHeadersMsg, amount) + }, + canSend: func(dp distPeer) bool { + return dp.(*peer) == p + }, + request: func(dp distPeer) func() { + peer := dp.(*peer) + cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) + peer.fcServer.QueueRequest(reqID, cost) + return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) } + }, + } + _, ok := <-pm.reqDist.queue(rq) + if !ok { + return ErrNoPeers + } + return nil + } + + pm.downloader.RegisterPeer(p.id, ethVersion, p.HeadAndTd, requestHeadersByHash, requestHeadersByNumber, nil, nil, nil) +} + +func (d *downloaderPeerNotify) unregisterPeer(p *peer) { + pm := (*ProtocolManager)(d) + pm.downloader.UnregisterPeer(p.id) +} diff --git a/les/handler_test.go b/les/handler_test.go index 0b94d0d30b2d..5df1d3463aca 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" @@ -42,7 +43,8 @@ func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{} func TestGetBlockHeadersLes1(t *testing.T) { testGetBlockHeaders(t, 1) } func testGetBlockHeaders(t *testing.T, protocol int) { - pm, _, _ := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil) + db, _ := ethdb.NewMemDatabase() + pm := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil, nil, nil, db) bc := pm.blockchain.(*core.BlockChain) peer, _ := newTestPeer(t, "peer", protocol, pm, true) defer peer.close() @@ -170,7 +172,8 @@ func testGetBlockHeaders(t *testing.T, protocol int) { func TestGetBlockBodiesLes1(t *testing.T) { testGetBlockBodies(t, 1) } func testGetBlockBodies(t *testing.T, protocol int) { - pm, _, _ := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil) + db, _ := ethdb.NewMemDatabase() + pm := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil, nil, nil, db) bc := pm.blockchain.(*core.BlockChain) peer, _ := newTestPeer(t, "peer", protocol, pm, true) defer peer.close() @@ -246,7 +249,8 @@ func TestGetCodeLes1(t *testing.T) { testGetCode(t, 1) } func testGetCode(t *testing.T, protocol int) { // Assemble the test environment - pm, _, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) + db, _ := ethdb.NewMemDatabase() + pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) bc := pm.blockchain.(*core.BlockChain) peer, _ := newTestPeer(t, "peer", protocol, pm, true) defer peer.close() @@ -278,7 +282,8 @@ func TestGetReceiptLes1(t *testing.T) { testGetReceipt(t, 1) } func testGetReceipt(t *testing.T, protocol int) { // Assemble the test environment - pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) + db, _ := ethdb.NewMemDatabase() + pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) bc := pm.blockchain.(*core.BlockChain) peer, _ := newTestPeer(t, "peer", protocol, pm, true) defer peer.close() @@ -304,7 +309,8 @@ func TestGetProofsLes1(t *testing.T) { testGetReceipt(t, 1) } func testGetProofs(t *testing.T, protocol int) { // Assemble the test environment - pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) + db, _ := ethdb.NewMemDatabase() + pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) bc := pm.blockchain.(*core.BlockChain) peer, _ := newTestPeer(t, "peer", protocol, pm, true) defer peer.close() diff --git a/les/helper_test.go b/les/helper_test.go index 7e442c131b3c..52fddd117a4d 100644 --- a/les/helper_test.go +++ b/les/helper_test.go @@ -25,7 +25,6 @@ import ( "math/big" "sync" "testing" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/ethash" @@ -132,22 +131,22 @@ func testRCL() RequestCostList { // newTestProtocolManager creates a new protocol manager for testing purposes, // with the given number of blocks already known, and potential notification // channels for different events. -func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *core.BlockGen)) (*ProtocolManager, ethdb.Database, *LesOdr, error) { +func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *core.BlockGen), peers *peerSet, odr *LesOdr, db ethdb.Database) (*ProtocolManager, error) { var ( evmux = new(event.TypeMux) engine = ethash.NewFaker() - db, _ = ethdb.NewMemDatabase() gspec = core.Genesis{ Config: params.TestChainConfig, Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}, } genesis = gspec.MustCommit(db) - odr *LesOdr - chain BlockChain + chain BlockChain ) + if peers == nil { + peers = newPeerSet() + } if lightSync { - odr = NewLesOdr(db) chain, _ = light.NewLightChain(odr, gspec.Config, engine, evmux) } else { blockchain, _ := core.NewBlockChain(db, gspec.Config, engine, evmux, vm.Config{}) @@ -158,9 +157,9 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor chain = blockchain } - pm, err := NewProtocolManager(gspec.Config, lightSync, NetworkId, evmux, engine, chain, nil, db, odr, nil) + pm, err := NewProtocolManager(gspec.Config, lightSync, NetworkId, evmux, engine, peers, chain, nil, db, odr, nil, make(chan struct{}), new(sync.WaitGroup)) if err != nil { - return nil, nil, nil, err + return nil, err } if !lightSync { srv := &LesServer{protocolManager: pm} @@ -174,20 +173,20 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor srv.fcManager = flowcontrol.NewClientManager(50, 10, 1000000000) srv.fcCostStats = newCostStats(nil) } - pm.Start(nil) - return pm, db, odr, nil + pm.Start() + return pm, nil } // newTestProtocolManagerMust creates a new protocol manager for testing purposes, // with the given number of blocks already known, and potential notification // channels for different events. In case of an error, the constructor force- // fails the test. -func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, generator func(int, *core.BlockGen)) (*ProtocolManager, ethdb.Database, *LesOdr) { - pm, db, odr, err := newTestProtocolManager(lightSync, blocks, generator) +func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, generator func(int, *core.BlockGen), peers *peerSet, odr *LesOdr, db ethdb.Database) *ProtocolManager { + pm, err := newTestProtocolManager(lightSync, blocks, generator, peers, odr, db) if err != nil { t.Fatalf("Failed to create protocol manager: %v", err) } - return pm, db, odr + return pm } // testTxPool is a fake, helper transaction pool for testing purposes @@ -342,30 +341,3 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu func (p *testPeer) close() { p.app.Close() } - -type testServerPool struct { - peer *peer - lock sync.RWMutex -} - -func (p *testServerPool) setPeer(peer *peer) { - p.lock.Lock() - defer p.lock.Unlock() - - p.peer = peer -} - -func (p *testServerPool) getAllPeers() map[distPeer]struct{} { - p.lock.RLock() - defer p.lock.RUnlock() - - m := make(map[distPeer]struct{}) - if p.peer != nil { - m[p.peer] = struct{}{} - } - return m -} - -func (p *testServerPool) adjustResponseTime(*poolEntry, time.Duration, bool) { - -} diff --git a/les/odr.go b/les/odr.go index 684f36c761dc..3f7584b48e12 100644 --- a/les/odr.go +++ b/les/odr.go @@ -18,45 +18,24 @@ package les import ( "context" - "crypto/rand" - "encoding/binary" - "sync" - "time" - "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" ) -var ( - softRequestTimeout = time.Millisecond * 500 - hardRequestTimeout = time.Second * 10 -) - -// peerDropFn is a callback type for dropping a peer detected as malicious. -type peerDropFn func(id string) - -type odrPeerSelector interface { - adjustResponseTime(*poolEntry, time.Duration, bool) -} - +// LesOdr implements light.OdrBackend type LesOdr struct { - light.OdrBackend - db ethdb.Database - stop chan struct{} - removePeer peerDropFn - mlock, clock sync.Mutex - sentReqs map[uint64]*sentReq - serverPool odrPeerSelector - reqDist *requestDistributor + db ethdb.Database + stop chan struct{} + retriever *retrieveManager } -func NewLesOdr(db ethdb.Database) *LesOdr { +func NewLesOdr(db ethdb.Database, retriever *retrieveManager) *LesOdr { return &LesOdr{ - db: db, - stop: make(chan struct{}), - sentReqs: make(map[uint64]*sentReq), + db: db, + retriever: retriever, + stop: make(chan struct{}), } } @@ -68,17 +47,6 @@ func (odr *LesOdr) Database() ethdb.Database { return odr.db } -// validatorFunc is a function that processes a message. -type validatorFunc func(ethdb.Database, *Msg) error - -// sentReq is a request waiting for an answer that satisfies its valFunc -type sentReq struct { - valFunc validatorFunc - sentTo map[*peer]chan struct{} - lock sync.RWMutex // protects acces to sentTo - answered chan struct{} // closed and set to nil when any peer answers it -} - const ( MsgBlockBodies = iota MsgCode @@ -94,156 +62,29 @@ type Msg struct { Obj interface{} } -// Deliver is called by the LES protocol manager to deliver ODR reply messages to waiting requests -func (self *LesOdr) Deliver(peer *peer, msg *Msg) error { - var delivered chan struct{} - self.mlock.Lock() - req, ok := self.sentReqs[msg.ReqID] - self.mlock.Unlock() - if ok { - req.lock.Lock() - delivered, ok = req.sentTo[peer] - req.lock.Unlock() - } - - if !ok { - return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID) - } - - if err := req.valFunc(self.db, msg); err != nil { - peer.Log().Warn("Invalid odr response", "err", err) - return errResp(ErrInvalidResponse, "reqID = %v", msg.ReqID) - } - close(delivered) - req.lock.Lock() - delete(req.sentTo, peer) - if req.answered != nil { - close(req.answered) - req.answered = nil - } - req.lock.Unlock() - return nil -} - -func (self *LesOdr) requestPeer(req *sentReq, peer *peer, delivered, timeout chan struct{}, reqWg *sync.WaitGroup) { - stime := mclock.Now() - defer func() { - req.lock.Lock() - delete(req.sentTo, peer) - req.lock.Unlock() - reqWg.Done() - }() - - select { - case <-delivered: - if self.serverPool != nil { - self.serverPool.adjustResponseTime(peer.poolEntry, time.Duration(mclock.Now()-stime), false) - } - return - case <-time.After(softRequestTimeout): - close(timeout) - case <-self.stop: - return - } - - select { - case <-delivered: - case <-time.After(hardRequestTimeout): - peer.Log().Debug("Request timed out hard") - go self.removePeer(peer.id) - case <-self.stop: - return - } - if self.serverPool != nil { - self.serverPool.adjustResponseTime(peer.poolEntry, time.Duration(mclock.Now()-stime), true) - } -} - -// networkRequest sends a request to known peers until an answer is received -// or the context is cancelled -func (self *LesOdr) networkRequest(ctx context.Context, lreq LesOdrRequest) error { - answered := make(chan struct{}) - req := &sentReq{ - valFunc: lreq.Validate, - sentTo: make(map[*peer]chan struct{}), - answered: answered, // reply delivered by any peer - } - - exclude := make(map[*peer]struct{}) - - reqWg := new(sync.WaitGroup) - reqWg.Add(1) - defer reqWg.Done() +// Retrieve tries to fetch an object from the LES network. +// If the network retrieval was successful, it stores the object in local db. +func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { + lreq := LesRequest(req) - var timeout chan struct{} - reqID := getNextReqID() + reqID := genReqID() rq := &distReq{ getCost: func(dp distPeer) uint64 { return lreq.GetCost(dp.(*peer)) }, canSend: func(dp distPeer) bool { p := dp.(*peer) - _, ok := exclude[p] - return !ok && lreq.CanSend(p) + return lreq.CanSend(p) }, request: func(dp distPeer) func() { p := dp.(*peer) - exclude[p] = struct{}{} - delivered := make(chan struct{}) - timeout = make(chan struct{}) - req.lock.Lock() - req.sentTo[p] = delivered - req.lock.Unlock() - reqWg.Add(1) cost := lreq.GetCost(p) p.fcServer.QueueRequest(reqID, cost) - go self.requestPeer(req, p, delivered, timeout, reqWg) return func() { lreq.Request(reqID, p) } }, } - self.mlock.Lock() - self.sentReqs[reqID] = req - self.mlock.Unlock() - - go func() { - reqWg.Wait() - self.mlock.Lock() - delete(self.sentReqs, reqID) - self.mlock.Unlock() - }() - - for { - peerChn := self.reqDist.queue(rq) - select { - case <-ctx.Done(): - self.reqDist.cancel(rq) - return ctx.Err() - case <-answered: - self.reqDist.cancel(rq) - return nil - case _, ok := <-peerChn: - if !ok { - return ErrNoPeers - } - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-answered: - return nil - case <-timeout: - } - } -} - -// Retrieve tries to fetch an object from the LES network. -// If the network retrieval was successful, it stores the object in local db. -func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { - lreq := LesRequest(req) - err = self.networkRequest(ctx, lreq) - if err == nil { + if err = self.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(self.db, msg) }); err == nil { // retrieved from network, store in db req.StoreResult(self.db) } else { @@ -251,9 +92,3 @@ func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err err } return } - -func getNextReqID() uint64 { - var rnd [8]byte - rand.Read(rnd[:]) - return binary.BigEndian.Uint64(rnd[:]) -} diff --git a/les/odr_test.go b/les/odr_test.go index 532de4d80be7..7b34996ceea8 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -158,15 +158,15 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { // Assemble the test environment - pm, db, odr := newTestProtocolManagerMust(t, false, 4, testChainGen) - lpm, ldb, odr := newTestProtocolManagerMust(t, true, 0, nil) + peers := newPeerSet() + dist := newRequestDistributor(peers, make(chan struct{})) + rm := newRetrieveManager(peers, dist, nil) + db, _ := ethdb.NewMemDatabase() + ldb, _ := ethdb.NewMemDatabase() + odr := NewLesOdr(ldb, rm) + pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) + lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) - pool := &testServerPool{} - lpm.reqDist = newRequestDistributor(pool.getAllPeers, lpm.quitSync) - odr.reqDist = lpm.reqDist - pool.setPeer(lpeer) - odr.serverPool = pool - lpeer.hasBlock = func(common.Hash, uint64) bool { return true } select { case <-time.After(time.Millisecond * 100): case err := <-err1: @@ -198,13 +198,19 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { } // temporarily remove peer to test odr fails - pool.setPeer(nil) // expect retrievals to fail (except genesis block) without a les peer + peers.Unregister(lpeer.id) + time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed test(expFail) - pool.setPeer(lpeer) // expect all retrievals to pass + peers.Register(lpeer) + time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed + lpeer.lock.Lock() + lpeer.hasBlock = func(common.Hash, uint64) bool { return true } + lpeer.lock.Unlock() test(5) - pool.setPeer(nil) // still expect all retrievals to pass, now data should be cached locally + peers.Unregister(lpeer.id) + time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed test(5) } diff --git a/les/peer.go b/les/peer.go index ab55bafe3e90..791d0da24f35 100644 --- a/les/peer.go +++ b/les/peer.go @@ -166,9 +166,9 @@ func (p *peer) GetRequestCost(msgcode uint64, amount int) uint64 { // HasBlock checks if the peer has a given block func (p *peer) HasBlock(hash common.Hash, number uint64) bool { p.lock.RLock() - hashBlock := p.hasBlock + hasBlock := p.hasBlock p.lock.RUnlock() - return hashBlock != nil && hashBlock(hash, number) + return hasBlock != nil && hasBlock(hash, number) } // SendAnnounce announces the availability of a number of blocks through @@ -433,12 +433,20 @@ func (p *peer) String() string { ) } +// peerSetNotify is a callback interface to notify services about added or +// removed peers +type peerSetNotify interface { + registerPeer(*peer) + unregisterPeer(*peer) +} + // peerSet represents the collection of active peers currently participating in // the Light Ethereum sub-protocol. type peerSet struct { - peers map[string]*peer - lock sync.RWMutex - closed bool + peers map[string]*peer + lock sync.RWMutex + notifyList []peerSetNotify + closed bool } // newPeerSet creates a new peer set to track the active participants. @@ -448,6 +456,17 @@ func newPeerSet() *peerSet { } } +// notify adds a service to be notified about added or removed peers +func (ps *peerSet) notify(n peerSetNotify) { + ps.lock.Lock() + defer ps.lock.Unlock() + + ps.notifyList = append(ps.notifyList, n) + for _, p := range ps.peers { + go n.registerPeer(p) + } +} + // Register injects a new peer into the working set, or returns an error if the // peer is already known. func (ps *peerSet) Register(p *peer) error { @@ -462,11 +481,14 @@ func (ps *peerSet) Register(p *peer) error { } ps.peers[p.id] = p p.sendQueue = newExecQueue(100) + for _, n := range ps.notifyList { + go n.registerPeer(p) + } return nil } // Unregister removes a remote peer from the active set, disabling any further -// actions to/from that particular entity. +// actions to/from that particular entity. It also initiates disconnection at the networking layer. func (ps *peerSet) Unregister(id string) error { ps.lock.Lock() defer ps.lock.Unlock() @@ -474,7 +496,11 @@ func (ps *peerSet) Unregister(id string) error { if p, ok := ps.peers[id]; !ok { return errNotRegistered } else { + for _, n := range ps.notifyList { + go n.unregisterPeer(p) + } p.sendQueue.quit() + p.Peer.Disconnect(p2p.DiscUselessPeer) } delete(ps.peers, id) return nil diff --git a/les/request_test.go b/les/request_test.go index ba1fc15bd741..3add5f20d7e4 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -68,15 +68,16 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.Odr func testAccess(t *testing.T, protocol int, fn accessTestFn) { // Assemble the test environment - pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) - lpm, ldb, odr := newTestProtocolManagerMust(t, true, 0, nil) + peers := newPeerSet() + dist := newRequestDistributor(peers, make(chan struct{})) + rm := newRetrieveManager(peers, dist, nil) + db, _ := ethdb.NewMemDatabase() + ldb, _ := ethdb.NewMemDatabase() + odr := NewLesOdr(ldb, rm) + + pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) + lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) - pool := &testServerPool{} - lpm.reqDist = newRequestDistributor(pool.getAllPeers, lpm.quitSync) - odr.reqDist = lpm.reqDist - pool.setPeer(lpeer) - odr.serverPool = pool - lpeer.hasBlock = func(common.Hash, uint64) bool { return true } select { case <-time.After(time.Millisecond * 100): case err := <-err1: @@ -108,10 +109,16 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { } // temporarily remove peer to test odr fails - pool.setPeer(nil) + peers.Unregister(lpeer.id) + time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed // expect retrievals to fail (except genesis block) without a les peer test(0) - pool.setPeer(lpeer) + + peers.Register(lpeer) + time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed + lpeer.lock.Lock() + lpeer.hasBlock = func(common.Hash, uint64) bool { return true } + lpeer.lock.Unlock() // expect all retrievals to pass test(5) } diff --git a/les/retrieve.go b/les/retrieve.go new file mode 100644 index 000000000000..b060e0b0d850 --- /dev/null +++ b/les/retrieve.go @@ -0,0 +1,395 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package light implements on-demand retrieval capable state and chain objects +// for the Ethereum Light Client. +package les + +import ( + "context" + "crypto/rand" + "encoding/binary" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" +) + +var ( + retryQueue = time.Millisecond * 100 + softRequestTimeout = time.Millisecond * 500 + hardRequestTimeout = time.Second * 10 +) + +// retrieveManager is a layer on top of requestDistributor which takes care of +// matching replies by request ID and handles timeouts and resends if necessary. +type retrieveManager struct { + dist *requestDistributor + peers *peerSet + serverPool peerSelector + + lock sync.RWMutex + sentReqs map[uint64]*sentReq +} + +// validatorFunc is a function that processes a reply message +type validatorFunc func(distPeer, *Msg) error + +// peerSelector receives feedback info about response times and timeouts +type peerSelector interface { + adjustResponseTime(*poolEntry, time.Duration, bool) +} + +// sentReq represents a request sent and tracked by retrieveManager +type sentReq struct { + rm *retrieveManager + req *distReq + id uint64 + validate validatorFunc + + eventsCh chan reqPeerEvent + stopCh chan struct{} + stopped bool + err error + + lock sync.RWMutex // protect access to sentTo map + sentTo map[distPeer]sentReqToPeer + + reqQueued bool // a request has been queued but not sent + reqSent bool // a request has been sent but not timed out + reqSrtoCount int // number of requests that reached soft (but not hard) timeout +} + +// sentReqToPeer notifies the request-from-peer goroutine (tryRequest) about a response +// delivered by the given peer. Only one delivery is allowed per request per peer, +// after which delivered is set to true, the validity of the response is sent on the +// valid channel and no more responses are accepted. +type sentReqToPeer struct { + delivered bool + valid chan bool +} + +// reqPeerEvent is sent by the request-from-peer goroutine (tryRequest) to the +// request state machine (retrieveLoop) through the eventsCh channel. +type reqPeerEvent struct { + event int + peer distPeer +} + +const ( + rpSent = iota // if peer == nil, not sent (no suitable peers) + rpSoftTimeout + rpHardTimeout + rpDeliveredValid + rpDeliveredInvalid +) + +// newRetrieveManager creates the retrieve manager +func newRetrieveManager(peers *peerSet, dist *requestDistributor, serverPool peerSelector) *retrieveManager { + return &retrieveManager{ + peers: peers, + dist: dist, + serverPool: serverPool, + sentReqs: make(map[uint64]*sentReq), + } +} + +// retrieve sends a request (to multiple peers if necessary) and waits for an answer +// that is delivered through the deliver function and successfully validated by the +// validator callback. It returns when a valid answer is delivered or the context is +// cancelled. +func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *distReq, val validatorFunc) error { + sentReq := rm.sendReq(reqID, req, val) + select { + case <-sentReq.stopCh: + case <-ctx.Done(): + sentReq.stop(ctx.Err()) + } + return sentReq.getError() +} + +// sendReq starts a process that keeps trying to retrieve a valid answer for a +// request from any suitable peers until stopped or succeeded. +func (rm *retrieveManager) sendReq(reqID uint64, req *distReq, val validatorFunc) *sentReq { + r := &sentReq{ + rm: rm, + req: req, + id: reqID, + sentTo: make(map[distPeer]sentReqToPeer), + stopCh: make(chan struct{}), + eventsCh: make(chan reqPeerEvent, 10), + validate: val, + } + + canSend := req.canSend + req.canSend = func(p distPeer) bool { + // add an extra check to canSend: the request has not been sent to the same peer before + r.lock.RLock() + _, sent := r.sentTo[p] + r.lock.RUnlock() + return !sent && canSend(p) + } + + request := req.request + req.request = func(p distPeer) func() { + // before actually sending the request, put an entry into the sentTo map + r.lock.Lock() + r.sentTo[p] = sentReqToPeer{false, make(chan bool, 1)} + r.lock.Unlock() + return request(p) + } + rm.lock.Lock() + rm.sentReqs[reqID] = r + rm.lock.Unlock() + + go r.retrieveLoop() + return r +} + +// deliver is called by the LES protocol manager to deliver reply messages to waiting requests +func (rm *retrieveManager) deliver(peer distPeer, msg *Msg) error { + rm.lock.RLock() + req, ok := rm.sentReqs[msg.ReqID] + rm.lock.RUnlock() + + if ok { + return req.deliver(peer, msg) + } + return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID) +} + +// reqStateFn represents a state of the retrieve loop state machine +type reqStateFn func() reqStateFn + +// retrieveLoop is the retrieval state machine event loop +func (r *sentReq) retrieveLoop() { + go r.tryRequest() + r.reqQueued = true + state := r.stateRequesting + + for state != nil { + state = state() + } + + r.rm.lock.Lock() + delete(r.rm.sentReqs, r.id) + r.rm.lock.Unlock() +} + +// stateRequesting: a request has been queued or sent recently; when it reaches soft timeout, +// a new request is sent to a new peer +func (r *sentReq) stateRequesting() reqStateFn { + select { + case ev := <-r.eventsCh: + r.update(ev) + switch ev.event { + case rpSent: + if ev.peer == nil { + // request send failed, no more suitable peers + if r.waiting() { + // we are already waiting for sent requests which may succeed so keep waiting + return r.stateNoMorePeers + } + // nothing to wait for, no more peers to ask, return with error + r.stop(ErrNoPeers) + // no need to go to stopped state because waiting() already returned false + return nil + } + case rpSoftTimeout: + // last request timed out, try asking a new peer + go r.tryRequest() + r.reqQueued = true + return r.stateRequesting + case rpDeliveredValid: + r.stop(nil) + return r.stateStopped + } + return r.stateRequesting + case <-r.stopCh: + return r.stateStopped + } +} + +// stateNoMorePeers: could not send more requests because no suitable peers are available. +// Peers may become suitable for a certain request later or new peers may appear so we +// keep trying. +func (r *sentReq) stateNoMorePeers() reqStateFn { + select { + case <-time.After(retryQueue): + go r.tryRequest() + r.reqQueued = true + return r.stateRequesting + case ev := <-r.eventsCh: + r.update(ev) + if ev.event == rpDeliveredValid { + r.stop(nil) + return r.stateStopped + } + return r.stateNoMorePeers + case <-r.stopCh: + return r.stateStopped + } +} + +// stateStopped: request succeeded or cancelled, just waiting for some peers +// to either answer or time out hard +func (r *sentReq) stateStopped() reqStateFn { + for r.waiting() { + r.update(<-r.eventsCh) + } + return nil +} + +// update updates the queued/sent flags and timed out peers counter according to the event +func (r *sentReq) update(ev reqPeerEvent) { + switch ev.event { + case rpSent: + r.reqQueued = false + if ev.peer != nil { + r.reqSent = true + } + case rpSoftTimeout: + r.reqSent = false + r.reqSrtoCount++ + case rpHardTimeout, rpDeliveredValid, rpDeliveredInvalid: + r.reqSrtoCount-- + } +} + +// waiting returns true if the retrieval mechanism is waiting for an answer from +// any peer +func (r *sentReq) waiting() bool { + return r.reqQueued || r.reqSent || r.reqSrtoCount > 0 +} + +// tryRequest tries to send the request to a new peer and waits for it to either +// succeed or time out if it has been sent. It also sends the appropriate reqPeerEvent +// messages to the request's event channel. +func (r *sentReq) tryRequest() { + sent := r.rm.dist.queue(r.req) + var p distPeer + select { + case p = <-sent: + case <-r.stopCh: + if r.rm.dist.cancel(r.req) { + p = nil + } else { + p = <-sent + } + } + + r.eventsCh <- reqPeerEvent{rpSent, p} + if p == nil { + return + } + + reqSent := mclock.Now() + srto, hrto := false, false + + r.lock.RLock() + s, ok := r.sentTo[p] + r.lock.RUnlock() + if !ok { + panic(nil) + } + + defer func() { + // send feedback to server pool and remove peer if hard timeout happened + pp, ok := p.(*peer) + if ok && r.rm.serverPool != nil { + respTime := time.Duration(mclock.Now() - reqSent) + r.rm.serverPool.adjustResponseTime(pp.poolEntry, respTime, srto) + } + if hrto { + pp.Log().Debug("Request timed out hard") + if r.rm.peers != nil { + r.rm.peers.Unregister(pp.id) + } + } + + r.lock.Lock() + delete(r.sentTo, p) + r.lock.Unlock() + }() + + select { + case ok := <-s.valid: + if ok { + r.eventsCh <- reqPeerEvent{rpDeliveredValid, p} + } else { + r.eventsCh <- reqPeerEvent{rpDeliveredInvalid, p} + } + return + case <-time.After(softRequestTimeout): + srto = true + r.eventsCh <- reqPeerEvent{rpSoftTimeout, p} + } + + select { + case ok := <-s.valid: + if ok { + r.eventsCh <- reqPeerEvent{rpDeliveredValid, p} + } else { + r.eventsCh <- reqPeerEvent{rpDeliveredInvalid, p} + } + case <-time.After(hardRequestTimeout): + hrto = true + r.eventsCh <- reqPeerEvent{rpHardTimeout, p} + } +} + +// deliver a reply belonging to this request +func (r *sentReq) deliver(peer distPeer, msg *Msg) error { + r.lock.Lock() + defer r.lock.Unlock() + + s, ok := r.sentTo[peer] + if !ok || s.delivered { + return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID) + } + valid := r.validate(peer, msg) == nil + r.sentTo[peer] = sentReqToPeer{true, s.valid} + s.valid <- valid + if !valid { + return errResp(ErrInvalidResponse, "reqID = %v", msg.ReqID) + } + return nil +} + +// stop stops the retrieval process and sets an error code that will be returned +// by getError +func (r *sentReq) stop(err error) { + r.lock.Lock() + if !r.stopped { + r.stopped = true + r.err = err + close(r.stopCh) + } + r.lock.Unlock() +} + +// getError returns any retrieval error (either internally generated or set by the +// stop function) after stopCh has been closed +func (r *sentReq) getError() error { + return r.err +} + +// genReqID generates a new random request ID +func genReqID() uint64 { + var rnd [8]byte + rand.Read(rnd[:]) + return binary.BigEndian.Uint64(rnd[:]) +} diff --git a/les/server.go b/les/server.go index 22fe59b7ac75..2ff715ea8581 100644 --- a/les/server.go +++ b/les/server.go @@ -32,6 +32,7 @@ import ( "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -41,17 +42,24 @@ type LesServer struct { fcManager *flowcontrol.ClientManager // nil if our node is client only fcCostStats *requestCostStats defParams *flowcontrol.ServerParams + lesTopic discv5.Topic + quitSync chan struct{} stopped bool } func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) { - pm, err := NewProtocolManager(eth.BlockChain().Config(), false, config.NetworkId, eth.EventMux(), eth.Engine(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil) + quitSync := make(chan struct{}) + pm, err := NewProtocolManager(eth.BlockChain().Config(), false, config.NetworkId, eth.EventMux(), eth.Engine(), newPeerSet(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil, quitSync, new(sync.WaitGroup)) if err != nil { return nil, err } pm.blockLoop() - srv := &LesServer{protocolManager: pm} + srv := &LesServer{ + protocolManager: pm, + quitSync: quitSync, + lesTopic: lesTopic(eth.BlockChain().Genesis().Hash()), + } pm.server = srv srv.defParams = &flowcontrol.ServerParams{ @@ -69,7 +77,14 @@ func (s *LesServer) Protocols() []p2p.Protocol { // Start starts the LES server func (s *LesServer) Start(srvr *p2p.Server) { - s.protocolManager.Start(srvr) + s.protocolManager.Start() + go func() { + logger := log.New("topic", s.lesTopic) + logger.Info("Starting topic registration") + defer logger.Info("Terminated topic registration") + + srvr.DiscV5.RegisterTopic(s.lesTopic, s.quitSync) + }() } // Stop stops the LES service diff --git a/les/serverpool.go b/les/serverpool.go index 64fe991c63be..f4e4df2fbfb5 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -102,6 +102,8 @@ type serverPool struct { wg *sync.WaitGroup connWg sync.WaitGroup + topic discv5.Topic + discSetPeriod chan time.Duration discNodes chan *discv5.Node discLookups chan bool @@ -118,11 +120,9 @@ type serverPool struct { } // newServerPool creates a new serverPool instance -func newServerPool(db ethdb.Database, dbPrefix []byte, server *p2p.Server, topic discv5.Topic, quit chan struct{}, wg *sync.WaitGroup) *serverPool { +func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup) *serverPool { pool := &serverPool{ db: db, - dbKey: append(dbPrefix, []byte(topic)...), - server: server, quit: quit, wg: wg, entries: make(map[discover.NodeID]*poolEntry), @@ -135,19 +135,25 @@ func newServerPool(db ethdb.Database, dbPrefix []byte, server *p2p.Server, topic } pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry) pool.newQueue = newPoolEntryQueue(maxNewEntries, pool.removeEntry) - wg.Add(1) + return pool +} + +func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { + pool.server = server + pool.topic = topic + pool.dbKey = append([]byte("serverPool/"), []byte(topic)...) + pool.wg.Add(1) pool.loadNodes() - pool.checkDial() + go pool.eventLoop() + + pool.checkDial() if pool.server.DiscV5 != nil { pool.discSetPeriod = make(chan time.Duration, 1) pool.discNodes = make(chan *discv5.Node, 100) pool.discLookups = make(chan bool, 100) - go pool.server.DiscV5.SearchTopic(topic, pool.discSetPeriod, pool.discNodes, pool.discLookups) + go pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, pool.discNodes, pool.discLookups) } - - go pool.eventLoop() - return pool } // connect should be called upon any incoming connection. If the connection has been @@ -485,7 +491,7 @@ func (pool *serverPool) checkDial() { // dial initiates a new connection func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) { - if entry.state != psNotConnected { + if pool.server == nil || entry.state != psNotConnected { return } entry.state = psDialed diff --git a/les/txrelay.go b/les/txrelay.go index 1ca3467e4dc8..7a02cc837e67 100644 --- a/les/txrelay.go +++ b/les/txrelay.go @@ -39,26 +39,28 @@ type LesTxRelay struct { reqDist *requestDistributor } -func NewLesTxRelay() *LesTxRelay { - return &LesTxRelay{ +func NewLesTxRelay(ps *peerSet, reqDist *requestDistributor) *LesTxRelay { + r := &LesTxRelay{ txSent: make(map[common.Hash]*ltrInfo), txPending: make(map[common.Hash]struct{}), + ps: ps, + reqDist: reqDist, } + ps.notify(r) + return r } -func (self *LesTxRelay) addPeer(p *peer) { +func (self *LesTxRelay) registerPeer(p *peer) { self.lock.Lock() defer self.lock.Unlock() - self.ps.Register(p) self.peerList = self.ps.AllPeers() } -func (self *LesTxRelay) removePeer(id string) { +func (self *LesTxRelay) unregisterPeer(p *peer) { self.lock.Lock() defer self.lock.Unlock() - self.ps.Unregister(id) self.peerList = self.ps.AllPeers() } @@ -112,7 +114,7 @@ func (self *LesTxRelay) send(txs types.Transactions, count int) { pp := p ll := list - reqID := getNextReqID() + reqID := genReqID() rq := &distReq{ getCost: func(dp distPeer) uint64 { peer := dp.(*peer)