From 5f3c446b09787f2eddc0f9e8fff339c08d0eeb9e Mon Sep 17 00:00:00 2001 From: George Tsagkarelis Date: Mon, 11 Nov 2024 14:31:15 +0100 Subject: [PATCH] rfq: policies track accepted htlcs --- rfq/order.go | 227 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 217 insertions(+), 10 deletions(-) diff --git a/rfq/order.go b/rfq/order.go index 8b5bcf9b9..966088fec 100644 --- a/rfq/order.go +++ b/rfq/order.go @@ -13,6 +13,7 @@ import ( "github.com/lightninglabs/taproot-assets/fn" "github.com/lightninglabs/taproot-assets/rfqmath" "github.com/lightninglabs/taproot-assets/rfqmsg" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/lnutils" @@ -71,6 +72,14 @@ type Policy interface { // which the policy applies. Scid() uint64 + // TrackAcceptedHtlc makes the policy aware of this new accepted HTLC. + // This is important in cases where the set of existing HTLCs may affect + // whether the next compliance check passes. + TrackAcceptedHtlc(circuitKey models.CircuitKey, amt lnwire.MilliSatoshi) + + // UntrackHtlc stops tracking the uniquely identified HTLC. + UntrackHtlc(circuitKey models.CircuitKey) + // GenerateInterceptorResponse generates an interceptor response for the // HTLC interceptor from the policy. GenerateInterceptorResponse( @@ -95,9 +104,22 @@ type AssetSalePolicy struct { // the policy. MaxOutboundAssetAmount uint64 + // CurrentAssetAmountMsat is the total amount that is held currently in + // accepted HTLCs. + CurrentAmountMsat lnwire.MilliSatoshi + + // stateMutex is a mutex that locks access to this policy's internal + // state. This is needed as state is updated asynchronously by each + // routine that handles an intercepted HTLC. + stateMutex sync.RWMutex + // AskAssetRate is the quote's asking asset unit to BTC conversion rate. AskAssetRate rfqmath.BigIntFixedPoint + // htlcToAmt maps the unique HTLC identifiers to the effective amount + // that they carry. + htlcToAmt map[models.CircuitKey]lnwire.MilliSatoshi + // expiry is the policy's expiry unix timestamp after which the policy // is no longer valid. expiry uint64 @@ -111,6 +133,7 @@ func NewAssetSalePolicy(quote rfqmsg.BuyAccept) *AssetSalePolicy { MaxOutboundAssetAmount: quote.Request.AssetMaxAmt, AskAssetRate: quote.AssetRate.Rate, expiry: uint64(quote.AssetRate.Expiry.Unix()), + htlcToAmt: make(map[models.CircuitKey]lnwire.MilliSatoshi), } } @@ -128,7 +151,7 @@ func (c *AssetSalePolicy) CheckHtlcCompliance( // Check that the channel SCID is as expected. htlcScid := SerialisedScid(htlc.OutgoingChannelID.ToUint64()) if htlcScid != c.AcceptedQuoteId.Scid() { - return fmt.Errorf("htlc outgoing channel ID does not match "+ + return fmt.Errorf("HTLC outgoing channel ID does not match "+ "policy's SCID (htlc_scid=%d, policy_scid=%d)", htlcScid, c.AcceptedQuoteId.Scid()) } @@ -152,8 +175,13 @@ func (c *AssetSalePolicy) CheckHtlcCompliance( maxAssetAmount, c.AskAssetRate, ) - if htlc.AmountOutMsat > policyMaxOutMsat { - return fmt.Errorf("htlc out amount is greater than the policy "+ + // Since we will be reading CurrentAmountMsat value we acquire a read + // lock. + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + + if (c.CurrentAmountMsat + htlc.AmountOutMsat) > policyMaxOutMsat { + return fmt.Errorf("HTLC out amount is greater than the policy "+ "maximum (htlc_out_msat=%d, policy_max_out_msat=%d)", htlc.AmountOutMsat, policyMaxOutMsat) } @@ -167,6 +195,34 @@ func (c *AssetSalePolicy) CheckHtlcCompliance( return nil } +// TrackAcceptedHtlc accounts for the newly accepted HTLC. This may affect the +// acceptance of future HTLCs. +func (c *AssetSalePolicy) TrackAcceptedHtlc(circuitKey models.CircuitKey, + amt lnwire.MilliSatoshi) { + + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + c.CurrentAmountMsat += amt + + c.htlcToAmt[circuitKey] = amt +} + +// UntrackHtlc stops tracking the uniquely identified HTLC. +func (c *AssetSalePolicy) UntrackHtlc(circuitKey models.CircuitKey) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + amt, found := c.htlcToAmt[circuitKey] + if !found { + return + } + + delete(c.htlcToAmt, circuitKey) + + c.CurrentAmountMsat -= amt +} + // Expiry returns the policy's expiry time as a unix timestamp. func (c *AssetSalePolicy) Expiry() uint64 { return c.expiry @@ -246,12 +302,25 @@ type AssetPurchasePolicy struct { // AcceptedQuoteId is the ID of the accepted quote. AcceptedQuoteId rfqmsg.ID + // CurrentAssetAmountMsat is the total amount that is held currently in + // accepted HTLCs. + CurrentAmountMsat lnwire.MilliSatoshi + + // stateMutex is a mutex that locks access to this policy's internal + // state. This is needed as state is updated asynchronously by each + // routine that handles an intercepted HTLC. + stateMutex sync.RWMutex + // BidAssetRate is the quote's asset to BTC conversion rate. BidAssetRate rfqmath.BigIntFixedPoint // PaymentMaxAmt is the maximum agreed BTC payment. PaymentMaxAmt lnwire.MilliSatoshi + // htlcToAmt maps the unique HTLC identifiers to the effective amount + // that they carry. + htlcToAmt map[models.CircuitKey]lnwire.MilliSatoshi + // expiry is the policy's expiry unix timestamp in seconds after which // the policy is no longer valid. expiry uint64 @@ -266,6 +335,7 @@ func NewAssetPurchasePolicy(quote rfqmsg.SellAccept) *AssetPurchasePolicy { BidAssetRate: quote.AssetRate.Rate, PaymentMaxAmt: quote.Request.PaymentMaxAmt, expiry: uint64(quote.AssetRate.Expiry.Unix()), + htlcToAmt: make(map[models.CircuitKey]lnwire.MilliSatoshi), } } @@ -288,7 +358,7 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance( if rfqID != c.AcceptedQuoteId { return fmt.Errorf("HTLC contains a custom record, but it does "+ - "not contain the accepted quote ID (htlc=%v, "+ + "not contain the accepted quote ID (HTLC=%v, "+ "accepted_quote_id=%v)", htlc, c.AcceptedQuoteId) } @@ -313,17 +383,22 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance( ) if inboundAmountMSat < htlc.AmountOutMsat { - return fmt.Errorf("htlc out amount is more than inbound "+ + return fmt.Errorf("HTLC out amount is more than inbound "+ "asset amount in millisatoshis (htlc_out_msat=%d, "+ "inbound_asset_amount=%s, "+ "inbound_asset_amount_msat=%v)", htlc.AmountOutMsat, assetAmt.String(), inboundAmountMSat) } + // Since we will be reading CurrentAmountMsat value we acquire a read + // lock. + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + // Ensure that the outbound HTLC amount is less than the maximum agreed // BTC payment. - if htlc.AmountOutMsat > c.PaymentMaxAmt { - return fmt.Errorf("htlc out amount is more than the maximum "+ + if (c.CurrentAmountMsat + htlc.AmountOutMsat) > c.PaymentMaxAmt { + return fmt.Errorf("HTLC out amount is more than the maximum "+ "agreed BTC payment (htlc_out_msat=%d, "+ "payment_max_amt=%d)", htlc.AmountOutMsat, c.PaymentMaxAmt) @@ -338,6 +413,34 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance( return nil } +// TrackAcceptedHtlc accounts for the newly accepted HTLC. This may affect the +// acceptance of future HTLCs. +func (c *AssetPurchasePolicy) TrackAcceptedHtlc(circuitKey models.CircuitKey, + amt lnwire.MilliSatoshi) { + + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + c.CurrentAmountMsat += amt + + c.htlcToAmt[circuitKey] = amt +} + +// UntrackHtlc stops tracking the uniquely identified HTLC. +func (c *AssetPurchasePolicy) UntrackHtlc(circuitKey models.CircuitKey) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + amt, found := c.htlcToAmt[circuitKey] + if !found { + return + } + + delete(c.htlcToAmt, circuitKey) + + c.CurrentAmountMsat -= amt +} + // Expiry returns the policy's expiry time as a unix timestamp in seconds. func (c *AssetPurchasePolicy) Expiry() uint64 { return c.expiry @@ -436,6 +539,27 @@ func (a *AssetForwardPolicy) CheckHtlcCompliance( return nil } +// TrackAcceptedHtlc accounts for the newly accepted HTLC. This may affect the +// acceptance of future HTLCs. +func (a *AssetForwardPolicy) TrackAcceptedHtlc(circuitKey models.CircuitKey, + amt lnwire.MilliSatoshi) { + + // Track accepted HTLC in the incoming policy. + a.incomingPolicy.TrackAcceptedHtlc(circuitKey, amt) + + // Track accepted HTLC in the outgoing policy. + a.outgoingPolicy.TrackAcceptedHtlc(circuitKey, amt) +} + +// UntrackHtlc stops tracking the uniquely identified HTLC. +func (a *AssetForwardPolicy) UntrackHtlc(circuitKey models.CircuitKey) { + // Untrack HTLC in the incoming policy. + a.incomingPolicy.UntrackHtlc(circuitKey) + + // Untrack HTLC in the outgoing policy. + a.outgoingPolicy.UntrackHtlc(circuitKey) +} + // Expiry returns the policy's expiry time as a unix timestamp in seconds. The // returned expiry time is the earliest expiry time of the incoming and outgoing // policies. @@ -514,6 +638,10 @@ type OrderHandlerCfg struct { // AcceptHtlcEvents is a channel that receives accepted HTLCs. AcceptHtlcEvents chan<- *AcceptHtlcEvent + + // HtlcSubscriber is a subscriber that is used to retrieve live HTLC + // event updates. + HtlcSubscriber HtlcSubscriber } // OrderHandler orchestrates management of accepted quote bundles. It monitors @@ -530,6 +658,11 @@ type OrderHandler struct { // associated asset transaction policies. policies lnutils.SyncMap[SerialisedScid, Policy] + // htlcToPolicy maps an HTLC circuit key to the policy that applies to + // it. We need this map because for failed HTLCs we don't have the RFQ + // data available, so we need to cache this info. + htlcToPolicy lnutils.SyncMap[models.CircuitKey, Policy] + // ContextGuard provides a wait group and main quit channel that can be // used to create guarded contexts. *fn.ContextGuard @@ -586,13 +719,19 @@ func (h *OrderHandler) handleIncomingHtlc(_ context.Context, err = policy.CheckHtlcCompliance(htlc) if err != nil { log.Warnf("HTLC does not comply with policy: %v "+ - "(htlc=%v, policy=%v)", err, htlc, policy) + "(HTLC=%v, policy=%v)", err, htlc, policy) return &lndclient.InterceptedHtlcResponse{ Action: lndclient.InterceptorActionFail, }, nil } + h.htlcToPolicy.Store(htlc.IncomingCircuitKey, policy) + + // The HTLC passed the compliance checks, so now we keep track of the + // accepted HTLC. + policy.TrackAcceptedHtlc(htlc.IncomingCircuitKey, htlc.AmountOutMsat) + log.Debug("HTLC complies with policy. Broadcasting accept event.") h.cfg.AcceptHtlcEvents <- NewAcceptHtlcEvent(htlc, policy) @@ -640,12 +779,66 @@ func (h *OrderHandler) mainEventLoop() { } } +// subscribeHtlcs subscribes the OrderHandler to HTLC events provided by the lnd +// RPC interface. We use this subscription to track HTLC forwarding failures, +// which we use to performn a live update of our policies. +func (h *OrderHandler) subscribeHtlcs(ctx context.Context) error { + events, chErr, err := h.cfg.HtlcSubscriber.SubscribeHtlcEvents(ctx) + if err != nil { + return err + } + + for { + select { + case event := <-events: + // We only care about forwarding events. + if event.GetEventType() != routerrpc.HtlcEvent_FORWARD { + continue + } + + // Retrieve the two instances that may be relevant. + failEvent := event.GetForwardFailEvent() + linkFail := event.GetLinkFailEvent() + + // Craft the circuit key that identifies this HTLC. + circuitKey := models.CircuitKey{ + ChanID: lnwire.NewShortChanIDFromInt( + event.IncomingChannelId, + ), + HtlcID: event.IncomingHtlcId, + } + + switch { + case failEvent != nil: + fallthrough + case linkFail != nil: + // Fetch the policy that is related to this + // HTLC. + policy, found := h.htlcToPolicy.LoadAndDelete( + circuitKey, + ) + + if !found { + continue + } + + // Stop tracking this HTLC as it failed. + policy.UntrackHtlc(circuitKey) + } + + case err := <-chErr: + return err + + case <-ctx.Done(): + return ctx.Err() + } + } +} + // Start starts the service. func (h *OrderHandler) Start() error { var startErr error h.startOnce.Do(func() { - log.Info("Starting subsystem: order handler") - // Start the main event loop in a separate goroutine. h.Wg.Add(1) go func() { @@ -663,6 +856,20 @@ func (h *OrderHandler) Start() error { h.mainEventLoop() }() + + // Start the HTLC event subscription loop. + h.Wg.Add(1) + go func() { + defer h.Wg.Done() + + ctx, cancel := h.WithCtxQuitNoTimeout() + defer cancel() + + err := h.subscribeHtlcs(ctx) + if err != nil { + log.Errorf("HTLC subscriber error: %v", err) + } + }() }) return startErr