diff --git a/aperture.go b/aperture.go index 2463eb5..6878ea7 100644 --- a/aperture.go +++ b/aperture.go @@ -112,7 +112,7 @@ func run() error { } errChan := make(chan error) challenger, err := NewLndChallenger( - cfg.Authenticator, genInvoiceReq, errChan, + cfg.Authenticator, genInvoiceReq, nil, errChan, ) if err != nil { return err @@ -452,7 +452,7 @@ func createProxy(cfg *config, challenger *LndChallenger, minter := mint.New(&mint.Config{ Challenger: challenger, Secrets: newSecretStore(etcdClient), - ServiceLimiter: newStaticServiceLimiter(cfg.Services), + ServiceLimiter: NewStaticServiceLimiter(cfg.Services), }) authenticator := auth.NewLsatAuthenticator(minter, challenger) return proxy.New( diff --git a/challenger.go b/challenger.go index f35a6ad..9c4ecd6 100644 --- a/challenger.go +++ b/challenger.go @@ -21,6 +21,9 @@ import ( // lnrpc.AddInvoice call. type InvoiceRequestGenerator func(price int64) (*lnrpc.Invoice, error) +type VerifyInvoiceStatusFunc func(hash lntypes.Hash, + state lnrpc.Invoice_InvoiceState, timeout time.Duration) error + // InvoiceClient is an interface that only implements part of a full lnd client, // namely the part around the invoices we need for the challenger to work. type InvoiceClient interface { @@ -41,8 +44,9 @@ type InvoiceClient interface { // LndChallenger is a challenger that uses an lnd backend to create new LSAT // payment challenges. type LndChallenger struct { - client InvoiceClient - genInvoiceReq InvoiceRequestGenerator + Client InvoiceClient + GenInvoiceReq InvoiceRequestGenerator + VerifyInvoiceStatusFunc VerifyInvoiceStatusFunc invoiceStates map[lntypes.Hash]lnrpc.Invoice_InvoiceState invoicesMtx *sync.Mutex @@ -69,7 +73,7 @@ const ( // NewLndChallenger creates a new challenger that uses the given connection // details to connect to an lnd backend to create payment challenges. func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator, - errChan chan<- error) (*LndChallenger, error) { + verifyInvoiceStatus VerifyInvoiceStatusFunc, errChan chan<- error) (*LndChallenger, error) { if genInvoiceReq == nil { return nil, fmt.Errorf("genInvoiceReq cannot be nil") @@ -84,15 +88,20 @@ func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator, } invoicesMtx := &sync.Mutex{} - return &LndChallenger{ - client: client, - genInvoiceReq: genInvoiceReq, + c := &LndChallenger{ + Client: client, + GenInvoiceReq: genInvoiceReq, invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), invoicesMtx: invoicesMtx, invoicesCond: sync.NewCond(invoicesMtx), quit: make(chan struct{}), errChan: errChan, - }, nil + } + if verifyInvoiceStatus == nil { + c.VerifyInvoiceStatusFunc = c.DefaultVerifyInvoiceStatus + } + + return c, nil } // Start starts the challenger's main work which is to keep track of all @@ -111,7 +120,7 @@ func (l *LndChallenger) Start() error { // cache. We need to keep track of all invoices, even quite old ones to // make sure tokens are valid. But to save space we only keep track of // an invoice's state. - invoiceResp, err := l.client.ListInvoices( + invoiceResp, err := l.Client.ListInvoices( context.Background(), &lnrpc.ListInvoiceRequest{ NumMaxInvoices: math.MaxUint64, }, @@ -148,7 +157,7 @@ func (l *LndChallenger) Start() error { ctxc, cancel := context.WithCancel(context.Background()) l.invoicesCancel = cancel - subscriptionResp, err := l.client.SubscribeInvoices( + subscriptionResp, err := l.Client.SubscribeInvoices( ctxc, &lnrpc.InvoiceSubscription{ AddIndex: addIndex, SettleIndex: settleIndex, @@ -264,13 +273,13 @@ func (l *LndChallenger) Stop() { func (l *LndChallenger) NewChallenge(price int64) (string, lntypes.Hash, error) { // Obtain a new invoice from lnd first. We need to know the payment hash // so we can add it as a caveat to the macaroon. - invoice, err := l.genInvoiceReq(price) + invoice, err := l.GenInvoiceReq(price) if err != nil { log.Errorf("Error generating invoice request: %v", err) return "", lntypes.ZeroHash, err } ctx := context.Background() - response, err := l.client.AddInvoice(ctx, invoice) + response, err := l.Client.AddInvoice(ctx, invoice) if err != nil { log.Errorf("Error adding invoice: %v", err) return "", lntypes.ZeroHash, err @@ -285,14 +294,24 @@ func (l *LndChallenger) NewChallenge(price int64) (string, lntypes.Hash, error) } // VerifyInvoiceStatus checks that an invoice identified by a payment -// hash has the desired status. To make sure we don't fail while the -// invoice update is still on its way, we try several times until either -// the desired status is set or the given timeout is reached. +// hash has the desired status. An optional invoice checker which could +// be customized by implementer or using the default implementation +// `DefaultVerifyInvoiceStatus`. // // NOTE: This is part of the auth.InvoiceChecker interface. func (l *LndChallenger) VerifyInvoiceStatus(hash lntypes.Hash, state lnrpc.Invoice_InvoiceState, timeout time.Duration) error { + return l.VerifyInvoiceStatusFunc(hash, state, timeout) +} + +// DefaultVerifyInvoiceStatus checks that an invoice identified by a payment +// hash has the desired status. To make sure we don't fail while the +// invoice update is still on its way, we try several times until either +// the desired status is set or the given timeout is reached. +func (l *LndChallenger) DefaultVerifyInvoiceStatus(hash lntypes.Hash, + state lnrpc.Invoice_InvoiceState, timeout time.Duration) error { + // Prevent the challenger to be shut down while we're still waiting for // status updates. l.wg.Add(1) diff --git a/challenger_test.go b/challenger_test.go index f0bf7e0..49ee57f 100644 --- a/challenger_test.go +++ b/challenger_test.go @@ -100,15 +100,18 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient, chan error) { } invoicesMtx := &sync.Mutex{} mainErrChan := make(chan error) - return &LndChallenger{ - client: mockClient, - genInvoiceReq: genInvoiceReq, + c := &LndChallenger{ + Client: mockClient, + GenInvoiceReq: genInvoiceReq, invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), quit: make(chan struct{}), invoicesMtx: invoicesMtx, invoicesCond: sync.NewCond(invoicesMtx), errChan: mainErrChan, - }, mockClient, mainErrChan + } + c.VerifyInvoiceStatusFunc = c.DefaultVerifyInvoiceStatus + + return c, mockClient, mainErrChan } func newInvoice(hash lntypes.Hash, addIndex uint64, @@ -130,7 +133,7 @@ func TestLndChallenger(t *testing.T) { // First of all, test that the NewLndChallenger doesn't allow a nil // invoice generator function. errChan := make(chan error) - _, err := NewLndChallenger(nil, nil, errChan) + _, err := NewLndChallenger(nil, nil, nil, errChan) require.Error(t, err) // Now mock the lnd backend and create a challenger instance that we can diff --git a/services.go b/services.go index 0f5a51d..7701b06 100644 --- a/services.go +++ b/services.go @@ -8,21 +8,21 @@ import ( "github.com/lightninglabs/aperture/proxy" ) -// staticServiceLimiter provides static restrictions for services. +// StaticServiceLimiter provides static restrictions for services. // // TODO(wilmer): use etcd instead. -type staticServiceLimiter struct { - capabilities map[lsat.Service]lsat.Caveat - constraints map[lsat.Service][]lsat.Caveat +type StaticServiceLimiter struct { + Capabilities map[lsat.Service]lsat.Caveat + Constraints map[lsat.Service][]lsat.Caveat } -// A compile-time constraint to ensure staticServiceLimiter implements +// A compile-time constraint to ensure StaticServiceLimiter implements // mint.ServiceLimiter. -var _ mint.ServiceLimiter = (*staticServiceLimiter)(nil) +var _ mint.ServiceLimiter = (*StaticServiceLimiter)(nil) -// newStaticServiceLimiter instantiates a new static service limiter backed by +// NewStaticServiceLimiter instantiates a new static service limiter backed by // the given restrictions. -func newStaticServiceLimiter(proxyServices []*proxy.Service) *staticServiceLimiter { +func NewStaticServiceLimiter(proxyServices []*proxy.Service) *StaticServiceLimiter { capabilities := make(map[lsat.Service]lsat.Caveat) constraints := make(map[lsat.Service][]lsat.Caveat) @@ -41,20 +41,20 @@ func newStaticServiceLimiter(proxyServices []*proxy.Service) *staticServiceLimit } } - return &staticServiceLimiter{ - capabilities: capabilities, - constraints: constraints, + return &StaticServiceLimiter{ + Capabilities: capabilities, + Constraints: constraints, } } // ServiceCapabilities returns the capabilities caveats for each service. This // determines which capabilities of each service can be accessed. -func (l *staticServiceLimiter) ServiceCapabilities(ctx context.Context, +func (l *StaticServiceLimiter) ServiceCapabilities(ctx context.Context, services ...lsat.Service) ([]lsat.Caveat, error) { res := make([]lsat.Caveat, 0, len(services)) for _, service := range services { - capabilities, ok := l.capabilities[service] + capabilities, ok := l.Capabilities[service] if !ok { continue } @@ -66,12 +66,12 @@ func (l *staticServiceLimiter) ServiceCapabilities(ctx context.Context, // ServiceConstraints returns the constraints for each service. This enforces // additional constraints on a particular service/service capability. -func (l *staticServiceLimiter) ServiceConstraints(ctx context.Context, +func (l *StaticServiceLimiter) ServiceConstraints(ctx context.Context, services ...lsat.Service) ([]lsat.Caveat, error) { res := make([]lsat.Caveat, 0, len(services)) for _, service := range services { - constraints, ok := l.constraints[service] + constraints, ok := l.Constraints[service] if !ok { continue }