Skip to content

Commit

Permalink
Use a callback to receive secret response
Browse files Browse the repository at this point in the history
To properly receive and store a requested secret, we usually need to
validate it against something like a public key to ensure we got the
correct one.

This changes the API so that we instead use a callback to receive any
incoming secret matching our request but we'll fail when we hit the
specified timeout if we never receive anything that is accepted.
  • Loading branch information
hifi committed Mar 15, 2024
1 parent a7bf485 commit fad4448
Showing 1 changed file with 39 additions and 23 deletions.
62 changes: 39 additions & 23 deletions crypto/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,26 @@ import (
"maunium.net/go/mautrix/id"
)

func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, timeout time.Duration) (secret string, err error) {
secret, err = mach.CryptoStore.GetSecret(ctx, name)
if err != nil || secret != "" {
return
// Callback function to process a received secret.
//
// Returning true or an error will immediately return from the wait loop, returning false will continue waiting for new responses.
type SecretReceiverFunc func(string) (bool, error)

func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, receiver SecretReceiverFunc, timeout time.Duration) (err error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// always offer our stored secret first, if any
secret, err := mach.CryptoStore.GetSecret(ctx, name)
if err != nil {
return err
} else if secret != "" {
if ok, err := receiver(secret); ok || err != nil {
return err
}
}

requestID, secretChan := random.String(64), make(chan string, 1)
requestID, secretChan := random.String(64), make(chan string, 5)
mach.secretLock.Lock()
mach.secretListeners[requestID] = secretChan
mach.secretLock.Unlock()
Expand All @@ -43,17 +56,27 @@ func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret,
return
}

select {
case <-ctx.Done():
err = ctx.Err()
case <-time.After(timeout):
case secret = <-secretChan:
}
// best effort cancel request from all devices when returning
defer func() {
go mach.sendToOneDevice(context.Background(), mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
Action: event.SecretRequestCancellation,
RequestID: requestID,
RequestingDeviceID: mach.Client.DeviceID,
})
}()

if secret != "" {
err = mach.CryptoStore.PutSecret(ctx, name, secret)
for {
select {
case <-ctx.Done():
return ctx.Err()
case secret = <-secretChan:
if ok, err := receiver(secret); err != nil {
return err
} else if ok {
return mach.CryptoStore.PutSecret(ctx, name, secret)
}
}
}
return
}

func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) {
Expand Down Expand Up @@ -159,17 +182,10 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven
return
}

// secret channel is buffered and we don't want to block
// at worst we drop _some_ of the responses
select {
case secretChan <- content.Secret:
default:
}

// best effort cancel this for all other targets
go func() {
mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
Action: event.SecretRequestCancellation,
RequestID: content.RequestID,
RequestingDeviceID: mach.Client.DeviceID,
})
}()
}

0 comments on commit fad4448

Please sign in to comment.