diff --git a/lndclient_mock.go b/lndclient_mock.go index b626510..f10c6b3 100644 --- a/lndclient_mock.go +++ b/lndclient_mock.go @@ -27,16 +27,18 @@ type lndclientMock struct { htlcInterceptorRequests chan *interceptedEvent htlcInterceptorResponses chan *interceptResponse - channels map[uint64]*channel + channels map[uint64]*channel + closedChannels map[uint64]*channel } -func newLndclientMock(channels map[uint64]*channel) *lndclientMock { +func newLndclientMock(channels, closedChannels map[uint64]*channel) *lndclientMock { return &lndclientMock{ htlcEvents: make(chan *resolvedEvent), htlcInterceptorRequests: make(chan *interceptedEvent), htlcInterceptorResponses: make(chan *interceptResponse), - channels: channels, + channels: channels, + closedChannels: closedChannels, } } @@ -51,7 +53,7 @@ func (l *lndclientMock) listChannels() (map[uint64]*channel, error) { } func (l *lndclientMock) listClosedChannels() (map[uint64]*channel, error) { - return make(map[uint64]*channel), nil + return l.closedChannels, nil } func (l *lndclientMock) subscribeHtlcEvents(ctx context.Context) ( diff --git a/process.go b/process.go index c80621b..a88ad14 100644 --- a/process.go +++ b/process.go @@ -526,6 +526,23 @@ func (p *process) getChanInfo(channel uint64) (*channel, error) { return ch, nil } + // If the channel is not open, fall back to checking our closed + // channels. + closedChannels, err := p.client.listClosedChannels() + if err != nil { + return nil, err + } + + // Add to cache and try again. + for chanId, ch := range closedChannels { + p.chanMap[chanId] = ch + } + + ch, ok = p.chanMap[channel] + if ok { + return ch, nil + } + // Channel not found. return nil, fmt.Errorf("%w: %v", errChannelNotFound, channel) } diff --git a/process_test.go b/process_test.go index 0feb159..3706470 100644 --- a/process_test.go +++ b/process_test.go @@ -33,7 +33,7 @@ const ( ) func testProcess(t *testing.T, event resolveEvent) { - client := newLndclientMock(testChannels) + client := newLndclientMock(testChannels, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -130,7 +130,7 @@ func testRateLimit(t *testing.T, mode Mode) { }, } - client := newLndclientMock(testChannels) + client := newLndclientMock(testChannels, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -222,7 +222,7 @@ func testMaxPending(t *testing.T, mode Mode) { }, } - client := newLndclientMock(testChannels) + client := newLndclientMock(testChannels, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -278,7 +278,7 @@ func testMaxPending(t *testing.T, mode Mode) { func TestNewPeer(t *testing.T) { // Initialize lnd with test channels. - client := newLndclientMock(testChannels) + client := newLndclientMock(testChannels, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -335,7 +335,7 @@ func TestBlocked(t *testing.T) { }, } - client := newLndclientMock(testChannels) + client := newLndclientMock(testChannels, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -370,7 +370,7 @@ func TestBlocked(t *testing.T) { // TestChannelNotFound tests that we'll successfully exit when we cannot lookup the // channel that a htlc belongs to. func TestChannelNotFound(t *testing.T) { - client := newLndclientMock(testChannels) + client := newLndclientMock(testChannels, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -407,3 +407,47 @@ func TestChannelNotFound(t *testing.T) { t.Fatalf("timeout on process error") } } + +// TestClosedChannelHtlc tests that we can handle intercepted htlcs that are associated +// with closed channels. +func TestClosedChannelHtlc(t *testing.T) { + // Initialize lnd with a closed channel. + var testClosedChannels = map[uint64]*channel{ + 5: {peer: route.Vertex{2}}, + } + client := newLndclientMock(testChannels, testClosedChannels) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, cleanup := setupTestDb(t, defaultFwdHistoryLimit) + defer cleanup() + + log := zaptest.NewLogger(t).Sugar() + + cfg := &Limits{} + + p := NewProcess(client, log, cfg, db) + + exit := make(chan error) + + go func() { + exit <- p.Run(ctx) + }() + + // Send a htlc that is from a closed channel, it should be given the go-ahead to + // resume. + key := circuitKey{ + channel: 5, + htlc: 3, + } + client.htlcInterceptorRequests <- &interceptedEvent{ + circuitKey: key, + } + + resp := <-client.htlcInterceptorResponses + require.Equal(t, key, resp.key) + + cancel() + require.ErrorIs(t, <-exit, context.Canceled) +}