Skip to content

Commit

Permalink
Merge the crons for expiring message and voice sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Jan 23, 2025
1 parent b676eb2 commit a24388e
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 155 deletions.
1 change: 1 addition & 0 deletions core/models/contact_fire.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
type ContactFireExtra struct {
SessionID SessionID `json:"session_id,omitempty"`
SessionModifiedOn time.Time `json:"session_modified_on,omitempty"`
CallID CallID `json:"call_id,omitempty"`
}

type ContactFire struct {
Expand Down
132 changes: 36 additions & 96 deletions core/tasks/expirations/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@ package expirations
import (
"context"
"fmt"
"log/slog"
"slices"
"time"

"github.com/nyaruka/mailroom/core/ivr"
"github.com/nyaruka/mailroom/core/models"
"github.com/nyaruka/mailroom/core/tasks"
"github.com/nyaruka/mailroom/core/tasks/contacts"
"github.com/nyaruka/mailroom/core/tasks/ivr"
"github.com/nyaruka/mailroom/runtime"
"github.com/nyaruka/redisx"
)

func init() {
tasks.RegisterCron("run_expirations", NewExpirationsCron(100))
tasks.RegisterCron("expire_ivr_calls", &VoiceExpirationsCron{})
}

type ExpirationsCron struct {
Expand Down Expand Up @@ -54,11 +52,12 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str

// scan and organize by org
byOrg := make(map[models.OrgID][]*ExpiredWait, 50)
callsByOrg := make(map[models.OrgID][]*ExpiredWait, 50)

rc := rt.RP.Get()
defer rc.Close()

numDupes, numQueued := 0, 0
numDupes, numExpires, numHangups := 0, 0, 0

for rows.Next() {
expiredWait := &ExpiredWait{}
Expand All @@ -78,7 +77,11 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str
continue
}

byOrg[expiredWait.OrgID] = append(byOrg[expiredWait.OrgID], expiredWait)
if expiredWait.CallID != models.NilCallID {
callsByOrg[expiredWait.OrgID] = append(callsByOrg[expiredWait.OrgID], expiredWait)
} else {
byOrg[expiredWait.OrgID] = append(byOrg[expiredWait.OrgID], expiredWait)
}
}

for orgID, expirations := range byOrg {
Expand All @@ -91,7 +94,7 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str
if err := tasks.Queue(rc, tasks.ThrottledQueue, orgID, &contacts.BulkSessionExpireTask{Expirations: exps}, true); err != nil {
return nil, fmt.Errorf("error queuing bulk expiration task to throttle queue: %w", err)
}
numQueued += len(batch)
numExpires += len(batch)

for _, exp := range batch {
// mark as queued
Expand All @@ -102,105 +105,42 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str
}
}

return map[string]any{"dupes": numDupes, "queued": numQueued}, nil
for orgID, expirations := range callsByOrg {
for batch := range slices.Chunk(expirations, c.bulkBatchSize) {
hups := make([]*ivr.Hangup, len(batch))
for i, exp := range batch {
hups[i] = &ivr.Hangup{SessionID: exp.SessionID, CallID: exp.CallID}
}

if err := tasks.Queue(rc, tasks.BatchQueue, orgID, &ivr.BulkCallHangupTask{Hangups: hups}, true); err != nil {
return nil, fmt.Errorf("error queuing bulk hangup task to batch queue: %w", err)
}
numHangups += len(batch)

for _, exp := range batch {
// mark as queued
if err = c.marker.Add(rc, taskID(exp)); err != nil {
return nil, fmt.Errorf("error marking hangup task as queued: %w", err)
}
}
}
}

return map[string]any{"dupes": numDupes, "queued_expires": numExpires, "queued_hangups": numHangups}, nil
}

const sqlSelectExpiredWaits = `
SELECT id as session_id, org_id, wait_expires_on, contact_id, modified_on
SELECT id, org_id, contact_id, call_id, wait_expires_on, modified_on
FROM flows_flowsession
WHERE session_type = 'M' AND status = 'W' AND wait_expires_on <= NOW()
WHERE status = 'W' AND wait_expires_on <= NOW()
ORDER BY wait_expires_on ASC
LIMIT 25000`

type ExpiredWait struct {
SessionID models.SessionID `db:"session_id"`
SessionID models.SessionID `db:"id"`
OrgID models.OrgID `db:"org_id"`
WaitExpiresOn time.Time `db:"wait_expires_on"`
ContactID models.ContactID `db:"contact_id"`
CallID models.CallID `db:"call_id"`
WaitExpiresOn time.Time `db:"wait_expires_on"`
ModifiedOn time.Time `db:"modified_on"`
}

