From 99f3629787b90c2e7241742c2b8c138d392780d4 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 31 Dec 2024 17:23:03 +0100 Subject: [PATCH] feat(transport): add mTLS for Kafka (#367) Adds Kafka transport CLI options to use a specific CA and client certificates. --- transport/kafka/kafka.go | 58 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/transport/kafka/kafka.go b/transport/kafka/kafka.go index a96af035..0121272e 100644 --- a/transport/kafka/kafka.go +++ b/transport/kafka/kafka.go @@ -3,9 +3,11 @@ package kafka import ( "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "flag" "fmt" + "io" "net" "os" "strconv" @@ -18,7 +20,12 @@ import ( ) type KafkaDriver struct { - kafkaTLS bool + kafkaTLS bool + kafkaClientCert string + kafkaClientKey string + kafkaServerCA string + kafkaTlsInsecure bool + kafkaSASL string kafkaTopic string kafkaSrv string @@ -83,6 +90,12 @@ var ( func (d *KafkaDriver) Prepare() error { flag.BoolVar(&d.kafkaTLS, "transport.kafka.tls", false, "Use TLS to connect to Kafka") + + flag.StringVar(&d.kafkaClientCert, "transport.kafka.tls.client", "", "Kafka client certificate") + flag.StringVar(&d.kafkaClientKey, "transport.kafka.tls.key", "", "Kafka client key") + flag.StringVar(&d.kafkaServerCA, "transport.kafka.tls.ca", "", "Kafka certificate authority") + flag.BoolVar(&d.kafkaTlsInsecure, "transport.kafka.tls.insecure", false, "Skips TLS verification") + flag.StringVar(&d.kafkaSASL, "transport.kafka.sasl", "none", fmt.Sprintf( "Use SASL to connect to Kafka, available settings: %s (TLS is recommended and the environment variables KAFKA_SASL_USER and KAFKA_SASL_PASS need to be set)", @@ -151,6 +164,49 @@ func (d *KafkaDriver) Init() error { RootCAs: rootCAs, MinVersion: tls.VersionTLS12, } + + kafkaConfig.Net.TLS.Config.InsecureSkipVerify = d.kafkaTlsInsecure + + if d.kafkaServerCA != "" { + serverCaFile, err := os.Open(d.kafkaServerCA) + if err != nil { + return fmt.Errorf("error initializing server CA: %v", err) + } + + serverCaBytes, err := io.ReadAll(serverCaFile) + serverCaFile.Close() + if err != nil { + return fmt.Errorf("error reading server CA: %v", err) + } + + block, _ := pem.Decode(serverCaBytes) + + serverCa, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("error parsing server CA: %v", err) + } + + certPool := x509.NewCertPool() + certPool.AddCert(serverCa) + + kafkaConfig.Net.TLS.Config.RootCAs = certPool + } + + if d.kafkaClientCert != "" && d.kafkaClientKey != "" { + _, err := tls.LoadX509KeyPair(d.kafkaClientCert, d.kafkaClientKey) + if err != nil { + return fmt.Errorf("error initializing mTLS: %v", err) + } + + kafkaConfig.Net.TLS.Config.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(d.kafkaClientCert, d.kafkaClientKey) + if err != nil { + return nil, err + } + return &cert, nil + } + } + } if d.kafkaHashing {