diff --git a/examples/remote-signing-server/server.go b/examples/remote-signing-server/server.go index dc77d75..9bd08df 100644 --- a/examples/remote-signing-server/server.go +++ b/examples/remote-signing-server/server.go @@ -14,7 +14,9 @@ import ( func getValidator(config *Config) remotesigning.Validator { if config.ValidationEnabled { - return &remotesigning.HashValidator{} + return remotesigning.NewMultiValidator( + remotesigning.HashValidator{}, + remotesigning.DestinationValidator{}) } else { return &remotesigning.PositiveValidator{} } @@ -75,7 +77,7 @@ func main() { case objects.WebhookEventTypeRemoteSigning: if config.RespondDirectly { resp, err := remotesigning.GraphQLResponseForRemoteSigningWebhook( - remotesigning.HashValidator{}, *event, config.MasterSeed) + validator, *event, config.MasterSeed) if err != nil { log.Printf("ERROR: Unable to handle remote signing webhook: %s", err) c.AbortWithStatus(http.StatusInternalServerError) diff --git a/remotesigning/validator.go b/remotesigning/validator.go index fe53caf..2fdc543 100644 --- a/remotesigning/validator.go +++ b/remotesigning/validator.go @@ -18,6 +18,23 @@ func (v PositiveValidator) ShouldSign(webhook webhooks.WebhookEvent, xpubs []str return true } +type MultiValidator struct{ + validators []Validator +} + +func NewMultiValidator(validators ...Validator) MultiValidator { + return MultiValidator{validators: validators} +} + +func (v MultiValidator) ShouldSign(webhookEvent webhooks.WebhookEvent, xpubs []string) bool { + for _, validator := range v.validators { + if !validator.ShouldSign(webhookEvent, xpubs) { + return false + } + } + return true +} + type HashValidator struct{} func (v HashValidator) ShouldSign(webhookEvent webhooks.WebhookEvent, xpubs []string) bool { @@ -26,21 +43,15 @@ func (v HashValidator) ShouldSign(webhookEvent webhooks.WebhookEvent, xpubs []st // Only validate DeriveAndSignRequest events return true } - for i, signing := range request.SigningJobs { - if strings.HasPrefix(signing.DerivationPath, "m/84") { - if !ValidateL1Transaction(&signing, xpubs[i]) { - return false - } - } else { - if !ValidateLightningTransaction(&signing) { - return false - } + for _, signing := range request.SigningJobs { + if !ValidateWitnessHash(&signing) { + return false } } return true } -func ValidateLightningTransaction(signing *SigningJob) bool { +func ValidateWitnessHash(signing *SigningJob) bool { if strings.HasSuffix(signing.DerivationPath, "/2") || strings.HasSuffix(signing.DerivationPath, "/3") { msg, err := CalculateWitnessHashPSBT(*signing.Transaction) if err != nil { @@ -61,23 +72,25 @@ func ValidateLightningTransaction(signing *SigningJob) bool { return true } -func ValidateL1Transaction(signing *SigningJob, xpub string) bool { - // 1. Address Validation - isValid, err := ValidateScript(signing, xpub) - if err != nil { - return false - } - if !isValid { - return false - } +// A validator that checks that the outputs of a transaction pass restrictions +// for where we allow sending funds. This is used to ensure that when signing +// transactions spending L1 wallet funds, we are only sending funds to certain +// addresses. +type DestinationValidator struct{} - // 2. Witness Hash Validation - msg, err := CalculateWitnessHash(*signing.Amount, *signing.Script, *signing.Transaction) +func (v DestinationValidator) ShouldSign(webhookEvent webhooks.WebhookEvent, xpubs []string) bool { + request, err := ParseDeriveAndSignRequest(webhookEvent) if err != nil { - return false + // Only validate DeriveAndSignRequest events + return true } - if strings.Compare(*msg, signing.Message) != 0 { - return false + for i, signing := range request.SigningJobs { + if strings.HasPrefix(signing.DerivationPath, "m/84") { + validScript, err := ValidateScript(&signing, xpubs[i]) + if err != nil || !validScript { + return false + } + } } return true }