diff --git a/routing/pathfind.go b/routing/pathfind.go index f026e0ce93..4aa726d130 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -435,6 +435,10 @@ type RestrictParams struct { // BlindedPayment is necessary to determine the hop size of the // last/exit hop. BlindedPayment *BlindedPayment + + // FirstHopCustomRecords includes any records that should be included in + // the update_add_htlc message towards our peer. + FirstHopCustomRecords record.CustomSet } // PathFindingConfig defines global parameters that control the trade-off in @@ -461,9 +465,11 @@ type PathFindingConfig struct { // available balance. func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, bandwidthHints bandwidthHints, - g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { + g routingGraph, htlcBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi, + lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi + cb := func(channel *channeldb.DirectedChannel) error { if !channel.OutPolicySet { return nil @@ -479,7 +485,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, } bandwidth, ok := bandwidthHints.availableChanBandwidth( - chanID, 0, fn.None[tlv.Blob](), + chanID, 0, htlcBlob, ) // If the bandwidth is not available, use the channel capacity. @@ -493,7 +499,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, max = bandwidth } - total += bandwidth + total = overflowSafeAdd(total, bandwidth) return nil } @@ -601,8 +607,15 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, self := g.graph.sourceNode() if source == self { + customRecords := lnwire.CustomRecords(r.FirstHopCustomRecords) + firstHopData, err := customRecords.Serialize() + if err != nil { + return nil, 0, err + } + max, total, err := getOutgoingBalance( self, outgoingChanMap, g.bandwidthHints, g.graph, + fn.Some[tlv.Blob](firstHopData), ) if err != nil { return nil, 0, err @@ -1031,9 +1044,18 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, continue } + firstHopTLVs := lnwire.CustomRecords( + r.FirstHopCustomRecords, + ) + firstHopData, err := firstHopTLVs.Serialize() + if err != nil { + return nil, 0, err + } + edge := edgeUnifier.getEdge( netAmountReceived, g.bandwidthHints, - partialPath.outboundFee, fn.None[tlv.Blob](), + partialPath.outboundFee, + fn.Some[tlv.Blob](firstHopData), ) if edge == nil { @@ -1225,3 +1247,14 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, // The final hop does not have a short chanID set. return finalHop.PayloadSize(0) } + +// overflowSafeAdd adds two MilliSatoshi values and returns the result. If an +// overflow could occur, the maximum uint64 value is returned instead. +func overflowSafeAdd(x, y lnwire.MilliSatoshi) lnwire.MilliSatoshi { + if y > math.MaxUint64-x { + // Overflow would occur, return maximum uint64 value. + return math.MaxUint64 + } + + return x + y +} diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index c440525034..9b4a0d60c2 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -10,6 +10,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -359,7 +360,8 @@ func (p *paymentLifecycle) requestRoute( // Query our payment session to construct a route. rt, err := p.paySession.RequestRoute( ps.RemainingAmt, remainingFees, - uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), nil, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), + p.firstHopTLVs, ) // Exit early if there's no error. @@ -677,6 +679,44 @@ func (p *paymentLifecycle) sendAttempt( CustomRecords: lnwire.CustomRecords(p.firstHopTLVs), } + // If we had custom records in the HTLC, then we'll encode that here + // now. We allow the traffic shaper (if there is one) to overwrite the + // custom records below. But if there is no traffic shaper, we still + // want to forward these custom records. + encodedRecords, err := htlcAdd.CustomRecords.Serialize() + if err != nil { + return nil, fmt.Errorf("unable to encode first hop TLVs: %w", + err) + } + + // If a hook exists that may affect our outgoing message, we call it now + // and apply its side effects to the UpdateAddHTLC message. + err = fn.MapOptionZ( + p.router.cfg.TrafficShaper, + func(ts TlvTrafficShaper) error { + newAmt, newData, err := ts.ProduceHtlcExtraData( + rt.TotalAmount, encodedRecords, + ) + if err != nil { + return err + } + + customRecords, err := lnwire.ParseCustomRecords(newData) + if err != nil { + return err + } + + htlcAdd.CustomRecords = customRecords + htlcAdd.Amount = lnwire.NewMSatFromSatoshis(newAmt) + + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("traffic shaper failed to produce "+ + "extra data: %w", err) + } + // Generate the raw encoded sphinx packet to be included along // with the htlcAdd message that we send directly to the // switch. diff --git a/routing/payment_session.go b/routing/payment_session.go index 1fb4d63214..e435b47d1e 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -253,16 +253,17 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // to our destination, respecting the recommendations from // MissionControl. restrictions := &RestrictParams{ - ProbabilitySource: p.missionControl.GetProbability, - FeeLimit: feeLimit, - OutgoingChannelIDs: p.payment.OutgoingChannelIDs, - LastHop: p.payment.LastHop, - CltvLimit: cltvLimit, - DestCustomRecords: p.payment.DestCustomRecords, - DestFeatures: p.payment.DestFeatures, - PaymentAddr: p.payment.PaymentAddr, - Amp: p.payment.amp, - Metadata: p.payment.Metadata, + ProbabilitySource: p.missionControl.GetProbability, + FeeLimit: feeLimit, + OutgoingChannelIDs: p.payment.OutgoingChannelIDs, + LastHop: p.payment.LastHop, + CltvLimit: cltvLimit, + DestCustomRecords: p.payment.DestCustomRecords, + DestFeatures: p.payment.DestFeatures, + PaymentAddr: p.payment.PaymentAddr, + Amp: p.payment.amp, + Metadata: p.payment.Metadata, + FirstHopCustomRecords: firstHopTLVs, } finalHtlcExpiry := int32(height) + int32(finalCltvDelta)