Skip to content

Commit

Permalink
Merge pull request #62 from nyaruka/easier_locks
Browse files Browse the repository at this point in the history
🔐 Refactor how we lock and unlock contacts
  • Loading branch information
rowanseymour authored Apr 27, 2023
2 parents f280131 + c8d3a98 commit 4dfb9a2
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 75 deletions.
60 changes: 54 additions & 6 deletions core/models/contacts.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/nyaruka/goflow/envs"
"github.com/nyaruka/goflow/excellent/types"
"github.com/nyaruka/goflow/flows"
"github.com/nyaruka/mailroom/runtime"
"github.com/nyaruka/null/v2"
"github.com/nyaruka/redisx"
"github.com/pkg/errors"
Expand Down Expand Up @@ -1315,12 +1316,6 @@ func (i ContactID) Value() (driver.Value, error) { return null.IntValue(i) }
func (i *ContactID) UnmarshalJSON(b []byte) error { return null.UnmarshalInt(b, i) }
func (i ContactID) MarshalJSON() ([]byte, error) { return null.MarshalInt(i) }

// GetContactLocker returns the locker for a particular contact
func GetContactLocker(orgID OrgID, contactID ContactID) *redisx.Locker {
key := fmt.Sprintf("lock:c:%d:%d", orgID, contactID)
return redisx.NewLocker(key, time.Minute*5)
}

