diff --git a/macaroon_recipes.go b/macaroon_recipes.go index 6a83fef..9551ece 100644 --- a/macaroon_recipes.go +++ b/macaroon_recipes.go @@ -27,23 +27,24 @@ var ( // implemented in lndclient and the value is the original name of the // RPC method defined in the proto. renames = map[string]string{ - "ChannelBackup": "ExportChannelBackup", - "ChannelBackups": "ExportAllChannelBackups", - "ConfirmedWalletBalance": "WalletBalance", - "Connect": "ConnectPeer", - "DecodePaymentRequest": "DecodePayReq", - "ListTransactions": "GetTransactions", - "PayInvoice": "SendPaymentSync", - "UpdateChanPolicy": "UpdateChannelPolicy", - "NetworkInfo": "GetNetworkInfo", - "SubscribeGraph": "SubscribeChannelGraph", - "InterceptHtlcs": "HtlcInterceptor", - "ImportMissionControl": "XImportMissionControl", - "EstimateFeeRate": "EstimateFee", - "EstimateFeeToP2WSH": "EstimateFee", - "OpenChannelStream": "OpenChannel", - "ListSweepsVerbose": "ListSweeps", - "MinRelayFee": "EstimateFee", + "ChannelBackup": "ExportChannelBackup", + "ChannelBackups": "ExportAllChannelBackups", + "ConfirmedWalletBalance": "WalletBalance", + "Connect": "ConnectPeer", + "DecodePaymentRequest": "DecodePayReq", + "ListTransactions": "GetTransactions", + "PayInvoice": "SendPaymentSync", + "UpdateChanPolicy": "UpdateChannelPolicy", + "NetworkInfo": "GetNetworkInfo", + "SubscribeGraph": "SubscribeChannelGraph", + "InterceptHtlcs": "HtlcInterceptor", + "ImportMissionControl": "XImportMissionControl", + "EstimateFeeRate": "EstimateFee", + "EstimateFeeToP2WSH": "EstimateFee", + "OpenChannelStream": "OpenChannel", + "ListSweepsVerbose": "ListSweeps", + "MinRelayFee": "EstimateFee", + "SignOutputRawKeyLocator": "SignOutputRaw", } // ignores is a list of method names on the client implementations that diff --git a/signer_client.go b/signer_client.go index 7fa0b84..d7de4a7 100644 --- a/signer_client.go +++ b/signer_client.go @@ -28,6 +28,16 @@ type SignerClient interface { signDescriptors []*SignDescriptor, prevOutputs []*wire.TxOut) ([][]byte, error) + // SignOutputRawKeyLocator is a copy of the SignOutputRaw that fixes a + // specific issue around how the key locator is populated in the sign + // descriptor. We copy this method instead of fixing the original to + // make sure we don't break any existing applications that have already + // adjusted themselves to use the specific behavior of the original + // SignOutputRaw method. + SignOutputRawKeyLocator(ctx context.Context, tx *wire.MsgTx, + signDescriptors []*SignDescriptor, + prevOutputs []*wire.TxOut) ([][]byte, error) + // ComputeInputScript generates the proper input script for P2WPKH // output and NP2WPKH outputs. This method only requires that the // `Output`, `HashType`, `SigHashes` and `InputIndex` fields are @@ -215,26 +225,70 @@ func (s *signerClient) RawClientWithMacAuth( return s.signerMac.WithMacaroonAuth(parentCtx), s.timeout, s.client } -func marshallSignDescriptors( - signDescriptors []*SignDescriptor) []*signrpc.SignDescriptor { +func marshallSignDescriptors(signDescriptors []*SignDescriptor, + fullDescriptors bool) []*signrpc.SignDescriptor { - rpcSignDescs := make([]*signrpc.SignDescriptor, len(signDescriptors)) - for i, signDesc := range signDescriptors { - var keyBytes []byte - var keyLocator *signrpc.KeyLocator - if signDesc.KeyDesc.PubKey != nil { - keyBytes = signDesc.KeyDesc.PubKey.SerializeCompressed() + // partialDescriptor is a helper method that creates a partially + // populated sign descriptor that is backward compatible with the way + // some applications like Loop expect the call to lnd to be made. This + // function only populates _either_ the public key or the key locator in + // the descriptor, but not both. + partialDescriptor := func( + d keychain.KeyDescriptor) *signrpc.KeyDescriptor { + + keyDesc := &signrpc.KeyDescriptor{} + if d.PubKey != nil { + keyDesc.RawKeyBytes = d.PubKey.SerializeCompressed() } else { - keyLocator = &signrpc.KeyLocator{ - KeyFamily: int32( - signDesc.KeyDesc.KeyLocator.Family, - ), - KeyIndex: int32( - signDesc.KeyDesc.KeyLocator.Index, - ), + keyDesc.KeyLoc = &signrpc.KeyLocator{ + KeyFamily: int32(d.KeyLocator.Family), + KeyIndex: int32(d.KeyLocator.Index), } } + return keyDesc + } + + // fullDescriptor is a helper method that creates a fully populated sign + // descriptor that includes both the public key and the key locator (if + // available). For the locator we explicitly check that both the family + // _and_ the index is non-zero. In some applications it's possible that + // the family is always set (because only a specific family is used), + // but the index might be zero because it's the first key, or because it + // isn't known at that particular moment. + // We aim to be compatible with this method in lnd's wallet: + // https://github.com/lightningnetwork/lnd/blob/master/lnwallet/btcwallet/signer.go#L286 + // Because we know all custom families (0 to 255) are derived at wallet + // creation, and the very first index of each family/account is always + // derived, we know that only using the public key for that very first + // index will work. But for a freshly initialized wallet (e.g. restored + // from seed), we won't know any indexes greater than 0, so we _need_ to + // also specify the key locator and not just the public key. + fullDescriptor := func( + d keychain.KeyDescriptor) *signrpc.KeyDescriptor { + + keyDesc := &signrpc.KeyDescriptor{} + if d.PubKey != nil { + keyDesc.RawKeyBytes = d.PubKey.SerializeCompressed() + } + + if d.KeyLocator.Family != 0 && d.KeyLocator.Index != 0 { + keyDesc.KeyLoc = &signrpc.KeyLocator{ + KeyFamily: int32(d.KeyLocator.Family), + KeyIndex: int32(d.KeyLocator.Index), + } + } + + return keyDesc + } + + rpcSignDescs := make([]*signrpc.SignDescriptor, len(signDescriptors)) + for i, signDesc := range signDescriptors { + keyDesc := partialDescriptor(signDesc.KeyDesc) + if fullDescriptors { + keyDesc = fullDescriptor(signDesc.KeyDesc) + } + var doubleTweak []byte if signDesc.DoubleTweak != nil { doubleTweak = signDesc.DoubleTweak.Serialize() @@ -247,12 +301,9 @@ func marshallSignDescriptors( PkScript: signDesc.Output.PkScript, Value: signDesc.Output.Value, }, - Sighash: uint32(signDesc.HashType), - InputIndex: int32(signDesc.InputIndex), - KeyDesc: &signrpc.KeyDescriptor{ - RawKeyBytes: keyBytes, - KeyLoc: keyLocator, - }, + Sighash: uint32(signDesc.HashType), + InputIndex: int32(signDesc.InputIndex), + KeyDesc: keyDesc, SingleTweak: signDesc.SingleTweak, DoubleTweak: doubleTweak, TapTweak: signDesc.TapTweak, @@ -283,11 +334,32 @@ func (s *signerClient) SignOutputRaw(ctx context.Context, tx *wire.MsgTx, signDescriptors []*SignDescriptor, prevOutputs []*wire.TxOut) ([][]byte, error) { + return s.signOutputRaw(ctx, tx, signDescriptors, prevOutputs, false) +} + +// SignOutputRawKeyLocator is a copy of the SignOutputRaw that fixes a specific +// issue around how the key locator is populated in the sign descriptor. We copy +// this method instead of fixing the original to make sure we don't break any +// existing applications that have already adjusted themselves to use the +// specific behavior of the original SignOutputRaw method. +func (s *signerClient) SignOutputRawKeyLocator(ctx context.Context, + tx *wire.MsgTx, signDescriptors []*SignDescriptor, + prevOutputs []*wire.TxOut) ([][]byte, error) { + + return s.signOutputRaw(ctx, tx, signDescriptors, prevOutputs, true) +} + +// signOutputRaw is a helper method that performs the actual signing of the +// transaction. +func (s *signerClient) signOutputRaw(ctx context.Context, tx *wire.MsgTx, + signDescriptors []*SignDescriptor, prevOutputs []*wire.TxOut, + fullDescriptor bool) ([][]byte, error) { + txRaw, err := encodeTx(tx) if err != nil { return nil, err } - rpcSignDescs := marshallSignDescriptors(signDescriptors) + rpcSignDescs := marshallSignDescriptors(signDescriptors, fullDescriptor) rpcPrevOutputs := marshallTxOut(prevOutputs) rpcCtx, cancel := context.WithTimeout(ctx, s.timeout) @@ -321,7 +393,7 @@ func (s *signerClient) ComputeInputScript(ctx context.Context, tx *wire.MsgTx, if err != nil { return nil, err } - rpcSignDescs := marshallSignDescriptors(signDescriptors) + rpcSignDescs := marshallSignDescriptors(signDescriptors, false) rpcPrevOutputs := marshallTxOut(prevOutputs) rpcCtx, cancel := context.WithTimeout(ctx, s.timeout)