Skip to content

Commit

Permalink
Merge pull request lightninglabs#1244 from GeorgeTsagk/strict-forward…
Browse files Browse the repository at this point in the history
…ing-p2

Strict forwarding pt.2
  • Loading branch information
GeorgeTsagk authored Dec 13, 2024
2 parents 35d8f1c + ffc6e68 commit 7358c1b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 11 deletions.
7 changes: 7 additions & 0 deletions tapchannel/aux_invoice_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ func (s *AuxInvoiceManager) handleInvoiceAccept(_ context.Context,
resp.CancelSet = true
}

return resp, nil
} else if !isAssetInvoice(req.Invoice, s) && !req.Invoice.IsKeysend {
// If we do have custom records, but the invoice does not
// correspond to an asset invoice, we do not settle the invoice.
// Since we requested btc we should be receiving btc.
resp.CancelSet = true

return resp, nil
}

Expand Down
57 changes: 46 additions & 11 deletions tapchannel/aux_invoice_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ import (
)

const (
// The test channel ID to use across the test cases.
testChanID = 1234

// maxRandomInvoiceValueMSat is the maximum invoice value in mSAT to be
// generated by the property based tests.
maxRandomInvoiceValueMSat = 100_000_000_000
Expand All @@ -52,12 +49,20 @@ var (
// The node ID to be used for the RFQ peer.
testNodeID = route.Vertex{1, 2, 3}

// The asset rate value to use across tests.
assetRate = big.NewInt(100_000)

// The asset rate struct based on the assetRate value.
testAssetRate = rfqmath.FixedPoint[rfqmath.BigInt]{
Coefficient: rfqmath.NewBigInt(assetRate),
Scale: 0,
}

// The test RFQ ID to use across tests.
testRfqID = dummyRfqID(31)

// The test RFQ SCID that is derived from testRfqID.
testScid = testRfqID.Scid()
)

// mockRfqManager mocks the interface of the rfq manager required by the aux
Expand Down Expand Up @@ -184,6 +189,11 @@ func (m *mockHtlcModifierProperty) HtlcModifier(ctx context.Context,
if r.ExitHtlcAmt != res.AmtPaid {
m.t.Errorf("AmtPaid != ExitHtlcAmt")
}
} else if !isAssetInvoice(r.Invoice, m) {
if !res.CancelSet {
m.t.Errorf("expected cancel set flag")
}
continue
}

htlcBlob, err := r.WireCustomRecords.Serialize()
Expand Down Expand Up @@ -293,7 +303,7 @@ func TestAuxInvoiceManager(t *testing.T) {
},
},
buyQuotes: map[rfq.SerialisedScid]rfqmsg.BuyAccept{
testChanID: {
testScid: {
Peer: testNodeID,
},
},
Expand All @@ -315,7 +325,7 @@ func TestAuxInvoiceManager(t *testing.T) {
},
},
buyQuotes: map[rfq.SerialisedScid]rfqmsg.BuyAccept{
testChanID: {
testScid: {
Peer: testNodeID,
},
},
Expand All @@ -335,7 +345,7 @@ func TestAuxInvoiceManager(t *testing.T) {
dummyAssetID(1),
3,
),
}, fn.Some(dummyRfqID(31)),
}, fn.Some(testRfqID),
),
},
},
Expand All @@ -345,7 +355,7 @@ func TestAuxInvoiceManager(t *testing.T) {
},
},
buyQuotes: rfq.BuyAcceptMap{
fn.Ptr(dummyRfqID(31)).Scid(): {
testScid: {
Peer: testNodeID,
AssetRate: rfqmsg.NewAssetRate(
testAssetRate, time.Now(),
Expand All @@ -368,7 +378,7 @@ func TestAuxInvoiceManager(t *testing.T) {
dummyAssetID(1),
4,
),
}, fn.Some(dummyRfqID(31)),
}, fn.Some(testRfqID),
),
ExitHtlcAmt: 1234,
},
Expand All @@ -379,14 +389,39 @@ func TestAuxInvoiceManager(t *testing.T) {
},
},
buyQuotes: rfq.BuyAcceptMap{
fn.Ptr(dummyRfqID(31)).Scid(): {
testScid: {
Peer: testNodeID,
AssetRate: rfqmsg.NewAssetRate(
testAssetRate, time.Now(),
),
},
},
},
{
name: "btc invoice, custom records",
requests: []lndclient.InvoiceHtlcModifyRequest{
{
Invoice: &lnrpc.Invoice{
ValueMsat: 10_000_000,
PaymentAddr: []byte{1, 1, 1},
},
WireCustomRecords: newWireCustomRecords(
t, []*rfqmsg.AssetBalance{
rfqmsg.NewAssetBalance(
dummyAssetID(1),
4,
),
}, fn.Some(testRfqID),
),
ExitHtlcAmt: 1234,
},
},
responses: []lndclient.InvoiceHtlcModifyResponse{
{
CancelSet: true,
},
},
},
}

for _, testCase := range testCases {
Expand Down Expand Up @@ -761,8 +796,8 @@ func testRouteHints() []*lnrpc.RouteHint {
NodeId: route.Vertex{1, 1, 1}.String(),
},
{
ChanId: 1234,
NodeId: route.Vertex{1, 2, 3}.String(),
ChanId: uint64(testScid),
NodeId: testNodeID.String(),
},
},
},
Expand Down

0 comments on commit 7358c1b

Please sign in to comment.