Skip to content

Commit

Permalink
Feat: Add optional retry_for argument for task queue (#731)
Browse files Browse the repository at this point in the history
The `retry_for` argument can be used to provide a list of exceptions. If
any of those exceptions are encountered, the task will be retried.

It can be used like this: 
```python
from beta9 import task_queue

@task_queue(retries=2, retry_for=[AttributeError])
def hello():
    import random
    import time
    time.sleep(2)

    raise AttributeError("This is a test exception")
```

The resulting output will look something like this: 
```bash
Starting task worker[0]
Worker[0] ready
Running task <8b96b5c9-3614-4121-aaec-f0b08cc4e108>
Retrying task <8b96b5c9-3614-4121-aaec-f0b08cc4e108> after AttributeError exception
Running task <8b96b5c9-3614-4121-aaec-f0b08cc4e108>
Retrying task <8b96b5c9-3614-4121-aaec-f0b08cc4e108> after AttributeError exception
Running task <8b96b5c9-3614-4121-aaec-f0b08cc4e108>
Retry limit of 2 exceeded for task <8b96b5c9-3614-4121-aaec-f0b08cc4e108>
```
  • Loading branch information
dleviminzi authored Nov 27, 2024
1 parent fcd4e22 commit 03539fa
Show file tree
Hide file tree
Showing 15 changed files with 263 additions and 148 deletions.
4 changes: 4 additions & 0 deletions pkg/abstractions/endpoint/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ func (t *EndpointTask) Metadata() types.TaskMetadata {
TaskId: t.msg.TaskId,
}
}

func (t *EndpointTask) Message() *types.TaskMessage {
return t.msg
}
4 changes: 4 additions & 0 deletions pkg/abstractions/experimental/bot/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,7 @@ func (t *BotTask) Metadata() types.TaskMetadata {
WorkspaceName: t.msg.WorkspaceName,
}
}

func (t *BotTask) Message() *types.TaskMessage {
return t.msg
}
4 changes: 4 additions & 0 deletions pkg/abstractions/function/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,7 @@ func (t *FunctionTask) Metadata() types.TaskMetadata {
ContainerId: t.containerId,
}
}

func (t *FunctionTask) Message() *types.TaskMessage {
return t.msg
}
56 changes: 56 additions & 0 deletions pkg/abstractions/taskqueue/keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package taskqueue

import "fmt"

// Redis keys
var (
taskQueueList string = "taskqueue:%s:%s"
taskQueueServeLock string = "taskqueue:%s:%s:serve_lock"
taskQueueInstanceLock string = "taskqueue:%s:%s:instance_lock"
taskQueueTaskDuration string = "taskqueue:%s:%s:task_duration"
taskQueueAverageTaskDuration string = "taskqueue:%s:%s:avg_task_duration"
taskQueueTaskHeartbeat string = "taskqueue:%s:%s:task:heartbeat:%s"
taskQueueProcessingLock string = "taskqueue:%s:%s:processing_lock:%s"
taskQueueKeepWarmLock string = "taskqueue:%s:%s:keep_warm_lock:%s"
taskQueueTaskRunningLock string = "taskqueue:%s:%s:task_running:%s:%s"
)

var Keys = &keys{}

type keys struct{}

func (k *keys) taskQueueInstanceLock(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueInstanceLock, workspaceName, stubId)
}

func (k *keys) taskQueueServeLock(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueServeLock, workspaceName, stubId)
}

func (k *keys) taskQueueList(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueList, workspaceName, stubId)
}

func (k *keys) taskQueueTaskHeartbeat(workspaceName, stubId, taskId string) string {
return fmt.Sprintf(taskQueueTaskHeartbeat, workspaceName, stubId, taskId)
}

func (k *keys) taskQueueTaskDuration(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueTaskDuration, workspaceName, stubId)
}

func (k *keys) taskQueueAverageTaskDuration(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueAverageTaskDuration, workspaceName, stubId)
}

func (k *keys) taskQueueTaskRunningLock(workspaceName, stubId, containerId, taskId string) string {
return fmt.Sprintf(taskQueueTaskRunningLock, workspaceName, stubId, containerId, taskId)
}

func (k *keys) taskQueueProcessingLock(workspaceName, stubId, containerId string) string {
return fmt.Sprintf(taskQueueProcessingLock, workspaceName, stubId, containerId)
}

func (k *keys) taskQueueKeepWarmLock(workspaceName, stubId, containerId string) string {
return fmt.Sprintf(taskQueueKeepWarmLock, workspaceName, stubId, containerId)
}
4 changes: 4 additions & 0 deletions pkg/abstractions/taskqueue/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ func (t *TaskQueueTask) Metadata() types.TaskMetadata {
WorkspaceName: t.msg.WorkspaceName,
}
}

