Skip to content

Commit

Permalink
routing: add TlvTrafficShaper to bandwidth hints
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeTsagk authored and guggero committed May 22, 2024
1 parent ec7dc03 commit 37cbfea
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 38 deletions.
124 changes: 109 additions & 15 deletions routing/bandwidth.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package routing

import (
"fmt"

"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)

// bandwidthHints provides hints about the currently available balance in our
Expand All @@ -18,7 +23,39 @@ type bandwidthHints interface {
// will be used. If the channel is unavailable, a zero amount is
// returned.
availableChanBandwidth(channelID uint64,
amount lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool)
amount lnwire.MilliSatoshi,
htlcBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, bool)
}

// TlvTrafficShaper is an interface that allows the sender to determine if a
// payment should be carried by a channel based on the TLV records that may be
// present in the `update_add_htlc` message or the channel commitment itself.
type TlvTrafficShaper interface {
AuxHtlcModifier

// HandleTraffic is called in order to check if the channel identified
// by the provided channel ID may have external mechanisms that would
// allow it to carry out the payment.
HandleTraffic(cid lnwire.ShortChannelID,
fundingBlob fn.Option[tlv.Blob]) (bool, error)

// PaymentBandwidth returns the available bandwidth for a custom channel
// decided by the given channel aux blob and HTLC blob. A return value
// of 0 means there is no bandwidth available. To find out if a channel
// is a custom channel that should be handled by the traffic shaper, the
// HandleTraffic method should be called first.
PaymentBandwidth(htlcBlob,
commitmentBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, error)
}

// AuxHtlcModifier is an interface that allows the sender to modify the outgoing
// HTLC of a payment by changing the amount or the wire message tlv records.
type AuxHtlcModifier interface {
// ProduceHtlcExtraData is a function that, based on the previous extra
// data blob of an HTLC, may produce a different blob or modify the
// amount of bitcoin this htlc should carry.
ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi,
htlcBlob tlv.Blob) (btcutil.Amount, tlv.Blob, error)
}

