forked from marselester/gopher-celery
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcelery.go
224 lines (192 loc) · 6.3 KB
/
celery.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
// Package celery helps to work with Celery (place tasks in queues and execute them).
package celery
import (
"context"
"fmt"
"runtime/debug"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"github.com/marselester/gopher-celery/protocol"
"github.com/marselester/gopher-celery/redis"
)
// TaskF represents a Celery task implemented by the client.
// The error doesn't affect anything, it's logged though.
type TaskF func(ctx context.Context, p *TaskParam) error
// Middleware is a chainable behavior modifier for tasks.
// For example, a caller can collect task metrics.
type Middleware func(next TaskF) TaskF
// Broker is responsible for receiving and sending task messages.
// For example, it knows how to read a message from a given queue in Redis.
// The messages can be in defferent formats depending on Celery protocol version.
type Broker interface {
// Send puts a message to a queue.
// Note, the method is safe to call concurrently.
Send(msg []byte, queue string) error
// Observe sets the queues from which the tasks should be received.
// Note, the method is not concurrency safe.
Observe(queues []string)
// Receive returns a raw message from one of the queues.
// It blocks until there is a message available for consumption.
// Note, the method is not concurrency safe.
Receive() ([]byte, error)
}
// NewApp creates a Celery app.
// The default broker is Redis assumed to run on localhost.
// When producing tasks the default message serializer is json and protocol is v2.
func NewApp(options ...Option) *App {
app := App{
conf: Config{
logger: log.NewNopLogger(),
registry: protocol.NewSerializerRegistry(),
mime: protocol.MimeJSON,
protocol: protocol.V2,
maxWorkers: DefaultMaxWorkers,
},
task: make(map[string]TaskF),
taskQueue: make(map[string]string),
}
for _, opt := range options {
opt(&app.conf)
}
app.sem = make(chan struct{}, app.conf.maxWorkers)
if app.conf.broker == nil {
app.conf.broker = redis.NewBroker()
}
return &app
}
// App is a Celery app to produce or consume tasks asynchronously.
type App struct {
// conf represents app settings.
conf Config
// task maps a Celery task path to a task itself, e.g.,
// "myproject.apps.myapp.tasks.mytask": TaskF.
task map[string]TaskF
// taskQueue helps to determine which queue a task belongs to, e.g.,
// "myproject.apps.myapp.tasks.mytask": "important".
taskQueue map[string]string
// sem is a semaphore that limits number of workers.
sem chan struct{}
}
// Register associates the task with given Python path and queue.
// For example, when "myproject.apps.myapp.tasks.mytask"
// is seen in "important" queue, the TaskF task is executed.
//
// Note, the method is not concurrency safe.
// The tasks mustn't be registered after the app starts processing tasks.
func (a *App) Register(path, queue string, task TaskF) {
a.task[path] = task
a.taskQueue[path] = queue
}
// Delay places the task associated with given Python path into queue.
func (a *App) Delay(path, queue string, args ...interface{}) error {
m := protocol.Task{
ID: uuid.NewString(),
Name: path,
Args: args,
}
rawMsg, err := a.conf.registry.Encode(queue, a.conf.mime, a.conf.protocol, &m)
if err != nil {
return fmt.Errorf("failed to encode task message: %w", err)
}
if err = a.conf.broker.Send(rawMsg, queue); err != nil {
return fmt.Errorf("failed to send task message to broker: %w", err)
}
return nil
}
// Run launches the workers that process the tasks received from the broker.
// The call is blocking until ctx is cancelled.
// The caller mustn't register any new tasks at this point.
func (a *App) Run(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
qq := make([]string, 0, len(a.taskQueue))
for k := range a.taskQueue {
qq = append(qq, a.taskQueue[k])
}
a.conf.broker.Observe(qq)
level.Debug(a.conf.logger).Log("msg", "observing queues", "queues", qq)
msgs := make(chan *protocol.Task, 1)
g.Go(func() error {
defer close(msgs)
// One goroutine fetching and decoding tasks from queues
// shouldn't be a bottleneck since the worker goroutines
// usually take seconds/minutes to complete.
for {
select {
case <-ctx.Done():
return nil
default:
rawMsg, err := a.conf.broker.Receive()
if err != nil {
return fmt.Errorf("failed to receive a raw task message: %w", err)
}
// No messages in the broker so far.
if rawMsg == nil {
continue
}
m, err := a.conf.registry.Decode(rawMsg)
if err != nil {
level.Error(a.conf.logger).Log("msg", "failed to decode task message", "rawmsg", rawMsg, "err", err)
continue
}
msgs <- m
}
}
})
go func() {
// Start a worker when there is a task.
for m := range msgs {
level.Debug(a.conf.logger).Log("msg", "task received", "name", m.Name)
if a.task[m.Name] == nil {
level.Debug(a.conf.logger).Log("msg", "unregistered task", "name", m.Name)
continue
}
if m.IsExpired() {
level.Debug(a.conf.logger).Log("msg", "task message expired", "name", m.Name)
continue
}
select {
// Acquire a semaphore by sending a token.
case a.sem <- struct{}{}:
// Stop processing tasks.
case <-ctx.Done():
return
}
m := m
g.Go(func() error {
// Release a semaphore by discarding a token.
defer func() { <-a.sem }()
if err := a.executeTask(ctx, m); err != nil {
level.Error(a.conf.logger).Log("msg", "task failed", "taskmsg", m, "err", err)
} else {
level.Debug(a.conf.logger).Log("msg", "task succeeded", "name", m.Name)
}
return nil
})
}
}()
return g.Wait()
}
type contextKey int
const (
// ContextKeyTaskName is a context key to access task names.
ContextKeyTaskName contextKey = iota
)
// executeTask calls the task function with args and kwargs from the message.
// If the task panics, the stack trace is returned as an error.
func (a *App) executeTask(ctx context.Context, m *protocol.Task) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("unexpected task error: %v: %s", r, debug.Stack())
}
}()
task := a.task[m.Name]
// Use middlewares if a client provided them.
if a.conf.chain != nil {
task = a.conf.chain(task)
}
ctx = context.WithValue(ctx, ContextKeyTaskName, m.Name)
p := NewTaskParam(m.Args, m.Kwargs)
return task(ctx, p)
}