Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance update_add_htlc with remote peer's custom records #8660

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions channeldb/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2550,6 +2550,12 @@ type HTLC struct {
// HTLC. It is stored in the ExtraData field, which is used to store
// a TLV stream of additional information associated with the HTLC.
BlindingPoint lnwire.BlindingPointRecord

// CustomRecords is a set of custom TLV records that are associated with
// this HTLC. These records are used to store additional information
// about the HTLC that is not part of the standard HTLC fields. This
// field is encoded within the ExtraData field.
CustomRecords lnwire.CustomRecords
GeorgeTsagk marked this conversation as resolved.
Show resolved Hide resolved
}

// serializeExtraData encodes a TLV stream of extra data to be stored with a
Expand All @@ -2568,6 +2574,11 @@ func (h *HTLC) serializeExtraData() error {
records = append(records, &b)
})

records, err := h.CustomRecords.ExtendRecordProducers(records)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test coverage for serialization change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added that to the tracking issue.

if err != nil {
return err
}

return h.ExtraData.PackRecords(records...)
}

Expand All @@ -2589,7 +2600,18 @@ func (h *HTLC) deserializeExtraData() error {

if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil {
h.BlindingPoint = tlv.SomeRecordT(blindingPoint)

// Remove the entry from the TLV map. Anything left in the map
// will be included in the custom records field.
delete(tlvMap, h.BlindingPoint.TlvType())
}

// Set the custom records field to the remaining TLV records.
customRecords, err := lnwire.NewCustomRecordsFromTlvTypeMap(tlvMap)
if err != nil {
return err
}
h.CustomRecords = customRecords

return nil
}
Expand Down
84 changes: 47 additions & 37 deletions htlcswitch/interceptable_switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"sync"

"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb/models"
Expand Down Expand Up @@ -622,15 +623,16 @@ func (f *interceptedForward) Packet() InterceptedPacket {
ChanID: f.packet.incomingChanID,
HtlcID: f.packet.incomingHTLCID,
},
OutgoingChanID: f.packet.outgoingChanID,
Hash: f.htlc.PaymentHash,
OutgoingExpiry: f.htlc.Expiry,
OutgoingAmount: f.htlc.Amount,
IncomingAmount: f.packet.incomingAmount,
IncomingExpiry: f.packet.incomingTimeout,
CustomRecords: f.packet.customRecords,
OnionBlob: f.htlc.OnionBlob,
AutoFailHeight: f.autoFailHeight,
OutgoingChanID: f.packet.outgoingChanID,
Hash: f.htlc.PaymentHash,
OutgoingExpiry: f.htlc.Expiry,
OutgoingAmount: f.htlc.Amount,
IncomingAmount: f.packet.incomingAmount,
IncomingExpiry: f.packet.incomingTimeout,
CustomRecords: f.packet.customRecords,
OnionBlob: f.htlc.OnionBlob,
AutoFailHeight: f.autoFailHeight,
IncomingWireCustomRecords: f.packet.incomingCustomRecords,
}
}

Expand Down Expand Up @@ -659,50 +661,58 @@ func (f *interceptedForward) ResumeModified(
htlc.Amount = amount
})

//nolint:lll
err := fn.MapOptionZ(customRecords, func(records record.CustomSet) error {
if len(records) == 0 {
return nil
}
err := fn.MapOptionZ(
customRecords, func(records record.CustomSet) error {
if len(records) == 0 {
return nil
}

// Type cast and validate custom records.
htlc.CustomRecords = lnwire.CustomRecords(records)
err := htlc.CustomRecords.Validate()
if err != nil {
return fmt.Errorf("failed to validate custom "+
"records: %w", err)
}
// Type cast and validate custom records.
htlc.CustomRecords = lnwire.CustomRecords(
records,
)
err := htlc.CustomRecords.Validate()
if err != nil {
return fmt.Errorf("failed to validate "+
"custom records: %w", err)
}

return nil
})
return nil
},
)
if err != nil {
return fmt.Errorf("failed to encode custom records: %w",
err)
}

