Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge crons for expiring message and voice sessions #415

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/models/contact_fire.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
type ContactFireExtra struct {
SessionID SessionID `json:"session_id,omitempty"`
SessionModifiedOn time.Time `json:"session_modified_on,omitempty"`
CallID CallID `json:"call_id,omitempty"`
}

type ContactFire struct {
Expand Down
132 changes: 36 additions & 96 deletions core/tasks/expirations/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@ package expirations
import (
"context"
"fmt"
"log/slog"
"slices"
"time"

"github.com/nyaruka/mailroom/core/ivr"
"github.com/nyaruka/mailroom/core/models"
"github.com/nyaruka/mailroom/core/tasks"
"github.com/nyaruka/mailroom/core/tasks/contacts"
"github.com/nyaruka/mailroom/core/tasks/ivr"
"github.com/nyaruka/mailroom/runtime"
"github.com/nyaruka/redisx"
)

func init() {
tasks.RegisterCron("run_expirations", NewExpirationsCron(100))
tasks.RegisterCron("expire_ivr_calls", &VoiceExpirationsCron{})
}

type ExpirationsCron struct {
Expand Down Expand Up @@ -54,11 +52,12 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str

// scan and organize by org
byOrg := make(map[models.OrgID][]*ExpiredWait, 50)
callsByOrg := make(map[models.OrgID][]*ExpiredWait, 50)

rc := rt.RP.Get()
defer rc.Close()

numDupes, numQueued := 0, 0
numDupes, numExpires, numHangups := 0, 0, 0

for rows.Next() {
expiredWait := &ExpiredWait{}
Expand All @@ -78,7 +77,11 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str
continue
}

byOrg[expiredWait.OrgID] = append(byOrg[expiredWait.OrgID], expiredWait)
if expiredWait.CallID != models.NilCallID {
callsByOrg[expiredWait.OrgID] = append(callsByOrg[expiredWait.OrgID], expiredWait)
} else {
byOrg[expiredWait.OrgID] = append(byOrg[expiredWait.OrgID], expiredWait)
}
}

for orgID, expirations := range byOrg {
Expand All @@ -91,7 +94,7 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str
if err := tasks.Queue(rc, tasks.ThrottledQueue, orgID, &contacts.BulkSessionExpireTask{Expirations: exps}, true); err != nil {
return nil, fmt.Errorf("error queuing bulk expiration task to throttle queue: %w", err)
}
numQueued += len(batch)
numExpires += len(batch)

for _, exp := range batch {
// mark as queued
Expand All @@ -102,105 +105,42 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str
}
}

return map[string]any{"dupes": numDupes, "queued": numQueued}, nil
for orgID, expirations := range callsByOrg {
for batch := range slices.Chunk(expirations, c.bulkBatchSize) {
hups := make([]*ivr.Hangup, len(batch))
for i, exp := range batch {
hups[i] = &ivr.Hangup{SessionID: exp.SessionID, CallID: exp.CallID}
}

if err := tasks.Queue(rc, tasks.BatchQueue, orgID, &ivr.BulkCallHangupTask{Hangups: hups}, true); err != nil {
return nil, fmt.Errorf("error queuing bulk hangup task to batch queue: %w", err)
}
numHangups += len(batch)

for _, exp := range batch {
// mark as queued
if err = c.marker.Add(rc, taskID(exp)); err != nil {
return nil, fmt.Errorf("error marking hangup task as queued: %w", err)
}
}
}
}

return map[string]any{"dupes": numDupes, "queued_expires": numExpires, "queued_hangups": numHangups}, nil
}

const sqlSelectExpiredWaits = `
SELECT id as session_id, org_id, wait_expires_on, contact_id, modified_on
SELECT id, org_id, contact_id, call_id, wait_expires_on, modified_on
FROM flows_flowsession
WHERE session_type = 'M' AND status = 'W' AND wait_expires_on <= NOW()
WHERE status = 'W' AND wait_expires_on <= NOW()
ORDER BY wait_expires_on ASC
LIMIT 25000`

type ExpiredWait struct {
SessionID models.SessionID `db:"session_id"`
SessionID models.SessionID `db:"id"`
OrgID models.OrgID `db:"org_id"`
WaitExpiresOn time.Time `db:"wait_expires_on"`
ContactID models.ContactID `db:"contact_id"`
CallID models.CallID `db:"call_id"`
WaitExpiresOn time.Time `db:"wait_expires_on"`
ModifiedOn time.Time `db:"modified_on"`
}

type VoiceExpirationsCron struct{}

