forked from qustavo/sqlhooks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compose.go
76 lines (67 loc) · 1.78 KB
/
compose.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package sqlhooks
import (
"context"
"fmt"
)
// Compose allows for composing multiple Hooks into one.
// It runs every callback on every hook in argument order,
// even if previous hooks return an error.
// If multiple hooks return errors, the error return value will be
// MultipleErrors, which allows for introspecting the errors if necessary.
func Compose(hooks ...Hooks) Hooks {
return composed(hooks)
}
type composed []Hooks
func (c composed) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
var errors []error
for _, hook := range c {
c, err := hook.Before(ctx, query, args...)
if err != nil {
errors = append(errors, err)
}
if c != nil {
ctx = c
}
}
return ctx, wrapErrors(nil, errors)
}
func (c composed) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
var errors []error
for _, hook := range c {
var err error
c, err := hook.After(ctx, query, args...)
if err != nil {
errors = append(errors, err)
}
if c != nil {
ctx = c
}
}
return ctx, wrapErrors(nil, errors)
}
func (c composed) OnError(ctx context.Context, cause error, query string, args ...interface{}) error {
var errors []error
for _, hook := range c {
if onErrorer, ok := hook.(OnErrorer); ok {
if err := onErrorer.OnError(ctx, cause, query, args...); err != nil && err != cause {
errors = append(errors, err)
}
}
}
return wrapErrors(cause, errors)
}
func wrapErrors(def error, errors []error) error {
switch len(errors) {
case 0:
return def
case 1:
return errors[0]
default:
return MultipleErrors(errors)
}
}
// MultipleErrors is an error that contains multiple errors.
type MultipleErrors []error
func (m MultipleErrors) Error() string {
return fmt.Sprint("multiple errors:", []error(m))
}