case *lnwire.UpdateFulfillHTLC:
//nolint:lll
err := fn.MapOptionZ(customRecords, func(records record.CustomSet) error {
if len(records) == 0 {
return nil
}
err := fn.MapOptionZ(
customRecords, func(records record.CustomSet) error {
if len(records) == 0 {
return nil
}

// Type cast and validate custom records.
htlc.CustomRecords = lnwire.CustomRecords(records)
err := htlc.CustomRecords.Validate()
if err != nil {
return fmt.Errorf("failed to validate custom "+
"records: %w", err)
}
// Type cast and validate custom records.
htlc.CustomRecords = lnwire.CustomRecords(
records,
)
err := htlc.CustomRecords.Validate()
if err != nil {
return fmt.Errorf("failed to validate "+
"custom records: %w", err)
}

return nil
})
return nil
},
)
if err != nil {
return fmt.Errorf("failed to encode custom records: %w",
err)
}
}

log.Tracef("Forwarding packet %v", spew.Sdump(f.packet))

// Forward to the switch. A link quit channel isn't needed, because we
// are on a different thread now.
return f.htlcSwitch.ForwardPackets(nil, f.packet)
Expand Down
4 changes: 4 additions & 0 deletions htlcswitch/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ type InterceptedPacket struct {
// OnionBlob is the onion packet for the next hop
OnionBlob [lnwire.OnionPacketSize]byte

// IncomingWireCustomRecords are user-defined records that were defined
// by the peer that forwarded this htlc to us.
IncomingWireCustomRecords record.CustomSet

// AutoFailHeight is the block height at which this intercept will be
// failed back automatically.
AutoFailHeight int32
Expand Down
77 changes: 52 additions & 25 deletions htlcswitch/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/tlv"
)

func init() {
Expand Down Expand Up @@ -3354,6 +3356,27 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
continue
}

var customRecords record.CustomSet
err = fn.MapOptionZ(
pd.CustomRecords, func(b tlv.Blob) error {
r, err := lnwire.ParseCustomRecords(b)
if err != nil {
return err
}

customRecords = record.CustomSet(r)

return nil
},
)
if err != nil {
l.fail(LinkFailureError{
code: ErrInternalError,
}, err.Error())

return
}

switch fwdPkg.State {
case channeldb.FwdStateProcessed:
// This add was not forwarded on the previous
Expand All @@ -3367,7 +3390,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
}

// Otherwise, it was already processed, we can
// can collect it and continue.
// collect it and continue.
addMsg := &lnwire.UpdateAddHTLC{
Expiry: fwdInfo.OutgoingCTLV,
Amount: fwdInfo.AmountToForward,
Expand All @@ -3387,19 +3410,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,

inboundFee := l.cfg.FwrdingPolicy.InboundFee

//nolint:lll
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
incomingCustomRecords: customRecords,
}
switchPackets = append(
switchPackets, updatePacket,
Expand Down Expand Up @@ -3455,19 +3480,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
if fwdPkg.State == channeldb.FwdStateLockedIn {
inboundFee := l.cfg.FwrdingPolicy.InboundFee

//nolint:lll
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
incomingCustomRecords: customRecords,
}

fwdPkg.FwdFilter.Set(idx)
Expand Down
4 changes: 4 additions & 0 deletions htlcswitch/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ type htlcPacket struct {
// were included in the payload.
customRecords record.CustomSet

// incomingCustomRecords are custom type range TLVs that are included
// in the incoming update_add_htlc.
incomingCustomRecords record.CustomSet

// originalOutgoingChanID is used when sending back failure messages.
// It is only used for forwarded Adds on option_scid_alias channels.
// This is to avoid possible confusion if a payer uses the public SCID
Expand Down
4 changes: 4 additions & 0 deletions itest/list_on_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ var allTestCases = []*lntest.TestCase{
Name: "forward interceptor modified htlc",
TestFunc: testForwardInterceptorModifiedHtlc,
},
{
Name: "forward interceptor wire records",
TestFunc: testForwardInterceptorWireRecords,
},
{
Name: "zero conf channel open",
TestFunc: testZeroConfChannelOpen,
Expand Down
Loading
Loading