func (c *VoiceExpirationsCron) Next(last time.Time) time.Time {
return tasks.CronNext(last, time.Minute)
}

func (c *VoiceExpirationsCron) AllInstances() bool {
return false
}

// looks for voice sessions that should be expired and ends them
func (c *VoiceExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[string]any, error) {
log := slog.With("comp", "ivr_cron_expirer")

ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
defer cancel()

// select voice sessions with expired waits
rows, err := rt.DB.QueryxContext(ctx, sqlSelectExpiredVoiceWaits)
if err != nil {
return nil, fmt.Errorf("error querying voice sessions with expired waits: %w", err)
}
defer rows.Close()

expiredSessions := make([]models.SessionID, 0, 100)
clogs := make([]*models.ChannelLog, 0, 100)

for rows.Next() {
expiredWait := &ExpiredVoiceWait{}
err := rows.StructScan(expiredWait)
if err != nil {
return nil, fmt.Errorf("error scanning expired wait: %w", err)
}

// add the session to those we need to expire
expiredSessions = append(expiredSessions, expiredWait.SessionID)

// load our call
conn, err := models.GetCallByID(ctx, rt.DB, expiredWait.OrgID, expiredWait.CallID)
if err != nil {
log.Error("unable to load call", "error", err, "call_id", expiredWait.CallID)
continue
}

// hang up our call
clog, err := ivr.HangupCall(ctx, rt, conn)
if err != nil {
// log error but carry on with other calls
log.Error("error hanging up call", "error", err, "call_id", conn.ID())
}

if clog != nil {
clogs = append(clogs, clog)
}
}

// now expire our runs and sessions
if len(expiredSessions) > 0 {
err := models.ExitSessions(ctx, rt.DB, expiredSessions, models.SessionStatusExpired)
if err != nil {
log.Error("error expiring sessions for expired calls", "error", err)
}
}

if err := models.InsertChannelLogs(ctx, rt, clogs); err != nil {
return nil, fmt.Errorf("error inserting channel logs: %w", err)
}

return map[string]any{"expired": len(expiredSessions)}, nil
}

const sqlSelectExpiredVoiceWaits = `
SELECT id, org_id, call_id, wait_expires_on
FROM flows_flowsession
WHERE session_type = 'V' AND status = 'W' AND wait_expires_on <= NOW()
ORDER BY wait_expires_on ASC
LIMIT 100`

type ExpiredVoiceWait struct {
SessionID models.SessionID `db:"id"`
OrgID models.OrgID `db:"org_id"`
CallID models.CallID `db:"call_id"`
ExpiresOn time.Time `db:"wait_expires_on"`
}
85 changes: 26 additions & 59 deletions core/tasks/expirations/cron_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"testing"
"time"

"github.com/nyaruka/gocommon/dbutil/assertdb"
"github.com/nyaruka/gocommon/i18n"
"github.com/nyaruka/gocommon/jsonx"
"github.com/nyaruka/gocommon/uuids"
Expand All @@ -15,6 +14,7 @@ import (
"github.com/nyaruka/mailroom/core/tasks"
"github.com/nyaruka/mailroom/core/tasks/contacts"
"github.com/nyaruka/mailroom/core/tasks/expirations"
"github.com/nyaruka/mailroom/core/tasks/ivr"
"github.com/nyaruka/mailroom/testsuite"
"github.com/nyaruka/mailroom/testsuite/testdata"
"github.com/stretchr/testify/assert"
Expand All @@ -38,7 +38,7 @@ func TestExpirations(t *testing.T) {

// create an IVR session for Alexandria
call := testdata.InsertCall(rt, testdata.Org1, testdata.TwilioChannel, testdata.Alexandria)
testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Alexandria, models.FlowTypeVoice, testdata.IVRFlow, call, time.Now(), time.Now(), false, nil)
ivrID := testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Alexandria, models.FlowTypeVoice, testdata.IVRFlow, call, time.Now(), time.Now(), false, nil)

// for other org create 6 waiting sessions that will expire
for i := range 6 {
Expand All @@ -50,78 +50,45 @@ func TestExpirations(t *testing.T) {
cron := expirations.NewExpirationsCron(5)
res, err := cron.Run(ctx, rt)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"dupes": 0, "queued": 7}, res)
assert.Equal(t, map[string]any{"dupes": 0, "queued_expires": 7, "queued_hangups": 1}, res)

// should have created one throttled task for org 1
// should have created one throttled expire task for org 1
task1, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org1.ID), task1.OwnerID)
assert.Equal(t, "bulk_session_expire", task1.Type)