type VoiceExpirationsCron struct{}

func (c *VoiceExpirationsCron) Next(last time.Time) time.Time {
return tasks.CronNext(last, time.Minute)
}

func (c *VoiceExpirationsCron) AllInstances() bool {
return false
}

// looks for voice sessions that should be expired and ends them
func (c *VoiceExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[string]any, error) {
log := slog.With("comp", "ivr_cron_expirer")

ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
defer cancel()

// select voice sessions with expired waits
rows, err := rt.DB.QueryxContext(ctx, sqlSelectExpiredVoiceWaits)
if err != nil {
return nil, fmt.Errorf("error querying voice sessions with expired waits: %w", err)
}
defer rows.Close()

expiredSessions := make([]models.SessionID, 0, 100)
clogs := make([]*models.ChannelLog, 0, 100)

for rows.Next() {
expiredWait := &ExpiredVoiceWait{}
err := rows.StructScan(expiredWait)
if err != nil {
return nil, fmt.Errorf("error scanning expired wait: %w", err)
}

// add the session to those we need to expire
expiredSessions = append(expiredSessions, expiredWait.SessionID)

// load our call
conn, err := models.GetCallByID(ctx, rt.DB, expiredWait.OrgID, expiredWait.CallID)
if err != nil {
log.Error("unable to load call", "error", err, "call_id", expiredWait.CallID)
continue
}

// hang up our call
clog, err := ivr.HangupCall(ctx, rt, conn)
if err != nil {
// log error but carry on with other calls
log.Error("error hanging up call", "error", err, "call_id", conn.ID())
}

if clog != nil {
clogs = append(clogs, clog)
}
}

// now expire our runs and sessions
if len(expiredSessions) > 0 {
err := models.ExitSessions(ctx, rt.DB, expiredSessions, models.SessionStatusExpired)
if err != nil {
log.Error("error expiring sessions for expired calls", "error", err)
}
}

if err := models.InsertChannelLogs(ctx, rt, clogs); err != nil {
return nil, fmt.Errorf("error inserting channel logs: %w", err)
}

return map[string]any{"expired": len(expiredSessions)}, nil
}

const sqlSelectExpiredVoiceWaits = `
SELECT id, org_id, call_id, wait_expires_on
FROM flows_flowsession
WHERE session_type = 'V' AND status = 'W' AND wait_expires_on <= NOW()
ORDER BY wait_expires_on ASC
LIMIT 100`

type ExpiredVoiceWait struct {
SessionID models.SessionID `db:"id"`
OrgID models.OrgID `db:"org_id"`
CallID models.CallID `db:"call_id"`
ExpiresOn time.Time `db:"wait_expires_on"`
}
85 changes: 26 additions & 59 deletions core/tasks/expirations/cron_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"testing"
"time"

"github.com/nyaruka/gocommon/dbutil/assertdb"
"github.com/nyaruka/gocommon/i18n"
"github.com/nyaruka/gocommon/jsonx"
"github.com/nyaruka/gocommon/uuids"
Expand All @@ -15,6 +14,7 @@ import (
"github.com/nyaruka/mailroom/core/tasks"
"github.com/nyaruka/mailroom/core/tasks/contacts"
"github.com/nyaruka/mailroom/core/tasks/expirations"
"github.com/nyaruka/mailroom/core/tasks/ivr"
"github.com/nyaruka/mailroom/testsuite"
"github.com/nyaruka/mailroom/testsuite/testdata"
"github.com/stretchr/testify/assert"
Expand All @@ -38,7 +38,7 @@ func TestExpirations(t *testing.T) {

// create an IVR session for Alexandria
call := testdata.InsertCall(rt, testdata.Org1, testdata.TwilioChannel, testdata.Alexandria)
testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Alexandria, models.FlowTypeVoice, testdata.IVRFlow, call, time.Now(), time.Now(), false, nil)
ivrID := testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Alexandria, models.FlowTypeVoice, testdata.IVRFlow, call, time.Now(), time.Now(), false, nil)

// for other org create 6 waiting sessions that will expire
for i := range 6 {
Expand All @@ -50,78 +50,45 @@ func TestExpirations(t *testing.T) {
cron := expirations.NewExpirationsCron(5)
res, err := cron.Run(ctx, rt)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"dupes": 0, "queued": 7}, res)
assert.Equal(t, map[string]any{"dupes": 0, "queued_expires": 7, "queued_hangups": 1}, res)