// getLinkQuery is the function signature used to lookup a link.
Expand All @@ -29,8 +66,9 @@ type getLinkQuery func(lnwire.ShortChannelID) (
// uses the link lookup provided to query the link for our latest local channel
// balances.
type bandwidthManager struct {
getLink getLinkQuery
localChans map[lnwire.ShortChannelID]struct{}
getLink getLinkQuery
localChans map[lnwire.ShortChannelID]struct{}
trafficShaper fn.Option[TlvTrafficShaper]
}

// newBandwidthManager creates a bandwidth manager for the source node provided
Expand All @@ -40,11 +78,13 @@ type bandwidthManager struct {
// allows us to reduce the number of extraneous attempts as we can skip channels
// that are inactive, or just don't have enough bandwidth to carry the payment.
func newBandwidthManager(graph routingGraph, sourceNode route.Vertex,
linkQuery getLinkQuery) (*bandwidthManager, error) {
linkQuery getLinkQuery,
trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) {

manager := &bandwidthManager{
getLink: linkQuery,
localChans: make(map[lnwire.ShortChannelID]struct{}),
getLink: linkQuery,
localChans: make(map[lnwire.ShortChannelID]struct{}),
trafficShaper: trafficShaper,
}

// First, we'll collect the set of outbound edges from the target
Expand All @@ -71,7 +111,8 @@ func newBandwidthManager(graph routingGraph, sourceNode route.Vertex,
// queried is one of our local channels, so any failure to retrieve the link
// is interpreted as the link being offline.
func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID,
amount lnwire.MilliSatoshi) lnwire.MilliSatoshi {
amount lnwire.MilliSatoshi,
htlcBlob fn.Option[tlv.Blob]) lnwire.MilliSatoshi {

link, err := b.getLink(cid)
if err != nil {
Expand All @@ -89,30 +130,83 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID,
return 0
}

// If our link isn't currently in a state where it can add another
// outgoing htlc, treat the link as unusable.
var (
auxBandwidth lnwire.MilliSatoshi
auxBandwidthDetermined bool
)
err = fn.MapOptionZ(b.trafficShaper, func(ts TlvTrafficShaper) error {
fundingBlob := link.FundingCustomBlob()
shouldHandle, err := ts.HandleTraffic(cid, fundingBlob)
if err != nil {
return fmt.Errorf("traffic shaper failed to decide "+
"whether to handle traffic: %w", err)
}

log.Debugf("ShortChannelID=%v: external traffic shaper is "+
"handling traffic: %v", cid, shouldHandle)

// If this channel isn't handled by the external traffic shaper,
// we'll return early.
if !shouldHandle {
return nil
}

// Ask for a specific bandwidth to be used for the channel.
commitmentBlob := link.CommitmentCustomBlob()
auxBandwidth, err = ts.PaymentBandwidth(
htlcBlob, commitmentBlob,
)
if err != nil {
return fmt.Errorf("failed to get bandwidth from "+
"external traffic shaper: %w", err)
}

log.Debugf("ShortChannelID=%v: external traffic shaper "+
"reported available bandwidth: %v", cid, auxBandwidth)

auxBandwidthDetermined = true

return nil
})
if err != nil {
log.Errorf("ShortChannelID=%v: failed to get bandwidth from "+
"external traffic shaper: %v", cid, err)

return 0
}

// If our link isn't currently in a state where it can add
// another outgoing htlc, treat the link as unusable.
if err := link.MayAddOutgoingHtlc(amount); err != nil {
log.Warnf("ShortChannelID=%v: cannot add outgoing htlc: %v",
cid, err)
log.Warnf("ShortChannelID=%v: cannot add outgoing "+
"htlc: %v", cid, err)
return 0
}

// Otherwise, we'll return the current best estimate for the available
// bandwidth for the link.
// If the external traffic shaper determined the bandwidth, we'll return
// that value, even if it is zero (which would mean no bandwidth is
// available on that channel).
if auxBandwidthDetermined {
return auxBandwidth
}

// Otherwise, we'll return the current best estimate for the
// available bandwidth for the link.
return link.Bandwidth()
}

// availableChanBandwidth returns the total available bandwidth for a channel
// and a bool indicating whether the channel hint was found. If the channel is
// unavailable, a zero amount is returned.
func (b *bandwidthManager) availableChanBandwidth(channelID uint64,
amount lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) {
amount lnwire.MilliSatoshi,
htlcBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, bool) {

shortID := lnwire.NewShortChanIDFromInt(channelID)
_, ok := b.localChans[shortID]
if !ok {
return 0, false
}

return b.getBandwidth(shortID, amount), true
return b.getBandwidth(shortID, amount, htlcBlob), true
}
3 changes: 3 additions & 0 deletions routing/bandwidth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/btcsuite/btcd/btcutil"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -115,11 +116,13 @@ func TestBandwidthManager(t *testing.T) {

m, err := newBandwidthManager(
g, sourceNode.pubkey, testCase.linkQuery,
fn.None[TlvTrafficShaper](),
)
require.NoError(t, err)

bandwidth, found := m.availableChanBandwidth(
testCase.channelID, 10,
fn.None[[]byte](),
)
require.Equal(t, testCase.expectedBandwidth, bandwidth)
require.Equal(t, testCase.expectFound, found)
Expand Down
5 changes: 4 additions & 1 deletion routing/integrated_routing_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
Expand All @@ -24,7 +25,8 @@ type mockBandwidthHints struct {
}

func (m *mockBandwidthHints) availableChanBandwidth(channelID uint64,
_ lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) {
_ lnwire.MilliSatoshi,
htlcBlob fn.Option[[]byte]) (lnwire.MilliSatoshi, bool) {

if m.hints == nil {
return 0, false
Expand Down Expand Up @@ -229,6 +231,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
// Find a route.
route, err := session.RequestRoute(
amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0,
nil,
)
if err != nil {
return attempts, err
Expand Down
12 changes: 8 additions & 4 deletions routing/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
Expand Down Expand Up @@ -98,7 +99,8 @@ type mockPaymentSessionSourceOld struct {
var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil)

func (m *mockPaymentSessionSourceOld) NewPaymentSession(
_ *LightningPayment) (PaymentSession, error) {
_ *LightningPayment,
_ fn.Option[TlvTrafficShaper]) (PaymentSession, error) {

return &mockPaymentSessionOld{
routes: m.routes,
Expand Down Expand Up @@ -160,7 +162,7 @@ type mockPaymentSessionOld struct {
var _ PaymentSession = (*mockPaymentSessionOld)(nil)

func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi,
_, height uint32) (*route.Route, error) {
_, height uint32, records record.CustomSet) (*route.Route, error) {

if m.release != nil {
m.release <- struct{}{}
Expand Down Expand Up @@ -613,7 +615,8 @@ type mockPaymentSessionSource struct {
var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)

func (m *mockPaymentSessionSource) NewPaymentSession(
payment *LightningPayment) (PaymentSession, error) {
payment *LightningPayment,
tlvShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) {

args := m.Called(payment)
return args.Get(0).(PaymentSession), args.Error(1)
Expand Down Expand Up @@ -673,7 +676,8 @@ type mockPaymentSession struct {
var _ PaymentSession = (*mockPaymentSession)(nil)

func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32) (*route.Route, error) {
activeShards, height uint32,
records record.CustomSet) (*route.Route, error) {

args := m.Called(maxAmt, feeLimit, activeShards, height)

Expand Down
6 changes: 4 additions & 2 deletions routing/pathfind.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/feature"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)

const (
Expand Down Expand Up @@ -477,7 +479,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
}

bandwidth, ok := bandwidthHints.availableChanBandwidth(
chanID, 0,
chanID, 0, fn.None[tlv.Blob](),
)

// If the bandwidth is not available, use the channel capacity.
Expand Down Expand Up @@ -1031,7 +1033,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,

edge := edgeUnifier.getEdge(
netAmountReceived, g.bandwidthHints,
partialPath.outboundFee,
partialPath.outboundFee, fn.None[tlv.Blob](),
)

if edge == nil {
Expand Down
2 changes: 1 addition & 1 deletion routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ func (p *paymentLifecycle) requestRoute(
// Query our payment session to construct a route.
rt, err := p.paySession.RequestRoute(
ps.RemainingAmt, remainingFees,
uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight),
uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), nil,
)

// Exit early if there's no error.
Expand Down
7 changes: 5 additions & 2 deletions routing/payment_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
)

Expand Down Expand Up @@ -138,7 +139,8 @@ type PaymentSession interface {
// A noRouteError is returned if a non-critical error is encountered
// during path finding.
RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32) (*route.Route, error)
activeShards, height uint32,
firstHopTLVs record.CustomSet) (*route.Route, error)

// UpdateAdditionalEdge takes an additional channel edge policy
// (private channels) and applies the update from the message. Returns
Expand Down Expand Up @@ -228,7 +230,8 @@ func newPaymentSession(p *LightningPayment,
// NOTE: This function is safe for concurrent access.
// NOTE: Part of the PaymentSession interface.
func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32) (*route.Route, error) {
activeShards, height uint32,
firstHopTLVs record.CustomSet) (*route.Route, error) {

if p.empty {
return nil, errEmptyPaySession
Expand Down
10 changes: 8 additions & 2 deletions routing/payment_session_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/zpay32"
Expand Down Expand Up @@ -42,6 +43,10 @@ type SessionSource struct {
// PathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probability.
PathFindingConfig PathFindingConfig

// TrafficShaper is an optional traffic shaper that can be used to
// control the outgoing channel of a payment.
TrafficShaper fn.Option[TlvTrafficShaper]
}

// getRoutingGraph returns a routing graph and a clean-up function for
Expand All @@ -63,12 +68,13 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
// view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the
// payment's destination.
func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
PaymentSession, error) {
func (m *SessionSource) NewPaymentSession(p *LightningPayment,
trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) {

getBandwidthHints := func(graph routingGraph) (bandwidthHints, error) {
return newBandwidthManager(
graph, m.SourceNode.PubKeyBytes, m.GetLink,
trafficShaper,
)
}

Expand Down
2 changes: 1 addition & 1 deletion routing/payment_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestRequestRoute(t *testing.T) {
}

route, err := session.RequestRoute(
payment.Amount, payment.FeeLimit, 0, height,
payment.Amount, payment.FeeLimit, 0, height, nil,
)
if err != nil {
t.Fatal(err)
Expand Down
Loading

0 comments on commit 37cbfea

Please sign in to comment.