Skip to content

Commit

Permalink
routing: use first hop records on path finding
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeTsagk authored and guggero committed May 22, 2024
1 parent a584b75 commit 53e9f28
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 15 deletions.
41 changes: 37 additions & 4 deletions routing/pathfind.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,10 @@ type RestrictParams struct {
// BlindedPayment is necessary to determine the hop size of the
// last/exit hop.
BlindedPayment *BlindedPayment

// FirstHopCustomRecords includes any records that should be included in
// the update_add_htlc message towards our peer.
FirstHopCustomRecords record.CustomSet
}

// PathFindingConfig defines global parameters that control the trade-off in
Expand All @@ -461,9 +465,11 @@ type PathFindingConfig struct {
// available balance.
func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
bandwidthHints bandwidthHints,
g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
g routingGraph, htlcBlob fn.Option[tlv.Blob]) (lnwire.MilliSatoshi,
lnwire.MilliSatoshi, error) {

var max, total lnwire.MilliSatoshi

cb := func(channel *channeldb.DirectedChannel) error {
if !channel.OutPolicySet {
return nil
Expand All @@ -479,7 +485,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
}

bandwidth, ok := bandwidthHints.availableChanBandwidth(
chanID, 0, fn.None[tlv.Blob](),
chanID, 0, htlcBlob,
)

// If the bandwidth is not available, use the channel capacity.
Expand All @@ -493,7 +499,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
max = bandwidth
}

total += bandwidth
total = overflowSafeAdd(total, bandwidth)

return nil
}
Expand Down Expand Up @@ -601,8 +607,15 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
self := g.graph.sourceNode()

if source == self {
customRecords := lnwire.CustomRecords(r.FirstHopCustomRecords)
firstHopData, err := customRecords.Serialize()
if err != nil {
return nil, 0, err
}

max, total, err := getOutgoingBalance(
self, outgoingChanMap, g.bandwidthHints, g.graph,
fn.Some[tlv.Blob](firstHopData),
)
if err != nil {
return nil, 0, err
Expand Down Expand Up @@ -1031,9 +1044,18 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
continue
}

firstHopTLVs := lnwire.CustomRecords(
r.FirstHopCustomRecords,
)
firstHopData, err := firstHopTLVs.Serialize()
if err != nil {
return nil, 0, err
}

edge := edgeUnifier.getEdge(
netAmountReceived, g.bandwidthHints,
partialPath.outboundFee, fn.None[tlv.Blob](),
partialPath.outboundFee,
fn.Some[tlv.Blob](firstHopData),
)

if edge == nil {
Expand Down Expand Up @@ -1225,3 +1247,14 @@ func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32,
// The final hop does not have a short chanID set.
return finalHop.PayloadSize(0)
}

// overflowSafeAdd adds two MilliSatoshi values and returns the result. If an
// overflow could occur, the maximum uint64 value is returned instead.
func overflowSafeAdd(x, y lnwire.MilliSatoshi) lnwire.MilliSatoshi {
if y > math.MaxUint64-x {
// Overflow would occur, return maximum uint64 value.
return math.MaxUint64
}

return x + y
}
42 changes: 41 additions & 1 deletion routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
Expand Down Expand Up @@ -359,7 +360,8 @@ func (p *paymentLifecycle) requestRoute(
// Query our payment session to construct a route.
rt, err := p.paySession.RequestRoute(
ps.RemainingAmt, remainingFees,
uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), nil,
uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight),
p.firstHopTLVs,
)

// Exit early if there's no error.
Expand Down Expand Up @@ -677,6 +679,44 @@ func (p *paymentLifecycle) sendAttempt(
CustomRecords: lnwire.CustomRecords(p.firstHopTLVs),
}

// If we had custom records in the HTLC, then we'll encode that here
// now. We allow the traffic shaper (if there is one) to overwrite the
// custom records below. But if there is no traffic shaper, we still
// want to forward these custom records.
encodedRecords, err := htlcAdd.CustomRecords.Serialize()
if err != nil {
return nil, fmt.Errorf("unable to encode first hop TLVs: %w",
err)
}

// If a hook exists that may affect our outgoing message, we call it now
// and apply its side effects to the UpdateAddHTLC message.
err = fn.MapOptionZ(
p.router.cfg.TrafficShaper,
func(ts TlvTrafficShaper) error {
newAmt, newData, err := ts.ProduceHtlcExtraData(
rt.TotalAmount, encodedRecords,
)
if err != nil {
return err
}

customRecords, err := lnwire.ParseCustomRecords(newData)
if err != nil {
return err
}

htlcAdd.CustomRecords = customRecords
htlcAdd.Amount = lnwire.NewMSatFromSatoshis(newAmt)

return nil
},
)
if err != nil {
return nil, fmt.Errorf("traffic shaper failed to produce "+
"extra data: %w", err)
}

// Generate the raw encoded sphinx packet to be included along
// with the htlcAdd message that we send directly to the
// switch.
Expand Down
21 changes: 11 additions & 10 deletions routing/payment_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,17 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// to our destination, respecting the recommendations from
// MissionControl.
restrictions := &RestrictParams{
ProbabilitySource: p.missionControl.GetProbability,
FeeLimit: feeLimit,
OutgoingChannelIDs: p.payment.OutgoingChannelIDs,
LastHop: p.payment.LastHop,
CltvLimit: cltvLimit,
DestCustomRecords: p.payment.DestCustomRecords,
DestFeatures: p.payment.DestFeatures,
PaymentAddr: p.payment.PaymentAddr,
Amp: p.payment.amp,
Metadata: p.payment.Metadata,
ProbabilitySource: p.missionControl.GetProbability,
FeeLimit: feeLimit,
OutgoingChannelIDs: p.payment.OutgoingChannelIDs,
LastHop: p.payment.LastHop,
CltvLimit: cltvLimit,
DestCustomRecords: p.payment.DestCustomRecords,
DestFeatures: p.payment.DestFeatures,
PaymentAddr: p.payment.PaymentAddr,
Amp: p.payment.amp,
Metadata: p.payment.Metadata,
FirstHopCustomRecords: firstHopTLVs,
}

finalHtlcExpiry := int32(height) + int32(finalCltvDelta)
Expand Down

0 comments on commit 53e9f28

Please sign in to comment.