// should have created one throttled task for org 1
// should have created one throttled expire task for org 1
task1, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org1.ID), task1.OwnerID)
assert.Equal(t, "bulk_session_expire", task1.Type)

decoded := &contacts.BulkSessionExpireTask{}
jsonx.MustUnmarshal(task1.Task, decoded)
assert.Len(t, decoded.Expirations, 1)
assert.Equal(t, s2ID, decoded.Expirations[0].SessionID)
decoded1 := &contacts.BulkSessionExpireTask{}
jsonx.MustUnmarshal(task1.Task, decoded1)
assert.Len(t, decoded1.Expirations, 1)
assert.Equal(t, s2ID, decoded1.Expirations[0].SessionID)

// and two for org 2
task2, err := tasks.ThrottledQueue.Pop(rc)
// and one batch hangup task for the IVR session
task2, err := tasks.BatchQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org2.ID), task2.OwnerID)
assert.Equal(t, "bulk_session_expire", task2.Type)
assert.Equal(t, int(testdata.Org1.ID), task2.OwnerID)
assert.Equal(t, "bulk_call_hangup", task2.Type)

decoded2 := &ivr.BulkCallHangupTask{}
jsonx.MustUnmarshal(task2.Task, decoded2)
assert.Len(t, decoded2.Hangups, 1)
assert.Equal(t, ivrID, decoded2.Hangups[0].SessionID)

// and two expire tasks for org 2
task3, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org2.ID), task3.OwnerID)
assert.Equal(t, "bulk_session_expire", task2.Type)
assert.Equal(t, "bulk_session_expire", task3.Type)
task4, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org2.ID), task4.OwnerID)
assert.Equal(t, "bulk_session_expire", task4.Type)

// no other
task, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Nil(t, task)
assert.Equal(t, map[string]int{}, testsuite.FlushTasks(t, rt, "batch", "throttled"))

// if task runs again, these tasks won't be re-queued
res, err = cron.Run(ctx, rt)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"dupes": 7, "queued": 0}, res)
}

func TestExpireVoiceSessions(t *testing.T) {
ctx, rt := testsuite.Runtime()
rc := rt.RP.Get()
defer rc.Close()

defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis)

// create voice session for Cathy
conn1ID := testdata.InsertCall(rt, testdata.Org1, testdata.TwilioChannel, testdata.Cathy)
s1ID := testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Cathy, models.FlowTypeVoice, testdata.IVRFlow, conn1ID, time.Now(), time.Now(), false, nil)
r1ID := testdata.InsertFlowRun(rt, testdata.Org1, s1ID, testdata.Cathy, testdata.Favorites, models.RunStatusWaiting, "")

// create voice session for Bob with expiration in future
conn2ID := testdata.InsertCall(rt, testdata.Org1, testdata.TwilioChannel, testdata.Bob)
s2ID := testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Bob, models.FlowTypeMessaging, testdata.IVRFlow, conn2ID, time.Now(), time.Now().Add(time.Hour), false, nil)
r2ID := testdata.InsertFlowRun(rt, testdata.Org1, s2ID, testdata.Bob, testdata.IVRFlow, models.RunStatusWaiting, "")

// create a messaging session for Alexandria
s3ID := testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Alexandria, models.FlowTypeMessaging, testdata.Favorites, models.NilCallID, time.Now(), time.Now(), false, nil)
r3ID := testdata.InsertFlowRun(rt, testdata.Org1, s3ID, testdata.Alexandria, testdata.Favorites, models.RunStatusWaiting, "")

time.Sleep(5 * time.Millisecond)

// expire our sessions...
cron := &expirations.VoiceExpirationsCron{}
res, err := cron.Run(ctx, rt)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"expired": 1}, res)

// Cathy's session should be expired along with its runs
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowsession WHERE id = $1;`, s1ID).Columns(map[string]any{"status": "X"})
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowrun WHERE id = $1;`, r1ID).Columns(map[string]any{"status": "X"})

// Bob's session and run should be unchanged
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowsession WHERE id = $1;`, s2ID).Columns(map[string]any{"status": "W"})
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowrun WHERE id = $1;`, r2ID).Columns(map[string]any{"status": "W"})

// Alexandria's session and run should be unchanged because message expirations are handled separately
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowsession WHERE id = $1;`, s3ID).Columns(map[string]any{"status": "W"})
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowrun WHERE id = $1;`, r3ID).Columns(map[string]any{"status": "W"})
assert.Equal(t, map[string]any{"dupes": 8, "queued_expires": 0, "queued_hangups": 0}, res)
}
Loading

0 comments on commit a24388e

Please sign in to comment.