// ContactStatusChange struct used for our contact status change
type ContactStatusChange struct {
ContactID ContactID
Expand Down Expand Up @@ -1384,3 +1379,56 @@ FROM (
WHERE
c.id = r.id::int
`

// LockContacts tries to grab locks for the given contacts, returning the locks and the skipped contacts
func LockContacts(rt *runtime.Runtime, orgID OrgID, ids []ContactID, retry time.Duration) (map[ContactID]string, []ContactID, error) {
locks := make(map[ContactID]string, len(ids))
skipped := make([]ContactID, 0, 5)

success := false

for _, contactID := range ids {
locker := getContactLocker(orgID, contactID)

lock, err := locker.Grab(rt.RP, retry)
if err != nil {
return nil, nil, errors.Wrapf(err, "error attempting to grab lock")
}

// no error but we didn't get the lock
if lock == "" {
skipped = append(skipped, contactID)
continue
}

locks[contactID] = lock

// if we error we want to release all locks on way out
defer func() {
if !success {
locker.Release(rt.RP, lock)
}
}()
}

success = true
return locks, skipped, nil
}

// UnlockContacts unlocks the given contacts using the given lock values
func UnlockContacts(rt *runtime.Runtime, orgID OrgID, locks map[ContactID]string) error {
for contactID, lock := range locks {
locker := getContactLocker(orgID, contactID)

err := locker.Release(rt.RP, lock)
if err != nil {
return err
}
}
return nil
}

// returns the locker for a particular contact
func getContactLocker(orgID OrgID, contactID ContactID) *redisx.Locker {
return redisx.NewLocker(fmt.Sprintf("lock:c:%d:%d", orgID, contactID), time.Minute*5)
}
30 changes: 30 additions & 0 deletions core/models/contacts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ import (
"github.com/nyaruka/mailroom/testsuite"
"github.com/nyaruka/mailroom/testsuite/testdata"
"github.com/nyaruka/mailroom/utils/test"
"github.com/nyaruka/redisx/assertredis"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)

func TestContacts(t *testing.T) {
Expand Down Expand Up @@ -607,3 +609,31 @@ func TestUpdateContactURNs(t *testing.T) {

assertdb.Query(t, rt.DB, `SELECT count(*) FROM contacts_contacturn`).Returns(numInitialURNs + 3)
}

func TestLockContacts(t *testing.T) {
_, rt := testsuite.Runtime()

defer testsuite.Reset(testsuite.ResetRedis)

// grab lock for contact #102
models.LockContacts(rt, testdata.Org1.ID, []models.ContactID{102}, time.Second)

assertredis.Exists(t, rt.RP, "lock:c:1:102")

// try to get locks for #101, #102, #103
locks, skipped, err := models.LockContacts(rt, testdata.Org1.ID, []models.ContactID{101, 102, 103}, time.Second)
assert.NoError(t, err)
assert.ElementsMatch(t, []models.ContactID{101, 103}, maps.Keys(locks))
assert.Equal(t, []models.ContactID{102}, skipped) // because it's already locked

assertredis.Exists(t, rt.RP, "lock:c:1:101")
assertredis.Exists(t, rt.RP, "lock:c:1:102")
assertredis.Exists(t, rt.RP, "lock:c:1:103")

err = models.UnlockContacts(rt, testdata.Org1.ID, locks)
assert.NoError(t, err)

assertredis.NotExists(t, rt.RP, "lock:c:1:101")
assertredis.Exists(t, rt.RP, "lock:c:1:102")
assertredis.NotExists(t, rt.RP, "lock:c:1:103")
}
102 changes: 40 additions & 62 deletions core/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (
"github.com/nyaruka/mailroom/core/goflow"
"github.com/nyaruka/mailroom/core/models"
"github.com/nyaruka/mailroom/runtime"
"github.com/nyaruka/redisx"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)

const (
Expand Down Expand Up @@ -248,11 +248,8 @@ func StartFlowBatch(
return sessions, nil
}

// StartFlow runs the passed in flow for the passed in contact
func StartFlow(
ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets,
flow *models.Flow, contactIDs []models.ContactID, options *StartOptions) ([]*models.Session, error) {

// StartFlow runs the passed in flow for the passed in contacts
func StartFlow(ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets, flow *models.Flow, contactIDs []models.ContactID, options *StartOptions) ([]*models.Session, error) {
if len(contactIDs) == 0 {
return nil, nil
}
Expand Down Expand Up @@ -304,74 +301,55 @@ func StartFlow(
remaining := includedContacts
start := time.Now()

// map of locks we've released
released := make(map[*redisx.Locker]bool)

for len(remaining) > 0 && time.Since(start) < time.Minute*5 {
locked := make([]models.ContactID, 0, len(remaining))
locks := make([]string, 0, len(remaining))
skipped := make([]models.ContactID, 0, 5)
ss, skipped, err := tryToStartWithLock(ctx, rt, oa, flow, remaining, options)
if err != nil {
return nil, err
}

// try up to a second to get a lock for a contact
for _, contactID := range remaining {
locker := models.GetContactLocker(oa.OrgID(), contactID)
sessions = append(sessions, ss...)
remaining = skipped // skipped are now our remaining
}

lock, err := locker.Grab(rt.RP, time.Second)
if err != nil {
return nil, errors.Wrapf(err, "error attempting to grab lock")
}
if lock == "" {
skipped = append(skipped, contactID)
continue
}
locked = append(locked, contactID)
locks = append(locks, lock)
return sessions, nil
}

// defer unlocking if we exit due to error
defer func() {
if !released[locker] {
locker.Release(rt.RP, lock)
}
}()
}
// tries to start the given contacts, returning sessions for those we could, and the ids that were skipped because we
// couldn't get their locks
func tryToStartWithLock(ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets, flow *models.Flow, ids []models.ContactID, options *StartOptions) ([]*models.Session, []models.ContactID, error) {
// try to get locks for these contacts, waiting for up to a second for each contact
locks, skipped, err := models.LockContacts(rt, oa.OrgID(), ids, time.Second)
if err != nil {
return nil, nil, err
}
locked := maps.Keys(locks)

// load our locked contacts
contacts, err := models.LoadContacts(ctx, rt.ReadonlyDB, oa, locked)
if err != nil {
return nil, errors.Wrapf(err, "error loading contacts to start")
}
// whatever happens, we need to unlock the contacts
defer models.UnlockContacts(rt, oa.OrgID(), locks)

// ok, we've filtered our contacts, build our triggers
triggers := make([]flows.Trigger, 0, len(locked))
for _, c := range contacts {
contact, err := c.FlowContact(oa)
if err != nil {
return nil, errors.Wrapf(err, "error creating flow contact")
}
trigger := options.TriggerBuilder(contact)
triggers = append(triggers, trigger)
}
// load our locked contacts
contacts, err := models.LoadContacts(ctx, rt.ReadonlyDB, oa, locked)
if err != nil {
return nil, nil, errors.Wrapf(err, "error loading contacts to start")
}

ss, err := StartFlowForContacts(ctx, rt, oa, flow, contacts, triggers, options.CommitHook, options.Interrupt)
// build our triggers
triggers := make([]flows.Trigger, 0, len(locked))
for _, c := range contacts {
contact, err := c.FlowContact(oa)
if err != nil {
return nil, errors.Wrapf(err, "error starting flow for contacts")
}

// append all the sessions that were started
sessions = append(sessions, ss...)

// release all our locks
for i := range locked {
locker := models.GetContactLocker(oa.OrgID(), locked[i])
locker.Release(rt.RP, locks[i])
released[locker] = true
return nil, nil, errors.Wrapf(err, "error creating flow contact")
}
trigger := options.TriggerBuilder(contact)
triggers = append(triggers, trigger)
}

// skipped are now our remaining
remaining = skipped
ss, err := StartFlowForContacts(ctx, rt, oa, flow, contacts, triggers, options.CommitHook, options.Interrupt)
if err != nil {
return nil, nil, errors.Wrapf(err, "error starting flow for contacts")
}

return sessions, nil
return ss, skipped, nil
}

// StartFlowForContacts runs the passed in flow for the passed in contact
Expand Down
13 changes: 6 additions & 7 deletions core/tasks/handler/handle_contact_event.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,14 @@ func (t *HandleContactEventTask) Timeout() time.Duration {
// Perform is called when an event comes in for a contact. To make sure we don't get into a situation of being off by one,
// this task ingests and handles all the events for a contact, one by one.
func (t *HandleContactEventTask) Perform(ctx context.Context, rt *runtime.Runtime, orgID models.OrgID) error {
// acquire the lock for this contact
locker := models.GetContactLocker(orgID, t.ContactID)

lock, err := locker.Grab(rt.RP, time.Second*10)
// try to get the lock for this contact, waiting up to 10 seconds
locks, _, err := models.LockContacts(rt, orgID, []models.ContactID{t.ContactID}, time.Second*10)
if err != nil {
return errors.Wrapf(err, "error acquiring lock for contact %d", t.ContactID)
}

// we didn't get the lock within our timeout, skip and requeue for later
if lock == "" {
// we didn't get the lock.. requeue for later
if len(locks) == 0 {
rc := rt.RP.Get()
defer rc.Close()
err = tasks.Queue(rc, queue.HandlerQueue, orgID, &HandleContactEventTask{ContactID: t.ContactID}, queue.DefaultPriority)
Expand All @@ -60,7 +58,8 @@ func (t *HandleContactEventTask) Perform(ctx context.Context, rt *runtime.Runtim
logrus.WithFields(logrus.Fields{"org_id": orgID, "contact_id": t.ContactID}).Info("failed to get lock for contact, requeued and skipping")
return nil
}
defer locker.Release(rt.RP, lock)

defer models.UnlockContacts(rt, orgID, locks)

// read all the events for this contact, one by one
contactQ := fmt.Sprintf("c:%d:%d", orgID, t.ContactID)
Expand Down

0 comments on commit 4dfb9a2

Please sign in to comment.