decoded := &contacts.BulkSessionExpireTask{}
jsonx.MustUnmarshal(task1.Task, decoded)
assert.Len(t, decoded.Expirations, 1)
assert.Equal(t, s2ID, decoded.Expirations[0].SessionID)
decoded1 := &contacts.BulkSessionExpireTask{}
jsonx.MustUnmarshal(task1.Task, decoded1)
assert.Len(t, decoded1.Expirations, 1)
assert.Equal(t, s2ID, decoded1.Expirations[0].SessionID)

// and two for org 2
task2, err := tasks.ThrottledQueue.Pop(rc)
// and one batch hangup task for the IVR session
task2, err := tasks.BatchQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org2.ID), task2.OwnerID)
assert.Equal(t, "bulk_session_expire", task2.Type)
assert.Equal(t, int(testdata.Org1.ID), task2.OwnerID)
assert.Equal(t, "bulk_call_hangup", task2.Type)

decoded2 := &ivr.BulkCallHangupTask{}
jsonx.MustUnmarshal(task2.Task, decoded2)
assert.Len(t, decoded2.Hangups, 1)
assert.Equal(t, ivrID, decoded2.Hangups[0].SessionID)

// and two expire tasks for org 2
task3, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org2.ID), task3.OwnerID)
assert.Equal(t, "bulk_session_expire", task2.Type)
assert.Equal(t, "bulk_session_expire", task3.Type)
task4, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org2.ID), task4.OwnerID)
assert.Equal(t, "bulk_session_expire", task4.Type)

// no other
task, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Nil(t, task)
assert.Equal(t, map[string]int{}, testsuite.FlushTasks(t, rt, "batch", "throttled"))

// if task runs again, these tasks won't be re-queued
res, err = cron.Run(ctx, rt)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"dupes": 7, "queued": 0}, res)
}

func TestExpireVoiceSessions(t *testing.T) {
ctx, rt := testsuite.Runtime()
rc := rt.RP.Get()
defer rc.Close()

defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis)

// create voice session for Cathy
conn1ID := testdata.InsertCall(rt, testdata.Org1, testdata.TwilioChannel, testdata.Cathy)
s1ID := testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Cathy, models.FlowTypeVoice, testdata.IVRFlow, conn1ID, time.Now(), time.Now(), false, nil)
r1ID := testdata.InsertFlowRun(rt, testdata.Org1, s1ID, testdata.Cathy, testdata.Favorites, models.RunStatusWaiting, "")

// create voice session for Bob with expiration in future
conn2ID := testdata.InsertCall(rt, testdata.Org1, testdata.TwilioChannel, testdata.Bob)
s2ID := testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Bob, models.FlowTypeMessaging, testdata.IVRFlow, conn2ID, time.Now(), time.Now().Add(time.Hour), false, nil)
r2ID := testdata.InsertFlowRun(rt, testdata.Org1, s2ID, testdata.Bob, testdata.IVRFlow, models.RunStatusWaiting, "")

// create a messaging session for Alexandria
s3ID := testdata.InsertWaitingSession(rt, testdata.Org1, testdata.Alexandria, models.FlowTypeMessaging, testdata.Favorites, models.NilCallID, time.Now(), time.Now(), false, nil)
r3ID := testdata.InsertFlowRun(rt, testdata.Org1, s3ID, testdata.Alexandria, testdata.Favorites, models.RunStatusWaiting, "")

time.Sleep(5 * time.Millisecond)

// expire our sessions...
cron := &expirations.VoiceExpirationsCron{}
res, err := cron.Run(ctx, rt)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"expired": 1}, res)

// Cathy's session should be expired along with its runs
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowsession WHERE id = $1;`, s1ID).Columns(map[string]any{"status": "X"})
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowrun WHERE id = $1;`, r1ID).Columns(map[string]any{"status": "X"})

// Bob's session and run should be unchanged
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowsession WHERE id = $1;`, s2ID).Columns(map[string]any{"status": "W"})
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowrun WHERE id = $1;`, r2ID).Columns(map[string]any{"status": "W"})

// Alexandria's session and run should be unchanged because message expirations are handled separately
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowsession WHERE id = $1;`, s3ID).Columns(map[string]any{"status": "W"})
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowrun WHERE id = $1;`, r3ID).Columns(map[string]any{"status": "W"})
assert.Equal(t, map[string]any{"dupes": 8, "queued_expires": 0, "queued_hangups": 0}, res)
}
Loading
Loading