diff --git a/README.md b/README.md index 428580e..7b0fb95 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ GLOBAL OPTIONS: --insecure Skips certificate verification. (default: false) --noprompt Disables password prompt. (default: false) --silent Disables terminal print. (default: false) + --persistent Creates a persistent interceptor queue. (default: false) --help, -h show help --version, -v print the version ``` diff --git a/coyote.go b/coyote.go index 9bcf388..01b011c 100644 --- a/coyote.go +++ b/coyote.go @@ -4,12 +4,14 @@ import ( "context" "crypto/tls" "database/sql" + "errors" "fmt" "log" "net/url" "os" "os/signal" "strings" + "sync/atomic" "github.com/fatih/color" "github.com/google/uuid" @@ -45,6 +47,25 @@ type combination struct { routingKey string } +type rabbitMQConnection struct { + conn *amqp.Connection + channel *amqp.Channel + rabbitUrl *url.URL + insecure bool + queue string + persistent bool + deliverables *listen +} + +type connectionError struct { + msg string + fatal bool +} + +func (e *connectionError) Error() string { + return e.msg +} + func (l *listen) Set(value string) (err error) { for _, comb := range strings.Split(value, ",") { pair := strings.Split(comb, "=") @@ -73,6 +94,133 @@ func (l *listen) String() string { return "" } +func (r *rabbitMQConnection) Connect() *connectionError { + var err error + r.conn, err = amqp.DialTLS(r.rabbitUrl.String(), &tls.Config{InsecureSkipVerify: r.insecure}) + if err != nil { + var e *amqp.Error + switch { + case errors.As(err, &e): + if e.Code == amqp.AccessRefused { + return &connectionError{ + msg: fmt.Sprintf("%s %v", color.RedString("access denied"), err), fatal: true, + } + } else { + return &connectionError{ + msg: fmt.Sprintf("%s %v", color.RedString("failed to connect to RabbitMQ:"), err), + } + } + default: + return &connectionError{ + msg: fmt.Sprintf("%s %v", color.RedString("failed to connect to RabbitMQ:"), err), + } + } + } + r.channel, err = r.conn.Channel() + if err != nil { + return &connectionError{ + msg: fmt.Sprintf("%s %v", color.RedString("failed to open a channel:"), err), + } + } + return nil +} + +func (r *rabbitMQConnection) Consume() (<-chan amqp.Delivery, error) { + q, err := r.channel.QueueDeclare( + r.queue, + false, // is durable + !r.persistent, // is auto delete + !r.persistent, // is exclusive + false, // is no wait + nil, // args + ) + if err != nil { + return nil, fmt.Errorf("%s %w", color.RedString("failed to declare a queue:"), err) + } + + for _, c := range r.deliverables.c { + err = r.channel.ExchangeDeclarePassive( + c.exchange, // exchange name + "topic", // exchange kind + true, // is durable + false, // is auto delete + false, // is internal + false, // is no wait + nil, // args + ) + if err != nil { + return nil, fmt.Errorf("%s %w", color.RedString("failed to connect to exchange:"), err) + } + + err = r.channel.QueueBind( + q.Name, // interceptor queue name + c.routingKey, // routing key to bind + c.exchange, // exchange to listen + false, // is no wait + nil, // args + ) + if err != nil { + return nil, fmt.Errorf("%s %w", color.RedString("failed to bind to queue:"), err) + } else { + log.Printf("👂 Listening from exchange %s with routing key %s", color.YellowString(c.exchange), color.YellowString(c.routingKey)) + } + } + + deliveries, err := r.channel.Consume( + q.Name, // queue name to consume from + "", // consumer tag + true, // is auto ack + false, // is exclusive + false, // is no local + false, // is no wait + nil, // args + ) + if err != nil { + return nil, fmt.Errorf("%s %w", color.RedString("failed to register a consumer:"), err) + } + + return deliveries, nil +} + +func (r *rabbitMQConnection) Close() error { + var err error + + if r.persistent { + if r.channel != nil && !r.channel.IsClosed() { + log.Printf("🗑️ Deleting queue %s", r.queue) + _, err := r.channel.QueueDelete( + r.queue, + false, // IfUnused: delete only if the queue is unused + false, // IfEmpty: delete only if the queue is empty + false, // NoWait: wait for server confirmation + ) + if err != nil { + log.Printf("Failed to delete the queue: %v", err) + } + } else { + log.Printf("Failed to delete the persistent queue %s because RabbitMQ connection is closed. Please delete it manually!", r.queue) + } + } + + if r.conn != nil && !r.conn.IsClosed() { + log.Printf("💔 Terminating AMQP connection") + err = r.conn.Close() + if err != nil { + return err + } + } + + if r.channel != nil && !r.channel.IsClosed() { + log.Printf("💔 Terminating AMQP channel") + err = r.channel.Close() + if err != nil { + return err + } + } + + return nil +} + func main() { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -131,9 +279,13 @@ func main() { Name: "silent", Usage: "Disables terminal print.", }, + &cli.BoolFlag{ + Name: "persistent", + Usage: "Creates a persistent interceptor queue.", + }, }, Action: func(ctx *cli.Context) error { - u, err := url.Parse(ctx.String("url")) + rabbitUrl, err := url.Parse(ctx.String("url")) if err != nil { return fmt.Errorf("%s %w", color.RedString("failed to parse provided url:"), err) } @@ -147,109 +299,26 @@ func main() { if err != nil { return fmt.Errorf("%s %w", color.RedString("failed to provide password:"), err) } - u.User = url.UserPassword(u.User.String(), password) - } - - conn, err := amqp.DialTLS(u.String(), &tls.Config{InsecureSkipVerify: ctx.Bool("insecure")}) - if err != nil { - return fmt.Errorf("%s %w", color.RedString("failed to connect to RabbitMQ:"), err) + rabbitUrl.User = url.UserPassword(rabbitUrl.User.String(), password) } - defer func() { - err := conn.Close() - if err != nil { - log.Fatal(err) - } - log.Printf("💔 Terminating AMQP connection") - }() - ch, err := conn.Channel() - if err != nil { - return fmt.Errorf("%s %w", color.RedString("failed to open a channel:"), err) - } - defer func() { - err := ch.Close() + var db *sql.DB + var insert *sql.Stmt + if ctx.IsSet("store") { + filename := ctx.String("store") + db, err = sql.Open("sqlite", filename+"?_txlock=exclusive&mode=rwc") if err != nil { log.Fatal(err) } - log.Printf("💔 Terminating AMQP channel") - }() - - q, err := ch.QueueDeclare( - fmt.Sprintf("%s.%s", ctx.String("queue"), uuid.NewString()), // queue name - false, // is durable - true, // is auto delete - true, // is exclusive - false, // is no wait - nil, // args - ) - if err != nil { - return fmt.Errorf("%s %w", color.RedString("failed to declare a queue:"), err) - } - - for _, c := range ctx.Generic("exchange").(*listen).c { - err = ch.ExchangeDeclarePassive( - c.exchange, // exchange name - "topic", // exchange kind - true, // is durable - false, // is auto delete - false, // is internal - false, // is no wait - nil, // args - ) - if err != nil { - return fmt.Errorf("%s %w", color.RedString("failed to connect to exchange:"), err) - } - - err = ch.QueueBind( - q.Name, // interceptor queue name - c.routingKey, // routing key to bind - c.exchange, // exchange to listen - false, // is no wait - nil, // args - ) - if err != nil { - return fmt.Errorf("%s %w", color.RedString("failed to bind to queue:"), err) - } else { - log.Printf("👂 Listening from exchange %s with routing key %s", color.YellowString(c.exchange), color.YellowString(c.routingKey)) - } - } - - deliveries, err := ch.Consume( - q.Name, // queue name to consume from - "", // consumer tag - true, // is auto ack - false, // is exclusive - false, // is no local - false, // is no wait - nil, // args - ) - if err != nil { - return fmt.Errorf("%s %w", color.RedString("failed to register a consumer:"), err) - } - - go func() { - var db *sql.DB - var insert *sql.Stmt - if ctx.IsSet("store") { - filename := ctx.String("store") - file, err := os.Create(filename) - if err != nil { - log.Fatal(err) - } - file.Close() - db, err = sql.Open("sqlite", filename+"?_txlock=exclusive") + defer func() { + log.Printf("💔 Closing database connection") + err := db.Close() if err != nil { log.Fatal(err) } - defer func() { - err := db.Close() - if err != nil { - log.Fatal(err) - } - log.Printf("💔 Closing database connection") - }() + }() - create, err := db.Prepare(`CREATE TABLE event + create, err := db.Prepare(`CREATE TABLE IF NOT EXISTS event ( "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, "timestamp" TIMESTAMP DEFAULT (DATETIME(CURRENT_TIMESTAMP, 'localtime')), @@ -260,49 +329,102 @@ func main() { "headers" TEXT, "body" TEXT );`) - if err != nil { - log.Fatal(err) - } - if _, err := create.Exec(); err != nil { - log.Fatal(err) - } - insert, err = db.Prepare(`INSERT INTO event(exchange, routing_key, correlation_id, reply_to, headers, body) + if err != nil { + log.Fatal(err) + } + if _, err := create.Exec(); err != nil { + log.Fatal(err) + } + insert, err = db.Prepare(`INSERT INTO event(exchange, routing_key, correlation_id, reply_to, headers, body) VALUES (?, ?, ?, ?, ?, ?)`) + if err != nil { + log.Fatal(err) + } + } + + rmq := &rabbitMQConnection{ + rabbitUrl: rabbitUrl, + insecure: ctx.Bool("insecure"), + queue: fmt.Sprintf("%s.%s", ctx.String("queue"), uuid.NewString()), + persistent: ctx.Bool("persistent"), + deliverables: ctx.Generic("exchange").(*listen), + } + defer func() { + log.Printf("💔 Closing RabbitMQ connection") + err := rmq.Close() + if err != nil { + log.Printf("Error closing RabbitMQ connection: %v", err) + } + }() + + var consumedCount int32 = 0 + + go func() { + for { + if err := rmq.Connect(); err != nil { + log.Printf("Error connecting to RabbitMQ: %v", err) + if err.fatal { + cancel() + return + } + continue + } + deliveries, err := rmq.Consume() if err != nil { - log.Fatal(err) + log.Printf("Error starting consumer: %v", err) + closeErr := rmq.conn.Close() + log.Printf("Error closing connection: %v", closeErr) + continue } - } - count := 0 - for d := range deliveries { - if insert != nil { - if _, err := insert.Exec(d.Exchange, d.RoutingKey, d.CorrelationId, d.ReplyTo, fmt.Sprint(d.Headers), string(d.Body)); err != nil { - log.Fatal(err) + + log.Printf("⏳ Waiting for messages. To exit press %s", color.YellowString("CTRL+C")) + + for d := range deliveries { + if insert != nil { + if _, err := insert.Exec(d.Exchange, d.RoutingKey, d.CorrelationId, d.ReplyTo, fmt.Sprint(d.Headers), string(d.Body)); err != nil { + log.Fatal(err) + } + } + if !ctx.Bool("silent") { + log.Printf("📧 %s\n%s%s\n%s%s\n%s%s\n%s%s\n%s%s\n%s%s", + color.YellowString("Received a message"), + color.GreenString("# Exchange : "), + d.Exchange, + color.GreenString("# Routing-key : "), + d.RoutingKey, + color.GreenString("# Correlation-id : "), + d.CorrelationId, + color.GreenString("# Reply-to : "), + d.ReplyTo, + color.GreenString("# Headers : "), + d.Headers, + color.GreenString("# Body : "), + d.Body) + } else { + atomic.AddInt32(&consumedCount, 1) + fmt.Printf("\033[1A\033[K") + log.Printf("💾 Consumed %s messages. To exit press %s", color.GreenString("%d", consumedCount), color.YellowString("CTRL+C")) } } - if !ctx.Bool("silent") { - log.Printf("📧 %s\n%s%s\n%s%s\n%s%s\n%s%s\n%s%s\n%s%s", - color.YellowString("Received a message"), - color.GreenString("# Exchange : "), - d.Exchange, - color.GreenString("# Routing-key : "), - d.RoutingKey, - color.GreenString("# Correlation-id : "), - d.CorrelationId, - color.GreenString("# Reply-to : "), - d.ReplyTo, - color.GreenString("# Headers : "), - d.Headers, - color.GreenString("# Body : "), - d.Body) - } else { - count++ - fmt.Printf("\033[1A\033[K") - log.Printf("💾 Consumed %s messages. To exit press %s", color.GreenString("%d", count), color.YellowString("CTRL+C")) + + select { + case <-rmq.conn.NotifyClose(make(chan *amqp.Error)): + if ctx.Err() != nil { + return + } + log.Printf("💥 Connection was closed enexpectedly, reconnecting ...") + if insert != nil { + if _, err := insert.Exec("", "", "", "", "", "CONNECTION_INTERRUPTED"); err != nil { + log.Fatal(err) + } + } + continue + case <-ctx.Done(): + return } } }() - log.Printf("⏳ Waiting for messages. To exit press %s", color.YellowString("CTRL+C")) <-ctx.Done() return nil }, diff --git a/coyote_test.go b/coyote_test.go index 9345afd..e4d6a30 100644 --- a/coyote_test.go +++ b/coyote_test.go @@ -148,6 +148,7 @@ GLOBAL OPTIONS: --insecure Skips certificate verification. (default: false) --noprompt Disables password prompt. (default: false) --silent Disables terminal print. (default: false) + --persistent Creates a persistent interceptor queue. (default: false) --help, -h show help --version, -v print the version` )