func (t *TaskQueueTask) Message() *types.TaskMessage {
return t.msg
}
73 changes: 25 additions & 48 deletions pkg/abstractions/taskqueue/taskqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ func (tq *RedisTaskQueue) TaskQueueComplete(ctx context.Context, in *pb.TaskQueu
task.EndedAt = sql.NullTime{Time: time.Now(), Valid: true}
task.Status = types.TaskStatus(in.TaskStatus)

if task.Status == types.TaskStatusRetry {
return tq.retryTask(ctx, authInfo, in), nil
}

err = tq.taskDispatcher.Complete(ctx, authInfo.Workspace.Name, in.StubId, in.TaskId)
if err != nil {
return &pb.TaskQueueCompleteResponse{
Expand Down Expand Up @@ -603,55 +607,28 @@ func (tq *RedisTaskQueue) getOrCreateQueueInstance(stubId string, options ...fun
return instance, nil
}

// Redis keys
var (
taskQueueList string = "taskqueue:%s:%s"
taskQueueServeLock string = "taskqueue:%s:%s:serve_lock"
taskQueueInstanceLock string = "taskqueue:%s:%s:instance_lock"
taskQueueTaskDuration string = "taskqueue:%s:%s:task_duration"
taskQueueAverageTaskDuration string = "taskqueue:%s:%s:avg_task_duration"
taskQueueTaskHeartbeat string = "taskqueue:%s:%s:task:heartbeat:%s"
taskQueueProcessingLock string = "taskqueue:%s:%s:processing_lock:%s"
taskQueueKeepWarmLock string = "taskqueue:%s:%s:keep_warm_lock:%s"
taskQueueTaskRunningLock string = "taskqueue:%s:%s:task_running:%s:%s"
)

var Keys = &keys{}

type keys struct{}

func (k *keys) taskQueueInstanceLock(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueInstanceLock, workspaceName, stubId)
}

func (k *keys) taskQueueServeLock(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueServeLock, workspaceName, stubId)
}

func (k *keys) taskQueueList(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueList, workspaceName, stubId)
}

func (k *keys) taskQueueTaskHeartbeat(workspaceName, stubId, taskId string) string {
return fmt.Sprintf(taskQueueTaskHeartbeat, workspaceName, stubId, taskId)
}

func (k *keys) taskQueueTaskDuration(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueTaskDuration, workspaceName, stubId)
}

func (k *keys) taskQueueAverageTaskDuration(workspaceName, stubId string) string {
return fmt.Sprintf(taskQueueAverageTaskDuration, workspaceName, stubId)
}
func (tq *RedisTaskQueue) retryTask(ctx context.Context, authInfo *auth.AuthInfo, in *pb.TaskQueueCompleteRequest) *pb.TaskQueueCompleteResponse {
task, err := tq.taskDispatcher.Retrieve(ctx, authInfo.Workspace.Name, in.StubId, in.TaskId)
if err != nil {
return &pb.TaskQueueCompleteResponse{
Ok: false,
}
}

func (k *keys) taskQueueTaskRunningLock(workspaceName, stubId, containerId, taskId string) string {
return fmt.Sprintf(taskQueueTaskRunningLock, workspaceName, stubId, containerId, taskId)
}
msg := ""
if task.Message().Retries >= task.Message().Policy.MaxRetries {
msg = fmt.Sprintf("Exceeded retry limit of %d for task <%s>", task.Message().Policy.MaxRetries, task.Message().TaskId)
}

func (k *keys) taskQueueProcessingLock(workspaceName, stubId, containerId string) string {
return fmt.Sprintf(taskQueueProcessingLock, workspaceName, stubId, containerId)
}
err = tq.taskDispatcher.RetryTask(ctx, task)
if err != nil {
return &pb.TaskQueueCompleteResponse{
Ok: false,
}
}

func (k *keys) taskQueueKeepWarmLock(workspaceName, stubId, containerId string) string {
return fmt.Sprintf(taskQueueKeepWarmLock, workspaceName, stubId, containerId)
return &pb.TaskQueueCompleteResponse{
Ok: true,
Message: msg,
}
}
5 changes: 4 additions & 1 deletion pkg/abstractions/taskqueue/taskqueue.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ message TaskQueueCompleteRequest {
float keep_warm_seconds = 7;
}

message TaskQueueCompleteResponse { bool ok = 1; }
message TaskQueueCompleteResponse {
bool ok = 1;
string message = 2;
}

message TaskQueueMonitorRequest {
string task_id = 1;
Expand Down
6 changes: 4 additions & 2 deletions pkg/task/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,17 @@ func (d *Dispatcher) monitor(ctx context.Context) {
}

if !heartbeat {
d.retryTask(ctx, task, taskMessage)
d.RetryTask(ctx, task)
continue
}
}
}
}
}

func (d *Dispatcher) retryTask(ctx context.Context, task types.TaskInterface, taskMessage *types.TaskMessage) error {
func (d *Dispatcher) RetryTask(ctx context.Context, task types.TaskInterface) error {
taskMessage := task.Message()

err := d.taskRepo.SetTaskRetryLock(ctx, taskMessage.WorkspaceName, taskMessage.StubId, taskMessage.TaskId)
if err != nil {
return err
Expand Down
1 change: 1 addition & 0 deletions pkg/types/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type TaskInterface interface {
Retry(ctx context.Context) error
HeartBeat(ctx context.Context) (bool, error)
Metadata() TaskMetadata
Message() *TaskMessage
}

type TaskExecutor string
Expand Down
Loading

0 comments on commit 03539fa

Please sign in to comment.