diff --git a/accounts/checkers.go b/accounts/checkers.go index 8fdb8c9b3..e4eff8935 100644 --- a/accounts/checkers.go +++ b/accounts/checkers.go @@ -522,6 +522,7 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params, } // The invoice is optional. + var paymentHash lntypes.Hash if len(invoice) > 0 { payReq, err := zpay32.Decode(invoice, chainParams) if err != nil { @@ -531,6 +532,10 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params, if payReq.MilliSat != nil && *payReq.MilliSat > sendAmt { sendAmt = *payReq.MilliSat } + + if payReq.PaymentHash != nil { + paymentHash = *payReq.PaymentHash + } } // We also add the max fee to the amount to check. This might mean that @@ -549,6 +554,14 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params, return fmt.Errorf("error validating account balance: %w", err) } + emptyHash := lntypes.Hash{} + if paymentHash != emptyHash { + err = service.AssociatePayment(acct.ID, paymentHash, sendAmt) + if err != nil { + return fmt.Errorf("error associating payment: %w", err) + } + } + return nil } diff --git a/accounts/checkers_test.go b/accounts/checkers_test.go index 481609c91..2b37b9493 100644 --- a/accounts/checkers_test.go +++ b/accounts/checkers_test.go @@ -68,6 +68,12 @@ func (m *mockService) AssociateInvoice(id AccountID, hash lntypes.Hash) error { return nil } +func (m *mockService) AssociatePayment(id AccountID, paymentHash lntypes.Hash, + amt lnwire.MilliSatoshi) error { + + return nil +} + func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash, amt lnwire.MilliSatoshi) error { @@ -85,6 +91,10 @@ func (m *mockService) RemovePayment(hash lntypes.Hash) error { return nil } +func (*mockService) IsRunning() bool { + return true +} + var _ Service = (*mockService)(nil) // TestAccountChecker makes sure all round trip checkers can be instantiated diff --git a/accounts/interceptor.go b/accounts/interceptor.go index d6dff8dfe..836c4a2fa 100644 --- a/accounts/interceptor.go +++ b/accounts/interceptor.go @@ -52,6 +52,17 @@ func (s *InterceptorService) Intercept(ctx context.Context, s.requestMtx.Lock() defer s.requestMtx.Unlock() + // If the account service is not running, we reject all requests. + // Note that this is by no means a guarantee that the account service + // will be running throughout processing the request, but at least we + // can stop requests early if the service was already disabled when the + // request came in. + if !s.IsRunning() { + return mid.RPCErrString( + req, "the account service has been stopped", + ) + } + mac := &macaroon.Macaroon{} err := mac.UnmarshalBinary(req.RawMacaroon) if err != nil { diff --git a/accounts/interface.go b/accounts/interface.go index 879d9a51a..8fec489e2 100644 --- a/accounts/interface.go +++ b/accounts/interface.go @@ -160,6 +160,12 @@ var ( ErrNotSupportedWithAccounts = errors.New("this RPC call is not " + "supported with restricted account macaroons") + // ErrAccountServiceDisabled is the error that is returned when the + // account service has been disabled due to an error being thrown + // in the service that cannot be recovered from. + ErrAccountServiceDisabled = errors.New("the account service has been " + + "stopped") + // MacaroonPermissions are the permissions required for an account // macaroon. MacaroonPermissions = []bakery.Op{{ @@ -240,4 +246,10 @@ type Service interface { // longer needs to be tracked. The payment is certain to never succeed, // so we never need to debit the amount from the account. RemovePayment(hash lntypes.Hash) error + + // AssociatePayment associates a payment (hash) with the given account, + // ensuring that the payment will be tracked for a user when LiT is + // restarted. + AssociatePayment(id AccountID, paymentHash lntypes.Hash, + fullAmt lnwire.MilliSatoshi) error } diff --git a/accounts/service.go b/accounts/service.go index a229d831c..7712258bc 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -16,6 +16,12 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) +// Config holds the configuration options for the accounts service. +type Config struct { + // Disable will disable the accounts service if set. + Disable bool `long:"disable" description:"disable the accounts service"` +} + // trackedPayment is a struct that holds all information that identifies a // payment that we are tracking in the service. type trackedPayment struct { @@ -60,14 +66,18 @@ type InterceptorService struct { invoiceToAccount map[lntypes.Hash]AccountID pendingPayments map[lntypes.Hash]*trackedPayment - mainErrChan chan<- error - wg sync.WaitGroup - quit chan struct{} + mainErrCallback func(error) + wg sync.WaitGroup + quit chan struct{} + + isEnabled bool } // NewService returns a service backed by the macaroon Bolt DB stored in the // passed-in directory. -func NewService(dir string, errChan chan<- error) (*InterceptorService, error) { +func NewService(dir string, + errCallback func(error)) (*InterceptorService, error) { + accountStore, err := NewBoltStore(dir, DBFilename) if err != nil { return nil, err @@ -81,8 +91,9 @@ func NewService(dir string, errChan chan<- error) (*InterceptorService, error) { contextCancel: contextCancel, invoiceToAccount: make(map[lntypes.Hash]AccountID), pendingPayments: make(map[lntypes.Hash]*trackedPayment), - mainErrChan: errChan, + mainErrCallback: errCallback, quit: make(chan struct{}), + isEnabled: false, }, nil } @@ -93,12 +104,15 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, s.routerClient = routerClient s.checkers = NewAccountChecker(s, params) + s.isEnabled = true + // Let's first fill our cache that maps invoices to accounts, which // allows us to credit an account easily once an invoice is settled. We // also track payments that aren't in a final state yet. existingAccounts, err := s.store.Accounts() if err != nil { - return fmt.Errorf("error querying existing accounts: %w", err) + return s.disableAndErrorf("error querying existing "+ + "accounts: %w", err) } for _, acct := range existingAccounts { acct := acct @@ -116,8 +130,8 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, acct.ID, hash, entry.FullAmount, ) if err != nil { - return fmt.Errorf("error tracking "+ - "payment: %w", err) + return s.disableAndErrorf("error "+ + "tracking payment: %w", err) } } } @@ -146,8 +160,8 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, s.currentSettleIndex = 0 default: - return fmt.Errorf("error determining last invoice indexes: %w", - err) + return s.disableAndErrorf("error determining last invoice "+ + "indexes: %w", err) } invoiceChan, invoiceErrChan, err := lightningClient.SubscribeInvoices( @@ -157,7 +171,7 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, }, ) if err != nil { - return fmt.Errorf("error subscribing invoices: %w", err) + return s.disableAndErrorf("error subscribing invoices: %w", err) } s.wg.Add(1) @@ -178,23 +192,18 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, log.Errorf("Error processing invoice "+ "update: %v", err) - select { - case s.mainErrChan <- err: - case <-s.mainCtx.Done(): - case <-s.quit: - } + s.mainErrCallback(err) return } case err := <-invoiceErrChan: - log.Errorf("Error in invoice subscription: %v", - err) + // If the invoice subscription errors out, we + // stop the service as we won't be able to + // process invoices. + err = s.disableAndErrorf("Error in invoice "+ + "subscription: %w", err) - select { - case s.mainErrChan <- err: - case <-s.mainCtx.Done(): - case <-s.quit: - } + s.mainErrCallback(err) return case <-s.mainCtx.Done(): @@ -211,6 +220,18 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, // Stop shuts down the account service. func (s *InterceptorService) Stop() error { + // We need to lock the request mutex to ensure that we don't stop the + // service while we're processing a request. + // This is especially important to ensure that we don't stop the service + // exactly after a user has made an rpc call to send a payment we can't + // know the payment hash for prior to the actual payment being sent + // (i.e. Keysend or SendToRoute). This is because if we stop the service + // after the send request has been sent to lnd, but before TrackPayment + // has been called, we won't be able to track the payment and debit the + // account. + s.requestMtx.Lock() + defer s.requestMtx.Unlock() + s.contextCancel() close(s.quit) @@ -219,6 +240,49 @@ func (s *InterceptorService) Stop() error { return s.store.Close() } +// IsRunning checks if the account service is running, and returns a boolean +// indicating whether it is running or not. +func (s *InterceptorService) IsRunning() bool { + s.RLock() + defer s.RUnlock() + + return s.isEnabled +} + +// isRunningUnsafe checks if the account service is running, and returns a +// boolean indicating whether it is running or not +// +// NOTE: The store lock MUST be held as either a read or write lock when calling +// this method. +func (s *InterceptorService) isRunningUnsafe() bool { + return s.isEnabled +} + +// disable disables the account service, and marks the service as not running. +// The function acquires the store write lock before disabling the service. +// The function returns an error with the given format and arguments. +func (s *InterceptorService) disableAndErrorf(format string, a ...any) error { + s.Lock() + defer s.Unlock() + + s.isEnabled = false + + return fmt.Errorf(format, a...) +} + +// disableAndErrorfUnsafe disables the account service, and marks the service as +// not running. The function returns an error with the given format and +// arguments. +// +// NOTE: The store lock MUST be held when calling this method. +func (s *InterceptorService) disableAndErrorfUnsafe(format string, + a ...any) error { + + s.isEnabled = false + + return fmt.Errorf(format, a...) +} + // NewAccount creates a new OffChainBalanceAccount with the given balance and a // randomly chosen ID. func (s *InterceptorService) NewAccount(balance lnwire.MilliSatoshi, @@ -239,6 +303,14 @@ func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, s.Lock() defer s.Unlock() + // As this function updates account balances, we require that the + // service is running before we execute it. + if s.isRunningUnsafe() { + // This case can only happen if the service is disabled while + // we we're processing a request. + return nil, ErrAccountServiceDisabled + } + account, err := s.store.Account(accountID) if err != nil { return nil, fmt.Errorf("error fetching account: %w", err) @@ -362,12 +434,62 @@ func (s *InterceptorService) AssociateInvoice(id AccountID, return s.store.UpdateAccount(account) } +// AssociatePayment associates a payment (hash) with the given account, +// ensuring that the payment will be tracked for a user when LiT is +// restarted. +func (s *InterceptorService) AssociatePayment(id AccountID, + paymentHash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error { + + s.Lock() + defer s.Unlock() + + account, err := s.store.Account(id) + if err != nil { + return err + } + + // If the payment is already associated with the account, we don't need + // to associate it again. + _, ok := account.Payments[paymentHash] + if ok { + return nil + } + + // Associate the payment with the account and store it. + account.Payments[paymentHash] = &PaymentEntry{ + Status: lnrpc.Payment_UNKNOWN, + FullAmount: fullAmt, + } + + if err := s.store.UpdateAccount(account); err != nil { + return fmt.Errorf("error updating account: %w", err) + } + + return nil +} + // invoiceUpdate credits the account an invoice was registered with, in case the // invoice was settled. +// +// NOTE: Any code that errors in this function MUST call disableAndErrorfUnsafe +// while the store lock is held to ensure that the service is disabled under +// the same lock. Else we risk that other threads will try to update invoices +// while the service should be disabled, which could lead to us missing invoice +// updates on next startup. func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { s.Lock() defer s.Unlock() + // As this function updates account balances, and is called from the + // invoice subscription, we ensure that the service is running before we + // execute it. + if !s.isRunningUnsafe() { + // We will process the invoice update on next startup instead, + // once the error that caused the service to stop has been + // resolved. + return ErrAccountServiceDisabled + } + // We update our indexes each time we get a new invoice from our // subscription. This might be a bit inefficient but makes sure we don't // miss an update. @@ -386,7 +508,9 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { s.currentAddIndex, s.currentSettleIndex, ) if err != nil { - return err + return s.disableAndErrorfUnsafe( + "error storing last indexes: %w", err, + ) } } @@ -405,7 +529,9 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { account, err := s.store.Account(acctID) if err != nil { - return fmt.Errorf("error fetching account: %w", err) + return s.disableAndErrorfUnsafe( + "error fetching account: %w", err, + ) } // If we get here, the current account has the invoice associated with @@ -413,7 +539,9 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { // in the DB. account.CurrentBalance += int64(invoice.AmountPaid) if err := s.store.UpdateAccount(account); err != nil { - return fmt.Errorf("error updating account: %w", err) + return s.disableAndErrorfUnsafe( + "error updating account: %w", err, + ) } // We've now fully processed the invoice and don't need to keep it @@ -451,16 +579,45 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, return nil } - // Okay, we haven't tracked this payment before. So let's now associate - // the account with it. account.Payments[hash] = &PaymentEntry{ Status: lnrpc.Payment_UNKNOWN, FullAmount: fullAmt, } + if err := s.store.UpdateAccount(account); err != nil { + if !ok { + // In the rare case that the payment isn't associated + // with an account yet, and we fail to update the + // account we will not be tracking the payment, even if + // track the service is restarted. Therefore the node + // runner needs to manually check if the payment was + // made and debit the account if that's the case. + errStr := "critical error: failed to store the " + + "payment with hash %v for user with account " + + "id %v. Manual intervention required! " + + "Verify if the payment was executed, and " + + "manually update the user account balance by " + + "subtracting the payment amount if it was" + + mainChanErr := s.disableAndErrorfUnsafe( + errStr, hash, id, + ) + + s.mainErrCallback(mainChanErr) + } + return fmt.Errorf("error updating account: %w", err) } + // As this function updates account balances, we ensure that the service + // is running before we execute it. + if !s.isRunningUnsafe() { + // We will track the payment on next on next startup instead, + // once the error that caused the service to stop has been + // resolved. + return ErrAccountServiceDisabled + } + // And start the long-running TrackPayment RPC. ctxc, cancel := context.WithCancel(s.mainCtx) statusChan, errChan, err := s.routerClient.TrackPayment(ctxc, hash) @@ -490,11 +647,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, hash, paymentUpdate, ) if err != nil { - select { - case s.mainErrChan <- err: - case <-s.mainCtx.Done(): - case <-s.quit: - } + s.mainErrCallback(err) return } @@ -516,15 +669,14 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, return } - log.Errorf("Received error from TrackPayment "+ - "RPC for payment %v: %v", hash, err) - if err != nil { - select { - case s.mainErrChan <- err: - case <-s.mainCtx.Done(): - case <-s.quit: - } + // If we error when tracking the + // payment, we stop the service. + err = s.disableAndErrorf("received "+ + "error from TrackPayment RPC "+ + "for payment %v: %w", hash, err) + + s.mainErrCallback(err) } return @@ -544,6 +696,10 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, // associated with, in case it is settled. The boolean value returned indicates // whether the status was terminal or not. If it's not terminal then further // updates are expected. +// +// NOTE: Any code that errors in this function MUST call disableAndErrorfUnsafe +// while the store lock is held to ensure that the service is disabled under +// the same lock. func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, status lndclient.PaymentStatus) (bool, error) { @@ -563,21 +719,40 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, s.Lock() defer s.Unlock() + // As this function updates account balances, we ensure that the service + // is running before we execute it. + if !s.isRunningUnsafe() { + // We will update the payment on next startup instead, once the + // error that caused the service to stop has been resolved. + return false, ErrAccountServiceDisabled + } + pendingPayment, ok := s.pendingPayments[hash] if !ok { - return terminalState, fmt.Errorf("payment %x not mapped to "+ - "any account", hash[:]) + err := s.disableAndErrorfUnsafe("payment %x not mapped to any "+ + "account", hash[:]) + + return terminalState, err } // A failed payment can just be removed, no further action needed. if status.State == lnrpc.Payment_FAILED { - return terminalState, s.removePayment(hash, status.State) + err := s.removePayment(hash, status.State) + if err != nil { + err = s.disableAndErrorfUnsafe("error removing "+ + "payment: %w", err) + } + + return terminalState, err } // The payment went through! We now need to debit the full amount from // the account. account, err := s.store.Account(pendingPayment.accountID) if err != nil { + err = s.disableAndErrorfUnsafe("error fetching account: %w", + err) + return terminalState, err } @@ -590,13 +765,21 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, FullAmount: fullAmount, } if err := s.store.UpdateAccount(account); err != nil { - return terminalState, fmt.Errorf("error updating account: %w", + err = s.disableAndErrorfUnsafe("error updating account: %w", err) + + return terminalState, err } // We've now fully processed the payment and don't need to keep it // mapped or tracked anymore. - return terminalState, s.removePayment(hash, lnrpc.Payment_SUCCEEDED) + err = s.removePayment(hash, lnrpc.Payment_SUCCEEDED) + if err != nil { + err = s.disableAndErrorfUnsafe("error removing payment: %w", + err) + } + + return terminalState, err } // RemovePayment removes a failed payment from the service because it no longer diff --git a/accounts/service_test.go b/accounts/service_test.go index 3b0b604ae..f1cf120d8 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -32,10 +32,12 @@ type mockLnd struct { invoiceReq chan lndclient.InvoiceSubscriptionRequest paymentReq chan lntypes.Hash - callErr error - errChan chan error - invoiceChan chan *lndclient.Invoice - paymentChans map[lntypes.Hash]chan lndclient.PaymentStatus + invoiceSubscriptionErr error + trackPaymentErr error + invoiceErrChan chan error + paymentErrChan chan error + invoiceChan chan *lndclient.Invoice + paymentChans map[lntypes.Hash]chan lndclient.PaymentStatus } func newMockLnd() *mockLnd { @@ -44,9 +46,10 @@ func newMockLnd() *mockLnd { invoiceReq: make( chan lndclient.InvoiceSubscriptionRequest, 10, ), - paymentReq: make(chan lntypes.Hash, 10), - errChan: make(chan error, 10), - invoiceChan: make(chan *lndclient.Invoice), + paymentReq: make(chan lntypes.Hash, 10), + invoiceErrChan: make(chan error, 10), + paymentErrChan: make(chan error, 10), + invoiceChan: make(chan *lndclient.Invoice), paymentChans: make( map[lntypes.Hash]chan lndclient.PaymentStatus, ), @@ -72,6 +75,18 @@ func (m *mockLnd) assertMainErr(t *testing.T, expectedErr error) { } } +// assertMainErrContains asserts that the main error contains the expected error +// string. +func (m *mockLnd) assertMainErrContains(t *testing.T, expectedStr string) { + select { + case err := <-m.mainErrChan: + require.ErrorContains(t, err, expectedStr) + + case <-time.After(testTimeout): + t.Fatalf("Did not get expected main err before timeout") + } +} + func (m *mockLnd) assertNoInvoiceRequest(t *testing.T) { select { case req := <-m.invoiceReq: @@ -132,13 +147,13 @@ func (m *mockLnd) SubscribeInvoices(_ context.Context, req lndclient.InvoiceSubscriptionRequest) (<-chan *lndclient.Invoice, <-chan error, error) { - if m.callErr != nil { - return nil, nil, m.callErr + if m.invoiceSubscriptionErr != nil { + return nil, nil, m.invoiceSubscriptionErr } m.invoiceReq <- req - return m.invoiceChan, m.errChan, nil + return m.invoiceChan, m.invoiceErrChan, nil } // TrackPayment picks up a previously started payment and returns a payment @@ -146,14 +161,14 @@ func (m *mockLnd) SubscribeInvoices(_ context.Context, func (m *mockLnd) TrackPayment(_ context.Context, hash lntypes.Hash) (chan lndclient.PaymentStatus, chan error, error) { - if m.callErr != nil { - return nil, nil, m.callErr + if m.trackPaymentErr != nil { + return nil, nil, m.trackPaymentErr } m.paymentReq <- hash m.paymentChans[hash] = make(chan lndclient.PaymentStatus, 1) - return m.paymentChans[hash], m.errChan, nil + return m.paymentChans[hash], m.paymentErrChan, nil } // TestAccountService tests that the account service can track payments and @@ -169,15 +184,92 @@ func TestAccountService(t *testing.T) { validate func(t *testing.T, lnd *mockLnd, s *InterceptorService) }{{ - name: "startup err on tracking payment", + name: "startup err on invoice subscription", setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { - lnd.callErr = testErr + lnd.invoiceSubscriptionErr = testErr }, startupErr: testErr.Error(), validate: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { lnd.assertNoInvoiceRequest(t) + require.False(t, s.IsRunning()) + }, + }, { + name: "err on invoice update", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Invoices: AccountInvoices{ + testHash: {}, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + + // Start by closing the store. This should cause an + // error once we make an invoice update, as the service + // will fail when persisting the invoice update. + s.store.Close() + + // Ensure that the service was started successfully and + // still running though, despite the closing of the + // db store. + require.True(t, s.IsRunning()) + + // Now let's send the invoice update, which should fail. + lnd.invoiceChan <- &lndclient.Invoice{ + AddIndex: 12, + SettleIndex: 12, + Hash: testHash, + AmountPaid: 777, + State: invpkg.ContractSettled, + } + + // Ensure that the service was eventually disabled. + assertEventually(t, func() bool { + isRunning := s.IsRunning() + return isRunning == false + }) + lnd.assertMainErrContains(t, "database not open") + }, + }, { + name: "err in invoice err channel", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Invoices: AccountInvoices{ + testHash: {}, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + // Ensure that the service was started successfully. + require.True(t, s.IsRunning()) + + // Now let's send an error over the invoice error + // channel. This should disable the service. + lnd.invoiceErrChan <- testErr + + // Ensure that the service was eventually disabled. + assertEventually(t, func() bool { + isRunning := s.IsRunning() + return isRunning == false + }) + + lnd.assertMainErrContains(t, testErr.Error()) }, }, { name: "goroutine err sent on main err chan", @@ -195,13 +287,13 @@ func TestAccountService(t *testing.T) { err := s.store.UpdateAccount(acct) require.NoError(t, err) - lnd.errChan <- testErr + s.mainErrCallback(testErr) }, validate: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { lnd.assertInvoiceRequest(t, 0, 0) - lnd.assertMainErr(t, testErr) + lnd.assertMainErrContains(t, testErr.Error()) }, }, { name: "startup do not track completed payments", @@ -227,6 +319,135 @@ func TestAccountService(t *testing.T) { lnd.assertNoPaymentRequest(t) lnd.assertInvoiceRequest(t, 0, 0) lnd.assertNoMainErr(t) + require.True(t, s.IsRunning()) + }, + }, { + name: "startup err on payment tracking", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Invoices: AccountInvoices{ + testHash: {}, + }, + Payments: AccountPayments{ + testHash: { + Status: lnrpc.Payment_IN_FLIGHT, + FullAmount: 1234, + }, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + + lnd.trackPaymentErr = testErr + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + + // Assert that the invoice subscription succeeded. + require.Contains(t, s.invoiceToAccount, testHash) + + // But setting up the payment tracking should have failed. + require.False(t, s.IsRunning()) + + // Finally let's assert that we didn't successfully add the + // payment to pending payment, and that lnd isn't awaiting + // the payment request. + require.NotContains(t, s.pendingPayments, testHash) + lnd.assertNoPaymentRequest(t) + }, + }, { + name: "err on payment update", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Payments: AccountPayments{ + testHash: { + Status: lnrpc.Payment_IN_FLIGHT, + FullAmount: 1234, + }, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + // Ensure that the service was started successfully, + // and lnd contains the payment request. + require.True(t, s.IsRunning()) + lnd.assertPaymentRequests(t, map[lntypes.Hash]struct{}{ + testHash: {}, + }) + + // Now let's wipe the service's pending payments. + // This will cause an error send an update over + // the payment channel, which should disable the + // service. + s.pendingPayments = make(map[lntypes.Hash]*trackedPayment) + + // Send an invalid payment over the payment chan + // which should error and disable the service + lnd.paymentChans[testHash] <- lndclient.PaymentStatus{ + State: lnrpc.Payment_SUCCEEDED, + Fee: 234, + Value: 1000, + } + + // Ensure that the service was eventually disabled. + assertEventually(t, func() bool { + isRunning := s.IsRunning() + return isRunning == false + }) + lnd.assertMainErrContains( + t, "not mapped to any account", + ) + + }, + }, { + name: "err in payment update chan", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Payments: AccountPayments{ + testHash: { + Status: lnrpc.Payment_IN_FLIGHT, + FullAmount: 1234, + }, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + // Ensure that the service was started successfully, + // and lnd contains the payment request. + require.True(t, s.IsRunning()) + lnd.assertPaymentRequests(t, map[lntypes.Hash]struct{}{ + testHash: {}, + }) + + // Now let's send an error over the payment error + // channel. This should disable the service. + lnd.paymentErrChan <- testErr + + // Ensure that the service was eventually disabled. + assertEventually(t, func() bool { + isRunning := s.IsRunning() + return isRunning == false + }) + + lnd.assertMainErrContains(t, testErr.Error()) }, }, { name: "startup track in-flight payments", @@ -451,9 +672,10 @@ func TestAccountService(t *testing.T) { tt.Parallel() lndMock := newMockLnd() - service, err := NewService( - t.TempDir(), lndMock.mainErrChan, - ) + errFunc := func(err error) { + lndMock.mainErrChan <- err + } + service, err := NewService(t.TempDir(), errFunc) require.NoError(t, err) // Is a setup call required to initialize initial diff --git a/config.go b/config.go index 107e2a4f8..65b9f54c5 100644 --- a/config.go +++ b/config.go @@ -17,6 +17,7 @@ import ( "github.com/lightninglabs/faraday" "github.com/lightninglabs/faraday/chain" "github.com/lightninglabs/faraday/frdrpcserver" + "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/autopilotserver" "github.com/lightninglabs/lightning-terminal/firewall" mid "github.com/lightninglabs/lightning-terminal/rpcmiddleware" @@ -205,6 +206,8 @@ type Config struct { Firewall *firewall.Config `group:"Firewall options" namespace:"firewall"` + Accounts *accounts.Config `group:"Accounts options" namespace:"accounts"` + // faradayRpcConfig is a subset of faraday's full configuration that is // passed into faraday's RPC server. faradayRpcConfig *frdrpcserver.Config @@ -320,6 +323,7 @@ func defaultConfig() *Config { PingCadence: time.Hour, }, Firewall: firewall.DefaultConfig(), + Accounts: &accounts.Config{}, } } diff --git a/itest/litd_mode_integrated_test.go b/itest/litd_mode_integrated_test.go index 2d5aeecdf..d75433619 100644 --- a/itest/litd_mode_integrated_test.go +++ b/itest/litd_mode_integrated_test.go @@ -213,10 +213,21 @@ var ( } endpoints = []struct { - name string - macaroonFn macaroonFn - requestFn requestFn - successPattern string + name string + macaroonFn macaroonFn + requestFn requestFn + successPattern string + + // disabledPattern represents a substring that is expected to be + // part of the error returned when a gRPC request is made to the + // disabled endpoint. + // TODO: once we have a subsystem manager, we can unify the + // returned for disabled endpoints for both subsystems and + // subservers by not registering the subsystem URIs to the + // permsMgr if it has been disabled. This field will then be + // unnecessary and can be removed. + disabledPattern string + allowedThroughLNC bool grpcWebURI string restWebURI string @@ -269,6 +280,7 @@ var ( macaroonFn: faradayMacaroonFn, requestFn: faradayRequestFn, successPattern: "\"reports\":[]", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/frdrpc.FaradayServer/RevenueReport", restWebURI: "/v1/faraday/revenue", @@ -278,6 +290,7 @@ var ( macaroonFn: loopMacaroonFn, requestFn: loopRequestFn, successPattern: "\"swaps\":[]", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/looprpc.SwapClient/ListSwaps", restWebURI: "/v1/loop/swaps", @@ -287,6 +300,7 @@ var ( macaroonFn: poolMacaroonFn, requestFn: poolRequestFn, successPattern: "\"accounts_active\":0", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/poolrpc.Trader/GetInfo", restWebURI: "/v1/pool/info", @@ -296,6 +310,7 @@ var ( macaroonFn: tapMacaroonFn, requestFn: tapRequestFn, successPattern: "\"assets\":[]", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/taprpc.TaprootAssets/ListAssets", restWebURI: "/v1/taproot-assets/assets", @@ -305,6 +320,7 @@ var ( macaroonFn: emptyMacaroonFn, requestFn: tapUniverseRequestFn, successPattern: "\"num_assets\":", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/universerpc.Universe/Info", restWebURI: "/v1/taproot-assets/universe/info", @@ -326,9 +342,11 @@ var ( macaroonFn: litMacaroonFn, requestFn: litAccountRequestFn, successPattern: "\"accounts\":[", + disabledPattern: "accounts has been disabled", allowedThroughLNC: false, grpcWebURI: "/litrpc.Accounts/ListAccounts", restWebURI: "/v1/accounts", + canDisable: true, }, { name: "litrpc-autopilot", macaroonFn: litMacaroonFn, @@ -384,6 +402,7 @@ func testDisablingSubServers(ctx context.Context, net *NetworkHarness, WithLitArg("loop-mode", "disable"), WithLitArg("pool-mode", "disable"), WithLitArg("faraday-mode", "disable"), + WithLitArg("accounts.disable", ""), }, ) require.NoError(t, err) @@ -494,7 +513,7 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.requestFn, endpoint.successPattern, endpointDisabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -532,7 +551,7 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, shouldFailWithoutMacaroon, endpoint.successPattern, endpointDisabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -557,7 +576,8 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, ttt, cfg.LitAddr(), cfg.UIPassword, endpoint.grpcWebURI, withoutUIPassword, endpointDisabled, - "unknown request", endpoint.noAuth, + endpoint.disabledPattern, + endpoint.noAuth, ) }) } @@ -596,7 +616,7 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.requestFn, endpoint.successPattern, endpointDisabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -649,7 +669,9 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.successPattern, endpoint.allowedThroughLNC, "unknown service", - endpointDisabled, endpoint.noAuth, + endpointDisabled, + endpoint.disabledPattern, + endpoint.noAuth, ) }) } @@ -658,6 +680,12 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, t.Run("gRPC super macaroon account system test", func(tt *testing.T) { cfg := net.Alice.Cfg + // If the accounts service is disabled, we skip this test as it + // will fail due to the accounts service being disabled. + if subServersDisabled { + return + } + superMacFile, err := bakeSuperMacaroon(cfg, false) require.NoError(tt, err) @@ -722,6 +750,7 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.successPattern, allowed, expectedErr, endpointDisabled, + endpoint.disabledPattern, endpoint.noAuth, ) }) @@ -1169,7 +1198,8 @@ func runRESTAuthTest(t *testing.T, hostPort, uiPassword, macaroonPath, restURI, // through Lightning Node Connect. func runLNCAuthTest(t *testing.T, rawLNCConn grpc.ClientConnInterface, makeRequest requestFn, successContent string, callAllowed bool, - expectErrContains string, disabled, noMac bool) { + expectErrContains string, disabled bool, disabledPattern string, + noMac bool) { ctxt, cancel := context.WithTimeout( context.Background(), defaultTimeout, @@ -1186,7 +1216,7 @@ func runLNCAuthTest(t *testing.T, rawLNCConn grpc.ClientConnInterface, // The call should be allowed, so we expect no error unless this is // for a disabled sub-server. case disabled: - require.ErrorContains(t, err, "unknown request") + require.ErrorContains(t, err, disabledPattern) return case noMac: diff --git a/itest/litd_mode_remote_test.go b/itest/litd_mode_remote_test.go index 41f34c13f..67e0aa14f 100644 --- a/itest/litd_mode_remote_test.go +++ b/itest/litd_mode_remote_test.go @@ -67,7 +67,7 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.requestFn, endpoint.successPattern, endpointEnabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -94,7 +94,7 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, shouldFailWithoutMacaroon, endpoint.successPattern, endpointEnabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -117,7 +117,8 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, ttt, cfg.LitAddr(), cfg.UIPassword, endpoint.grpcWebURI, withoutUIPassword, endpointEnabled, - "unknown request", endpoint.noAuth, + endpoint.disabledPattern, + endpoint.noAuth, ) }) } @@ -145,7 +146,7 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.requestFn, endpoint.successPattern, endpointEnabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -197,7 +198,9 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.successPattern, endpoint.allowedThroughLNC, "unknown service", - endpointDisabled, endpoint.noAuth, + endpointDisabled, + endpoint.disabledPattern, + endpoint.noAuth, ) }) } @@ -248,6 +251,7 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.successPattern, allowed, expectedErr, endpointDisabled, + endpoint.disabledPattern, endpoint.noAuth, ) }) @@ -257,6 +261,12 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, t.Run("gRPC super macaroon account system test", func(tt *testing.T) { cfg := net.Bob.Cfg + // If the accounts service is disabled, we skip this test as it + // will fail due to the accounts service being disabled. + if subServersDisabled { + return + } + superMacFile, err := bakeSuperMacaroon(cfg, false) require.NoError(tt, err) diff --git a/rpc_proxy.go b/rpc_proxy.go index 798f8592e..f1be75934 100644 --- a/rpc_proxy.go +++ b/rpc_proxy.go @@ -623,6 +623,9 @@ func (p *rpcProxy) checkSubSystemStarted(requestURI string) error { switch { case handled: + case isAccountsReq(requestURI): + system = subservers.ACCOUNTS + case p.permsMgr.IsSubServerURI(subservers.LIT, requestURI): system = subservers.LIT @@ -694,3 +697,13 @@ func isProxyReq(uri string) bool { uri, fmt.Sprintf("/%s", litrpc.Proxy_ServiceDesc.ServiceName), ) } + +// isAccountsReq returns true if the given request is intended for the +// litrpc.Accounts service. +func isAccountsReq(uri string) bool { + return strings.HasPrefix( + uri, fmt.Sprintf( + "/%s", litrpc.Accounts_ServiceDesc.ServiceName, + ), + ) +} diff --git a/subservers/subserver.go b/subservers/subserver.go index 8605eb422..82ba22292 100644 --- a/subservers/subserver.go +++ b/subservers/subserver.go @@ -11,12 +11,13 @@ import ( ) const ( - LND string = "lnd" - LIT string = "lit" - LOOP string = "loop" - POOL string = "pool" - TAP string = "taproot-assets" - FARADAY string = "faraday" + LND string = "lnd" + LIT string = "lit" + LOOP string = "loop" + POOL string = "pool" + TAP string = "taproot-assets" + FARADAY string = "faraday" + ACCOUNTS string = "accounts" ) // subServerWrapper is a wrapper around the SubServer interface and is used by diff --git a/terminal.go b/terminal.go index a4ac4afdf..f0c1e9350 100644 --- a/terminal.go +++ b/terminal.go @@ -232,9 +232,15 @@ func (g *LightningTerminal) Run() error { return fmt.Errorf("could not create permissions manager") } - // Register LND and LiT with the status manager. + // Register LND, LiT and Accounts with the status manager. g.statusMgr.RegisterAndEnableSubServer(subservers.LND) g.statusMgr.RegisterAndEnableSubServer(subservers.LIT) + g.statusMgr.RegisterSubServer(subservers.ACCOUNTS) + + // Also enable the accounts subserver if it's not disabled. + if !g.cfg.Accounts.Disable { + g.statusMgr.SetEnabled(subservers.ACCOUNTS) + } // Create the instances of our subservers now so we can hook them up to // lnd once it's fully started. @@ -305,8 +311,19 @@ func (g *LightningTerminal) Run() error { func (g *LightningTerminal) start() error { var err error + accountServiceErrCallback := func(err error) { + g.statusMgr.SetErrored( + subservers.ACCOUNTS, + err.Error(), + ) + + log.Errorf("Error thrown in the accounts service, keeping "+ + "litd running: %v", err, + ) + } + g.accountService, err = accounts.NewService( - filepath.Dir(g.cfg.MacaroonPath), g.errQueue.ChanIn(), + filepath.Dir(g.cfg.MacaroonPath), accountServiceErrCallback, ) if err != nil { return fmt.Errorf("error creating account service: %v", err) @@ -837,16 +854,41 @@ func (g *LightningTerminal) startInternalSubServers( return nil } + // Even if the accounts service fails on the Start function, or the + // accounts service is disabled, we still want to call Stop function as + // this closes the contexts and the db store which were opened with the + // accounts.NewService function call in the LightningTerminal start + // function above. + closeAccountService := func() { + if err := g.accountService.Stop(); err != nil { + // We only log the error if we fail to stop the service, + // as it's not critical that this succeeds in order to + // keep litd running + log.Errorf("Error stopping account service: %v", err) + } + } + log.Infof("Starting LiT account service") - err = g.accountService.Start( - g.lndClient.Client, g.lndClient.Router, - g.lndClient.ChainParams, - ) - if err != nil { - return fmt.Errorf("error starting account service: %v", - err) + if !g.cfg.Accounts.Disable { + err = g.accountService.Start( + g.lndClient.Client, g.lndClient.Router, + g.lndClient.ChainParams, + ) + if err != nil { + log.Errorf("error starting account service: %v, "+ + "disabling account service", err) + + g.statusMgr.SetErrored(subservers.ACCOUNTS, err.Error()) + + closeAccountService() + } else { + g.statusMgr.SetRunning(subservers.ACCOUNTS) + + g.accountServiceStarted = true + } + } else { + closeAccountService() } - g.accountServiceStarted = true requestLogger, err := firewall.NewRequestLogger( g.cfg.Firewall.RequestLogger, g.firewallDB, @@ -933,7 +975,12 @@ func (g *LightningTerminal) registerSubDaemonGrpcServers(server *grpc.Server, litrpc.RegisterStatusServer(server, g.statusMgr) } else { litrpc.RegisterSessionsServer(server, g.sessionRpcServer) - litrpc.RegisterAccountsServer(server, g.accountRpcServer) + + if !g.cfg.Accounts.Disable { + litrpc.RegisterAccountsServer( + server, g.accountRpcServer, + ) + } } litrpc.RegisterFirewallServer(server, g.sessionRpcServer) @@ -960,11 +1007,13 @@ func (g *LightningTerminal) RegisterRestSubserver(ctx context.Context, return err } - err = litrpc.RegisterAccountsHandlerFromEndpoint( - ctx, mux, endpoint, dialOpts, - ) - if err != nil { - return err + if !g.cfg.Accounts.Disable { + err = litrpc.RegisterAccountsHandlerFromEndpoint( + ctx, mux, endpoint, dialOpts, + ) + if err != nil { + return err + } } err = litrpc.RegisterFirewallHandlerFromEndpoint(