diff --git a/config_builder.go b/config_builder.go index 7af3273824..5679de10bf 100644 --- a/config_builder.go +++ b/config_builder.go @@ -44,6 +44,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/btcwallet" "github.com/lightningnetwork/lnd/lnwallet/rpcwallet" "github.com/lightningnetwork/lnd/macaroons" + "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" @@ -157,6 +158,10 @@ type AuxComponents struct { // AuxLeafStore is an optional data source that can be used by custom // channels to fetch+store various data. AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + + // TrafficShaper is an optional traffic shaper that can be used to + // control the outgoing channel of a payment. + TrafficShaper fn.Option[routing.TlvTrafficShaper] } // DefaultWalletImpl is the default implementation of our normal, btcwallet diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index eda85cfb0f..0bada0d10b 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/tlv" ) // InvoiceDatabase is an interface which represents the persistent subsystem @@ -271,6 +272,16 @@ type ChannelLink interface { // have buffered messages. AttachMailBox(MailBox) + // FundingCustomBlob returns the custom funding blob of the channel that + // this link is associated with. The funding blob represents static + // information about the channel that was created at channel funding + // time. + FundingCustomBlob() fn.Option[tlv.Blob] + + // CommitmentCustomBlob returns the custom blob of the current local + // commitment of the channel that this link is associated with. + CommitmentCustomBlob() fn.Option[tlv.Blob] + // Start/Stop are used to initiate the start/stop of the channel link // functioning. Start() error diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 8eca0f004d..ab4284d62a 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -3775,3 +3775,16 @@ func (l *channelLink) fail(linkErr LinkFailureError, l.failed = true l.cfg.OnChannelFailure(l.ChanID(), l.ShortChanID(), linkErr) } + +// FundingCustomBlob returns the custom funding blob of the channel that this +// link is associated with. The funding blob represents static information about +// the channel that was created at channel funding time. +func (l *channelLink) FundingCustomBlob() fn.Option[tlv.Blob] { + return l.channel.State().CustomBlob +} + +// CommitmentCustomBlob returns the custom blob of the current local commitment +// of the channel that this link is associated with. +func (l *channelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { + return l.channel.LocalCommitmentBlob() +} diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index a0f38c74fe..c56b47548d 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -27,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnpeer" @@ -35,6 +36,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/ticker" + "github.com/lightningnetwork/lnd/tlv" ) func isAlias(scid lnwire.ShortChannelID) bool { @@ -912,6 +914,10 @@ func (f *mockChannelLink) ChannelPoint() wire.OutPoint { return wire.OutPoint{} } +func (f *mockChannelLink) ChannelCustomBlob() fn.Option[tlv.Blob] { + return fn.Option[tlv.Blob]{} +} + func (f *mockChannelLink) Stop() {} func (f *mockChannelLink) EligibleToForward() bool { return f.eligible } func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil } @@ -942,6 +948,14 @@ func (f *mockChannelLink) OnCommitOnce(LinkDirection, func()) { // TODO(proofofkeags): Implement } +func (f *mockChannelLink) FundingCustomBlob() fn.Option[tlv.Blob] { + return fn.None[tlv.Blob]() +} + +func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { + return fn.None[tlv.Blob]() +} + var _ ChannelLink = (*mockChannelLink)(nil) func newDB() (*channeldb.DB, func(), error) { diff --git a/lnwallet/channel.go b/lnwallet/channel.go index ad54668b4d..51bea4ba43 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -9589,3 +9589,19 @@ func (lc *LightningChannel) MultiSigKeys() (keychain.KeyDescriptor, return lc.channelState.LocalChanCfg.MultiSigKey, lc.channelState.RemoteChanCfg.MultiSigKey } + +// LocalCommitmentBlob returns the custom blob of the local commitment. +func (lc *LightningChannel) LocalCommitmentBlob() fn.Option[tlv.Blob] { + lc.RLock() + defer lc.RUnlock() + + chanState := lc.channelState + localBalance := chanState.LocalCommitment.CustomBlob + + return fn.MapOption(func(b tlv.Blob) tlv.Blob { + newBlob := make([]byte, len(b)) + copy(newBlob, b) + + return newBlob + })(localBalance) +} diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 19c6087018..564168225d 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -1,10 +1,15 @@ package routing import ( + "fmt" + + "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) // bandwidthHints provides hints about the currently available balance in our @@ -18,7 +23,39 @@ type bandwidthHints interface { // will be used. If the channel is unavailable, a zero amount is // returned. availableChanBandwidth(channelID uint64, - amount lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) + amount lnwire.MilliSatoshi, + htlcBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, bool) +} + +// TlvTrafficShaper is an interface that allows the sender to determine if a +// payment should be carried by a channel based on the TLV records that may be +// present in the `update_add_htlc` message or the channel commitment itself. +type TlvTrafficShaper interface { + AuxHtlcModifier + + // HandleTraffic is called in order to check if the channel identified + // by the provided channel ID may have external mechanisms that would + // allow it to carry out the payment. + HandleTraffic(cid lnwire.ShortChannelID, + fundingBlob fn.Option[tlv.Blob]) (bool, error) + + // PaymentBandwidth returns the available bandwidth for a custom channel + // decided by the given channel aux blob and HTLC blob. A return value + // of 0 means there is no bandwidth available. To find out if a channel + // is a custom channel that should be handled by the traffic shaper, the + // HandleTraffic method should be called first. + PaymentBandwidth(htlcBlob, + commitmentBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, error) +} + +// AuxHtlcModifier is an interface that allows the sender to modify the outgoing +// HTLC of a payment by changing the amount or the wire message tlv records. +type AuxHtlcModifier interface { + // ProduceHtlcExtraData is a function that, based on the previous extra + // data blob of an HTLC, may produce a different blob or modify the + // amount of bitcoin this htlc should carry. + ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, + htlcBlob tlv.Blob) (btcutil.Amount, tlv.Blob, error) } // getLinkQuery is the function signature used to lookup a link. @@ -29,8 +66,9 @@ type getLinkQuery func(lnwire.ShortChannelID) ( // uses the link lookup provided to query the link for our latest local channel // balances. type bandwidthManager struct { - getLink getLinkQuery - localChans map[lnwire.ShortChannelID]struct{} + getLink getLinkQuery + localChans map[lnwire.ShortChannelID]struct{} + trafficShaper fn.Option[TlvTrafficShaper] } // newBandwidthManager creates a bandwidth manager for the source node provided @@ -40,11 +78,13 @@ type bandwidthManager struct { // allows us to reduce the number of extraneous attempts as we can skip channels // that are inactive, or just don't have enough bandwidth to carry the payment. func newBandwidthManager(graph routingGraph, sourceNode route.Vertex, - linkQuery getLinkQuery) (*bandwidthManager, error) { + linkQuery getLinkQuery, + trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) { manager := &bandwidthManager{ - getLink: linkQuery, - localChans: make(map[lnwire.ShortChannelID]struct{}), + getLink: linkQuery, + localChans: make(map[lnwire.ShortChannelID]struct{}), + trafficShaper: trafficShaper, } // First, we'll collect the set of outbound edges from the target @@ -71,7 +111,8 @@ func newBandwidthManager(graph routingGraph, sourceNode route.Vertex, // queried is one of our local channels, so any failure to retrieve the link // is interpreted as the link being offline. func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, - amount lnwire.MilliSatoshi) lnwire.MilliSatoshi { + amount lnwire.MilliSatoshi, + htlcBlob fn.Option[tlv.Blob]) lnwire.MilliSatoshi { link, err := b.getLink(cid) if err != nil { @@ -89,16 +130,68 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, return 0 } - // If our link isn't currently in a state where it can add another - // outgoing htlc, treat the link as unusable. + var ( + auxBandwidth lnwire.MilliSatoshi + auxBandwidthDetermined bool + ) + err = fn.MapOptionZ(b.trafficShaper, func(ts TlvTrafficShaper) error { + fundingBlob := link.FundingCustomBlob() + shouldHandle, err := ts.HandleTraffic(cid, fundingBlob) + if err != nil { + return fmt.Errorf("traffic shaper failed to decide "+ + "whether to handle traffic: %w", err) + } + + log.Debugf("ShortChannelID=%v: external traffic shaper is "+ + "handling traffic: %v", cid, shouldHandle) + + // If this channel isn't handled by the external traffic shaper, + // we'll return early. + if !shouldHandle { + return nil + } + + // Ask for a specific bandwidth to be used for the channel. + commitmentBlob := link.CommitmentCustomBlob() + auxBandwidth, err = ts.PaymentBandwidth( + htlcBlob, commitmentBlob, + ) + if err != nil { + return fmt.Errorf("failed to get bandwidth from "+ + "external traffic shaper: %w", err) + } + + log.Debugf("ShortChannelID=%v: external traffic shaper "+ + "reported available bandwidth: %v", cid, auxBandwidth) + + auxBandwidthDetermined = true + + return nil + }) + if err != nil { + log.Errorf("ShortChannelID=%v: failed to get bandwidth from "+ + "external traffic shaper: %v", cid, err) + + return 0 + } + + // If our link isn't currently in a state where it can add + // another outgoing htlc, treat the link as unusable. if err := link.MayAddOutgoingHtlc(amount); err != nil { - log.Warnf("ShortChannelID=%v: cannot add outgoing htlc: %v", - cid, err) + log.Warnf("ShortChannelID=%v: cannot add outgoing "+ + "htlc: %v", cid, err) return 0 } - // Otherwise, we'll return the current best estimate for the available - // bandwidth for the link. + // If the external traffic shaper determined the bandwidth, we'll return + // that value, even if it is zero (which would mean no bandwidth is + // available on that channel). + if auxBandwidthDetermined { + return auxBandwidth + } + + // Otherwise, we'll return the current best estimate for the + // available bandwidth for the link. return link.Bandwidth() } @@ -106,7 +199,8 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, // and a bool indicating whether the channel hint was found. If the channel is // unavailable, a zero amount is returned. func (b *bandwidthManager) availableChanBandwidth(channelID uint64, - amount lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) { + amount lnwire.MilliSatoshi, + htlcBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, bool) { shortID := lnwire.NewShortChanIDFromInt(channelID) _, ok := b.localChans[shortID] @@ -114,5 +208,5 @@ func (b *bandwidthManager) availableChanBandwidth(channelID uint64, return 0, false } - return b.getBandwidth(shortID, amount), true + return b.getBandwidth(shortID, amount, htlcBlob), true } diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index ef12d69737..c1423aa9e6 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" @@ -115,11 +116,13 @@ func TestBandwidthManager(t *testing.T) { m, err := newBandwidthManager( g, sourceNode.pubkey, testCase.linkQuery, + fn.None[TlvTrafficShaper](), ) require.NoError(t, err) bandwidth, found := m.availableChanBandwidth( testCase.channelID, 10, + fn.None[[]byte](), ) require.Equal(t, testCase.expectedBandwidth, bandwidth) require.Equal(t, testCase.expectFound, found) diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 4215d3b254..6febaaa174 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -24,7 +25,8 @@ type mockBandwidthHints struct { } func (m *mockBandwidthHints) availableChanBandwidth(channelID uint64, - _ lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) { + _ lnwire.MilliSatoshi, + htlcBlob fn.Option[[]byte]) (lnwire.MilliSatoshi, bool) { if m.hints == nil { return 0, false @@ -229,6 +231,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, // Find a route. route, err := session.RequestRoute( amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0, + nil, ) if err != nil { return attempts, err diff --git a/routing/mock_test.go b/routing/mock_test.go index f712c420de..d64eb2b45a 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -9,6 +9,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -98,7 +99,8 @@ type mockPaymentSessionSourceOld struct { var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil) func (m *mockPaymentSessionSourceOld) NewPaymentSession( - _ *LightningPayment) (PaymentSession, error) { + _ *LightningPayment, + _ fn.Option[TlvTrafficShaper]) (PaymentSession, error) { return &mockPaymentSessionOld{ routes: m.routes, @@ -160,7 +162,7 @@ type mockPaymentSessionOld struct { var _ PaymentSession = (*mockPaymentSessionOld)(nil) func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, - _, height uint32) (*route.Route, error) { + _, height uint32, records record.CustomSet) (*route.Route, error) { if m.release != nil { m.release <- struct{}{} @@ -613,7 +615,8 @@ type mockPaymentSessionSource struct { var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) func (m *mockPaymentSessionSource) NewPaymentSession( - payment *LightningPayment) (PaymentSession, error) { + payment *LightningPayment, + tlvShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { args := m.Called(payment) return args.Get(0).(PaymentSession), args.Error(1) @@ -673,7 +676,8 @@ type mockPaymentSession struct { var _ PaymentSession = (*mockPaymentSession)(nil) func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32) (*route.Route, error) { + activeShards, height uint32, + records record.CustomSet) (*route.Route, error) { args := m.Called(maxAmt, feeLimit, activeShards, height) diff --git a/routing/pathfind.go b/routing/pathfind.go index add833ecc7..4aa726d130 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -13,9 +13,11 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/feature" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -433,6 +435,10 @@ type RestrictParams struct { // BlindedPayment is necessary to determine the hop size of the // last/exit hop. BlindedPayment *BlindedPayment + + // FirstHopCustomRecords includes any records that should be included in + // the update_add_htlc message towards our peer. + FirstHopCustomRecords record.CustomSet } // PathFindingConfig defines global parameters that control the trade-off in @@ -459,9 +465,11 @@ type PathFindingConfig struct { // available balance. func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, bandwidthHints bandwidthHints, - g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { + g routingGraph, htlcBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, + lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi + cb := func(channel *channeldb.DirectedChannel) error { if !channel.OutPolicySet { return nil @@ -477,7 +485,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, } bandwidth, ok := bandwidthHints.availableChanBandwidth( - chanID, 0, + chanID, 0, htlcBlob, ) // If the bandwidth is not available, use the channel capacity. @@ -491,7 +499,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, max = bandwidth } - total += bandwidth + total = overflowSafeAdd(total, bandwidth) return nil } @@ -599,8 +607,15 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, self := g.graph.sourceNode() if source == self { + customRecords := lnwire.CustomRecords(r.FirstHopCustomRecords) + firstHopData, err := customRecords.Serialize() + if err != nil { + return nil, 0, err + } + max, total, err := getOutgoingBalance( self, outgoingChanMap, g.bandwidthHints, g.graph, + fn.Some[tlv.Blob](firstHopData), ) if err != nil { return nil, 0, err @@ -1029,9 +1044,18 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, continue } + firstHopTLVs := lnwire.CustomRecords( + r.FirstHopCustomRecords, + ) + firstHopData, err := firstHopTLVs.Serialize() + if err != nil { + return nil, 0, err + } + edge := edgeUnifier.getEdge( netAmountReceived, g.bandwidthHints, partialPath.outboundFee, + fn.Some[tlv.Blob](firstHopData), ) if edge == nil { @@ -1223,3 +1247,14 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, // The final hop does not have a short chanID set. return finalHop.PayloadSize(0) } + +// overflowSafeAdd adds two MilliSatoshi values and returns the result. If an +// overflow could occur, the maximum uint64 value is returned instead. +func overflowSafeAdd(x, y lnwire.MilliSatoshi) lnwire.MilliSatoshi { + if y > math.MaxUint64-x { + // Overflow would occur, return maximum uint64 value. + return math.MaxUint64 + } + + return x + y +} diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 15d18f7da2..9b4a0d60c2 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -10,6 +10,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -360,6 +361,7 @@ func (p *paymentLifecycle) requestRoute( rt, err := p.paySession.RequestRoute( ps.RemainingAmt, remainingFees, uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), + p.firstHopTLVs, ) // Exit early if there's no error. @@ -677,6 +679,44 @@ func (p *paymentLifecycle) sendAttempt( CustomRecords: lnwire.CustomRecords(p.firstHopTLVs), } + // If we had custom records in the HTLC, then we'll encode that here + // now. We allow the traffic shaper (if there is one) to overwrite the + // custom records below. But if there is no traffic shaper, we still + // want to forward these custom records. + encodedRecords, err := htlcAdd.CustomRecords.Serialize() + if err != nil { + return nil, fmt.Errorf("unable to encode first hop TLVs: %w", + err) + } + + // If a hook exists that may affect our outgoing message, we call it now + // and apply its side effects to the UpdateAddHTLC message. + err = fn.MapOptionZ( + p.router.cfg.TrafficShaper, + func(ts TlvTrafficShaper) error { + newAmt, newData, err := ts.ProduceHtlcExtraData( + rt.TotalAmount, encodedRecords, + ) + if err != nil { + return err + } + + customRecords, err := lnwire.ParseCustomRecords(newData) + if err != nil { + return err + } + + htlcAdd.CustomRecords = customRecords + htlcAdd.Amount = lnwire.NewMSatFromSatoshis(newAmt) + + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("traffic shaper failed to produce "+ + "extra data: %w", err) + } + // Generate the raw encoded sphinx packet to be included along // with the htlcAdd message that we send directly to the // switch. diff --git a/routing/payment_session.go b/routing/payment_session.go index 2d174244c8..e435b47d1e 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -10,6 +10,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" ) @@ -138,7 +139,8 @@ type PaymentSession interface { // A noRouteError is returned if a non-critical error is encountered // during path finding. RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32) (*route.Route, error) + activeShards, height uint32, + firstHopTLVs record.CustomSet) (*route.Route, error) // UpdateAdditionalEdge takes an additional channel edge policy // (private channels) and applies the update from the message. Returns @@ -228,7 +230,8 @@ func newPaymentSession(p *LightningPayment, // NOTE: This function is safe for concurrent access. // NOTE: Part of the PaymentSession interface. func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32) (*route.Route, error) { + activeShards, height uint32, + firstHopTLVs record.CustomSet) (*route.Route, error) { if p.empty { return nil, errEmptyPaySession @@ -250,16 +253,17 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // to our destination, respecting the recommendations from // MissionControl. restrictions := &RestrictParams{ - ProbabilitySource: p.missionControl.GetProbability, - FeeLimit: feeLimit, - OutgoingChannelIDs: p.payment.OutgoingChannelIDs, - LastHop: p.payment.LastHop, - CltvLimit: cltvLimit, - DestCustomRecords: p.payment.DestCustomRecords, - DestFeatures: p.payment.DestFeatures, - PaymentAddr: p.payment.PaymentAddr, - Amp: p.payment.amp, - Metadata: p.payment.Metadata, + ProbabilitySource: p.missionControl.GetProbability, + FeeLimit: feeLimit, + OutgoingChannelIDs: p.payment.OutgoingChannelIDs, + LastHop: p.payment.LastHop, + CltvLimit: cltvLimit, + DestCustomRecords: p.payment.DestCustomRecords, + DestFeatures: p.payment.DestFeatures, + PaymentAddr: p.payment.PaymentAddr, + Amp: p.payment.amp, + Metadata: p.payment.Metadata, + FirstHopCustomRecords: firstHopTLVs, } finalHtlcExpiry := int32(height) + int32(finalCltvDelta) diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index b96a2294ba..61a2d3d931 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -4,6 +4,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/zpay32" @@ -42,6 +43,10 @@ type SessionSource struct { // PathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probability. PathFindingConfig PathFindingConfig + + // TrafficShaper is an optional traffic shaper that can be used to + // control the outgoing channel of a payment. + TrafficShaper fn.Option[TlvTrafficShaper] } // getRoutingGraph returns a routing graph and a clean-up function for @@ -63,12 +68,13 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the // payment's destination. -func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( - PaymentSession, error) { +func (m *SessionSource) NewPaymentSession(p *LightningPayment, + trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { getBandwidthHints := func(graph routingGraph) (bandwidthHints, error) { return newBandwidthManager( graph, m.SourceNode.PubKeyBytes, m.GetLink, + trafficShaper, ) } diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 75b84a51a3..c803909e08 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -238,7 +238,7 @@ func TestRequestRoute(t *testing.T) { } route, err := session.RequestRoute( - payment.Amount, payment.FeeLimit, 0, height, + payment.Amount, payment.FeeLimit, 0, height, nil, ) if err != nil { t.Fatal(err) diff --git a/routing/router.go b/routing/router.go index e8a59f829e..0bc2dee95e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -22,6 +22,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" @@ -243,7 +244,9 @@ type PaymentSessionSource interface { // routes to the given target. An optional set of routing hints can be // provided in order to populate additional edges to explore when // finding a path to the payment's destination. - NewPaymentSession(p *LightningPayment) (PaymentSession, error) + NewPaymentSession(p *LightningPayment, + trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, + error) // NewPaymentSessionEmpty creates a new paymentSession instance that is // empty, and will be exhausted immediately. Used for failure reporting @@ -409,6 +412,10 @@ type Config struct { // IsAlias returns whether a passed ShortChannelID is an alias. This is // only used for our local channels. IsAlias func(scid lnwire.ShortChannelID) bool + + // TrafficShaper is an optional traffic shaper that can be used to + // control the outgoing channel of a payment. + TrafficShaper fn.Option[TlvTrafficShaper] } // EdgeLocator is a struct used to identify a specific edge. @@ -2095,6 +2102,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, // eliminate certain routes early on in the path finding process. bandwidthHints, err := newBandwidthManager( r.cachedGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, + r.cfg.TrafficShaper, ) if err != nil { return nil, 0, err @@ -2457,7 +2465,9 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) ( // Before starting the HTLC routing attempt, we'll create a fresh // payment session which will report our errors back to mission // control. - paySession, err := r.cfg.SessionSource.NewPaymentSession(payment) + paySession, err := r.cfg.SessionSource.NewPaymentSession( + payment, r.cfg.TrafficShaper, + ) if err != nil { return nil, nil, err } @@ -3106,6 +3116,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := newBandwidthManager( r.cachedGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, + r.cfg.TrafficShaper, ) if err != nil { return nil, err @@ -3202,7 +3213,9 @@ func getRouteUnifiers(source route.Vertex, hops []route.Vertex, } // Get an edge for the specific amount that we want to forward. - edge := edgeUnifier.getEdge(runningAmt, bandwidthHints, 0) + edge := edgeUnifier.getEdge( + runningAmt, bandwidthHints, 0, fn.Option[[]byte]{}, + ) if edge == nil { log.Errorf("Cannot find policy with amt=%v for node %v", runningAmt, fromNode) @@ -3240,7 +3253,9 @@ func getPathEdges(source route.Vertex, receiverAmt lnwire.MilliSatoshi, // amount ranges re-checked. var pathEdges []*unifiedEdge for i, unifier := range unifiers { - edge := unifier.getEdge(receiverAmt, bandwidthHints, 0) + edge := unifier.getEdge( + receiverAmt, bandwidthHints, 0, fn.Option[[]byte]{}, + ) if edge == nil { fromNode := source if i > 0 { diff --git a/routing/unified_edges.go b/routing/unified_edges.go index 44efc6314e..04e135a533 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -6,9 +6,11 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) // nodeEdgeUnifier holds all edge unifiers for connections towards a node. @@ -181,12 +183,12 @@ type edgeUnifier struct { // specific amount to send. It differentiates between local and network // channels. func (u *edgeUnifier) getEdge(netAmtReceived lnwire.MilliSatoshi, - bandwidthHints bandwidthHints, - nextOutFee lnwire.MilliSatoshi) *unifiedEdge { + bandwidthHints bandwidthHints, nextOutFee lnwire.MilliSatoshi, + htlcBlob fn.Option[tlv.Blob]) *unifiedEdge { if u.localChan { return u.getEdgeLocal( - netAmtReceived, bandwidthHints, nextOutFee, + netAmtReceived, bandwidthHints, nextOutFee, htlcBlob, ) } @@ -213,8 +215,8 @@ func calcCappedInboundFee(edge *unifiedEdge, amt lnwire.MilliSatoshi, // getEdgeLocal returns the optimal unified edge to use for this local // connection given a specific amount to send. func (u *edgeUnifier) getEdgeLocal(netAmtReceived lnwire.MilliSatoshi, - bandwidthHints bandwidthHints, - nextOutFee lnwire.MilliSatoshi) *unifiedEdge { + bandwidthHints bandwidthHints, nextOutFee lnwire.MilliSatoshi, + htlcBlob fn.Option[tlv.Blob]) *unifiedEdge { var ( bestEdge *unifiedEdge @@ -251,7 +253,7 @@ func (u *edgeUnifier) getEdgeLocal(netAmtReceived lnwire.MilliSatoshi, // channel. The bandwidth hint is expected to be // available. bandwidth, ok := bandwidthHints.availableChanBandwidth( - edge.policy.ChannelID, amt, + edge.policy.ChannelID, amt, htlcBlob, ) if !ok { log.Debugf("Cannot get bandwidth for edge %v, use max "+ diff --git a/routing/unified_edges_test.go b/routing/unified_edges_test.go index 7b1650c025..566b50517c 100644 --- a/routing/unified_edges_test.go +++ b/routing/unified_edges_test.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -230,6 +231,7 @@ func TestNodeEdgeUnifier(t *testing.T) { edge := test.unifier.edgeUnifiers[fromNode].getEdge( test.amount, bandwidthHints, test.nextOutFee, + fn.None[[]byte](), ) if test.expectNoPolicy { diff --git a/server.go b/server.go index bebf233f78..8f7ddce677 100644 --- a/server.go +++ b/server.go @@ -1001,6 +1001,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, Clock: clock.NewDefaultClock(), StrictZombiePruning: strictPruning, IsAlias: aliasmgr.IsAlias, + TrafficShaper: implCfg.TrafficShaper, }) if err != nil { return nil, fmt.Errorf("can't create router: %w", err)