From c8d3a9810966dffdeb2c9aaddb198d3e3bcc01b7 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Thu, 27 Apr 2023 13:36:08 -0500 Subject: [PATCH] Refactor how we lock and unlock contacts --- core/models/contacts.go | 60 ++++++++++-- core/models/contacts_test.go | 30 ++++++ core/runner/runner.go | 102 ++++++++------------- core/tasks/handler/handle_contact_event.go | 13 ++- 4 files changed, 130 insertions(+), 75 deletions(-) diff --git a/core/models/contacts.go b/core/models/contacts.go index db8281580..74a161fb8 100644 --- a/core/models/contacts.go +++ b/core/models/contacts.go @@ -18,6 +18,7 @@ import ( "github.com/nyaruka/goflow/envs" "github.com/nyaruka/goflow/excellent/types" "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/mailroom/runtime" "github.com/nyaruka/null/v2" "github.com/nyaruka/redisx" "github.com/pkg/errors" @@ -1315,12 +1316,6 @@ func (i ContactID) Value() (driver.Value, error) { return null.IntValue(i) } func (i *ContactID) UnmarshalJSON(b []byte) error { return null.UnmarshalInt(b, i) } func (i ContactID) MarshalJSON() ([]byte, error) { return null.MarshalInt(i) } -// GetContactLocker returns the locker for a particular contact -func GetContactLocker(orgID OrgID, contactID ContactID) *redisx.Locker { - key := fmt.Sprintf("lock:c:%d:%d", orgID, contactID) - return redisx.NewLocker(key, time.Minute*5) -} - // ContactStatusChange struct used for our contact status change type ContactStatusChange struct { ContactID ContactID @@ -1384,3 +1379,56 @@ FROM ( WHERE c.id = r.id::int ` + +// LockContacts tries to grab locks for the given contacts, returning the locks and the skipped contacts +func LockContacts(rt *runtime.Runtime, orgID OrgID, ids []ContactID, retry time.Duration) (map[ContactID]string, []ContactID, error) { + locks := make(map[ContactID]string, len(ids)) + skipped := make([]ContactID, 0, 5) + + success := false + + for _, contactID := range ids { + locker := getContactLocker(orgID, contactID) + + lock, err := locker.Grab(rt.RP, retry) + if err != nil { + return nil, nil, errors.Wrapf(err, "error attempting to grab lock") + } + + // no error but we didn't get the lock + if lock == "" { + skipped = append(skipped, contactID) + continue + } + + locks[contactID] = lock + + // if we error we want to release all locks on way out + defer func() { + if !success { + locker.Release(rt.RP, lock) + } + }() + } + + success = true + return locks, skipped, nil +} + +// UnlockContacts unlocks the given contacts using the given lock values +func UnlockContacts(rt *runtime.Runtime, orgID OrgID, locks map[ContactID]string) error { + for contactID, lock := range locks { + locker := getContactLocker(orgID, contactID) + + err := locker.Release(rt.RP, lock) + if err != nil { + return err + } + } + return nil +} + +// returns the locker for a particular contact +func getContactLocker(orgID OrgID, contactID ContactID) *redisx.Locker { + return redisx.NewLocker(fmt.Sprintf("lock:c:%d:%d", orgID, contactID), time.Minute*5) +} diff --git a/core/models/contacts_test.go b/core/models/contacts_test.go index ed30efbe2..33fe4775e 100644 --- a/core/models/contacts_test.go +++ b/core/models/contacts_test.go @@ -15,9 +15,11 @@ import ( "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" "github.com/nyaruka/mailroom/utils/test" + "github.com/nyaruka/redisx/assertredis" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" ) func TestContacts(t *testing.T) { @@ -607,3 +609,31 @@ func TestUpdateContactURNs(t *testing.T) { assertdb.Query(t, rt.DB, `SELECT count(*) FROM contacts_contacturn`).Returns(numInitialURNs + 3) } + +func TestLockContacts(t *testing.T) { + _, rt := testsuite.Runtime() + + defer testsuite.Reset(testsuite.ResetRedis) + + // grab lock for contact #102 + models.LockContacts(rt, testdata.Org1.ID, []models.ContactID{102}, time.Second) + + assertredis.Exists(t, rt.RP, "lock:c:1:102") + + // try to get locks for #101, #102, #103 + locks, skipped, err := models.LockContacts(rt, testdata.Org1.ID, []models.ContactID{101, 102, 103}, time.Second) + assert.NoError(t, err) + assert.ElementsMatch(t, []models.ContactID{101, 103}, maps.Keys(locks)) + assert.Equal(t, []models.ContactID{102}, skipped) // because it's already locked + + assertredis.Exists(t, rt.RP, "lock:c:1:101") + assertredis.Exists(t, rt.RP, "lock:c:1:102") + assertredis.Exists(t, rt.RP, "lock:c:1:103") + + err = models.UnlockContacts(rt, testdata.Org1.ID, locks) + assert.NoError(t, err) + + assertredis.NotExists(t, rt.RP, "lock:c:1:101") + assertredis.Exists(t, rt.RP, "lock:c:1:102") + assertredis.NotExists(t, rt.RP, "lock:c:1:103") +} diff --git a/core/runner/runner.go b/core/runner/runner.go index 2f18c4893..7ce949d68 100644 --- a/core/runner/runner.go +++ b/core/runner/runner.go @@ -14,9 +14,9 @@ import ( "github.com/nyaruka/mailroom/core/goflow" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/redisx" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) const ( @@ -248,11 +248,8 @@ func StartFlowBatch( return sessions, nil } -// StartFlow runs the passed in flow for the passed in contact -func StartFlow( - ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets, - flow *models.Flow, contactIDs []models.ContactID, options *StartOptions) ([]*models.Session, error) { - +// StartFlow runs the passed in flow for the passed in contacts +func StartFlow(ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets, flow *models.Flow, contactIDs []models.ContactID, options *StartOptions) ([]*models.Session, error) { if len(contactIDs) == 0 { return nil, nil } @@ -304,74 +301,55 @@ func StartFlow( remaining := includedContacts start := time.Now() - // map of locks we've released - released := make(map[*redisx.Locker]bool) - for len(remaining) > 0 && time.Since(start) < time.Minute*5 { - locked := make([]models.ContactID, 0, len(remaining)) - locks := make([]string, 0, len(remaining)) - skipped := make([]models.ContactID, 0, 5) + ss, skipped, err := tryToStartWithLock(ctx, rt, oa, flow, remaining, options) + if err != nil { + return nil, err + } - // try up to a second to get a lock for a contact - for _, contactID := range remaining { - locker := models.GetContactLocker(oa.OrgID(), contactID) + sessions = append(sessions, ss...) + remaining = skipped // skipped are now our remaining + } - lock, err := locker.Grab(rt.RP, time.Second) - if err != nil { - return nil, errors.Wrapf(err, "error attempting to grab lock") - } - if lock == "" { - skipped = append(skipped, contactID) - continue - } - locked = append(locked, contactID) - locks = append(locks, lock) + return sessions, nil +} - // defer unlocking if we exit due to error - defer func() { - if !released[locker] { - locker.Release(rt.RP, lock) - } - }() - } +// tries to start the given contacts, returning sessions for those we could, and the ids that were skipped because we +// couldn't get their locks +func tryToStartWithLock(ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets, flow *models.Flow, ids []models.ContactID, options *StartOptions) ([]*models.Session, []models.ContactID, error) { + // try to get locks for these contacts, waiting for up to a second for each contact + locks, skipped, err := models.LockContacts(rt, oa.OrgID(), ids, time.Second) + if err != nil { + return nil, nil, err + } + locked := maps.Keys(locks) - // load our locked contacts - contacts, err := models.LoadContacts(ctx, rt.ReadonlyDB, oa, locked) - if err != nil { - return nil, errors.Wrapf(err, "error loading contacts to start") - } + // whatever happens, we need to unlock the contacts + defer models.UnlockContacts(rt, oa.OrgID(), locks) - // ok, we've filtered our contacts, build our triggers - triggers := make([]flows.Trigger, 0, len(locked)) - for _, c := range contacts { - contact, err := c.FlowContact(oa) - if err != nil { - return nil, errors.Wrapf(err, "error creating flow contact") - } - trigger := options.TriggerBuilder(contact) - triggers = append(triggers, trigger) - } + // load our locked contacts + contacts, err := models.LoadContacts(ctx, rt.ReadonlyDB, oa, locked) + if err != nil { + return nil, nil, errors.Wrapf(err, "error loading contacts to start") + } - ss, err := StartFlowForContacts(ctx, rt, oa, flow, contacts, triggers, options.CommitHook, options.Interrupt) + // build our triggers + triggers := make([]flows.Trigger, 0, len(locked)) + for _, c := range contacts { + contact, err := c.FlowContact(oa) if err != nil { - return nil, errors.Wrapf(err, "error starting flow for contacts") - } - - // append all the sessions that were started - sessions = append(sessions, ss...) - - // release all our locks - for i := range locked { - locker := models.GetContactLocker(oa.OrgID(), locked[i]) - locker.Release(rt.RP, locks[i]) - released[locker] = true + return nil, nil, errors.Wrapf(err, "error creating flow contact") } + trigger := options.TriggerBuilder(contact) + triggers = append(triggers, trigger) + } - // skipped are now our remaining - remaining = skipped + ss, err := StartFlowForContacts(ctx, rt, oa, flow, contacts, triggers, options.CommitHook, options.Interrupt) + if err != nil { + return nil, nil, errors.Wrapf(err, "error starting flow for contacts") } - return sessions, nil + return ss, skipped, nil } // StartFlowForContacts runs the passed in flow for the passed in contact diff --git a/core/tasks/handler/handle_contact_event.go b/core/tasks/handler/handle_contact_event.go index ecd36d721..18884595f 100644 --- a/core/tasks/handler/handle_contact_event.go +++ b/core/tasks/handler/handle_contact_event.go @@ -41,16 +41,14 @@ func (t *HandleContactEventTask) Timeout() time.Duration { // Perform is called when an event comes in for a contact. To make sure we don't get into a situation of being off by one, // this task ingests and handles all the events for a contact, one by one. func (t *HandleContactEventTask) Perform(ctx context.Context, rt *runtime.Runtime, orgID models.OrgID) error { - // acquire the lock for this contact - locker := models.GetContactLocker(orgID, t.ContactID) - - lock, err := locker.Grab(rt.RP, time.Second*10) + // try to get the lock for this contact, waiting up to 10 seconds + locks, _, err := models.LockContacts(rt, orgID, []models.ContactID{t.ContactID}, time.Second*10) if err != nil { return errors.Wrapf(err, "error acquiring lock for contact %d", t.ContactID) } - // we didn't get the lock within our timeout, skip and requeue for later - if lock == "" { + // we didn't get the lock.. requeue for later + if len(locks) == 0 { rc := rt.RP.Get() defer rc.Close() err = tasks.Queue(rc, queue.HandlerQueue, orgID, &HandleContactEventTask{ContactID: t.ContactID}, queue.DefaultPriority) @@ -60,7 +58,8 @@ func (t *HandleContactEventTask) Perform(ctx context.Context, rt *runtime.Runtim logrus.WithFields(logrus.Fields{"org_id": orgID, "contact_id": t.ContactID}).Info("failed to get lock for contact, requeued and skipping") return nil } - defer locker.Release(rt.RP, lock) + + defer models.UnlockContacts(rt, orgID, locks) // read all the events for this contact, one by one contactQ := fmt.Sprintf("c:%d:%d", orgID, t.ContactID)