From 0f5a13e7464e79893c53ac38abf2cc8c00843d1e Mon Sep 17 00:00:00 2001 From: Rohini Chandra <61837065+r14chandra@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:30:48 +0530 Subject: [PATCH] feat: support different sasl mechanism --- internal/config/config.go | 4 ++ internal/util/kafka.go | 91 ++++++++++++++++++++++++++++----------- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 78e884fd..93dc1d03 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -43,6 +43,7 @@ type Config struct { KafkaSystemProfileTopic string KafkaUsername string KafkaCAPath string + KafkaSaslMechanism string LogBatchFrequency time.Duration LogFormat flagvar.Enum LogGroup string @@ -96,6 +97,7 @@ var DefaultConfig Config = Config{ KafkaSystemProfileTopic: "platform.inventory.system-profile", KafkaUsername: "", KafkaCAPath: "", + KafkaSaslMechanism: "", LogBatchFrequency: 10 * time.Second, LogFormat: flagvar.Enum{Choices: []string{"json", "text"}, Value: "json"}, LogGroup: "platform-dev", @@ -138,6 +140,7 @@ func init() { DefaultConfig.KafkaUsername = *broker.Sasl.Username DefaultConfig.KafkaPassword = *broker.Sasl.Password + DefaultConfig.KafkaSaslMechanism = *broker.Sasl.SaslMechanism if broker.Cacert != nil { caPath, err := clowder.LoadedConfig.KafkaCa(broker) @@ -200,6 +203,7 @@ func FlagSet(name string, errorHandling flag.ErrorHandling) *flag.FlagSet { fs.StringVar(&DefaultConfig.KafkaSystemProfileTopic, "kafka-system-profile-topic", DefaultConfig.KafkaSystemProfileTopic, "host-inventory system-profile topic name") fs.StringVar(&DefaultConfig.KafkaUsername, "kafka-username", DefaultConfig.KafkaUsername, "managed kafka auth username") fs.StringVar(&DefaultConfig.KafkaCAPath, "kafka-cacert-path", DefaultConfig.KafkaCAPath, "managed kafka cacert path") + fs.StringVar(&DefaultConfig.KafkaSaslMechanism, "kafka-sasl-mechanism", DefaultConfig.KafkaSaslMechanism, "managed kafka sasl mechanism") fs.DurationVar(&DefaultConfig.LogBatchFrequency, "log-batch-frequency", DefaultConfig.LogBatchFrequency, "CloudWatch batch log frequency") fs.Var(&DefaultConfig.LogFormat, "log-format", fmt.Sprintf("structured logging output format (%v)", DefaultConfig.LogFormat.Help())) fs.StringVar(&DefaultConfig.LogGroup, "log-group", DefaultConfig.LogGroup, "CloudWatch log group") diff --git a/internal/util/kafka.go b/internal/util/kafka.go index 1529e005..cfa31341 100644 --- a/internal/util/kafka.go +++ b/internal/util/kafka.go @@ -6,10 +6,13 @@ import ( "crypto/x509" "fmt" "io/ioutil" + "strings" "time" "github.com/segmentio/kafka-go" + "github.com/segmentio/kafka-go/sasl" "github.com/segmentio/kafka-go/sasl/plain" + "github.com/segmentio/kafka-go/sasl/scram" ) var Kafka kafkautil @@ -21,22 +24,12 @@ func (k kafkautil) NewReader(topic string) *kafka.Reader { var dialer *kafka.Dialer if config.DefaultConfig.KafkaUsername != "" && config.DefaultConfig.KafkaPassword != "" { - tlsConfig, err := createTLSConfig(config.DefaultConfig.KafkaCAPath) - - // Log the error, but proceed with default TLS configuration - if err != nil { - fmt.Println(err) - tlsConfig = &tls.Config{} // Providing default empty TLS configuration - } - + saslMechanism, tlsConfig := getSaslAndTLSConfig() dialer = &kafka.Dialer{ - Timeout: 10 * time.Second, - DualStack: true, - SASLMechanism: plain.Mechanism{ - Username: config.DefaultConfig.KafkaUsername, - Password: config.DefaultConfig.KafkaPassword, - }, - TLS: tlsConfig, + Timeout: 10 * time.Second, + DualStack: true, + SASLMechanism: saslMechanism, + TLS: tlsConfig, } } @@ -54,18 +47,11 @@ func (k kafkautil) NewWriter(topic string) *kafka.Writer { var transport *kafka.Transport = kafka.DefaultTransport.(*kafka.Transport) if config.DefaultConfig.KafkaUsername != "" && config.DefaultConfig.KafkaPassword != "" { - tlsConfig, err := createTLSConfig(config.DefaultConfig.KafkaCAPath) - if err != nil { - fmt.Println("Error creating TLS configuration for Kafka:", err) - tlsConfig = &tls.Config{} // Providing default empty TLS configuration - } + saslMechanism, tlsConfig := getSaslAndTLSConfig() transport = &kafka.Transport{ - SASL: plain.Mechanism{ - Username: config.DefaultConfig.KafkaUsername, - Password: config.DefaultConfig.KafkaPassword, - }, - TLS: tlsConfig, + SASL: saslMechanism, + TLS: tlsConfig, } } @@ -77,6 +63,29 @@ func (k kafkautil) NewWriter(topic string) *kafka.Writer { } } +func getSaslAndTLSConfig() (sasl.Mechanism, *tls.Config) { + username := config.DefaultConfig.KafkaUsername + password := config.DefaultConfig.KafkaPassword + saslmechanismName := config.DefaultConfig.KafkaSaslMechanism + + tlsConfig, err := createTLSConfig(config.DefaultConfig.KafkaCAPath) + if err != nil { + fmt.Println("Error creating TLS configuration for Kafka:", err) + tlsConfig = &tls.Config{} // Providing default empty TLS configuration + } + + saslMechanism, err := createSaslMechanism( + saslmechanismName, + username, + password, + ) + if err != nil { + fmt.Println("Error creating SASL Mechanism for Kafka, using plain mechanism:", err) + } + + return saslMechanism, tlsConfig +} + func createTLSConfig(pathToCert string) (*tls.Config, error) { tlsConfig := tls.Config{ @@ -99,6 +108,38 @@ func createTLSConfig(pathToCert string) (*tls.Config, error) { return &tlsConfig, nil } +func createSaslMechanism(saslMechanism string, username string, password string) (sasl.Mechanism, error) { + + switch strings.ToLower(saslMechanism) { + case "plain": + return plain.Mechanism{ + Username: username, + Password: password, + }, nil + + case "scram-sha-512": + mechanism, err := scram.Mechanism(scram.SHA512, username, password) + if err != nil { + return nil, fmt.Errorf("unable to create scram-sha-512 mechanism: %w", err) + } + return mechanism, nil + + case "scram-sha-256": + mechanism, err := scram.Mechanism(scram.SHA256, username, password) + if err != nil { + return nil, fmt.Errorf("unable to create scram-sha-256 mechanism: %w", err) + } + return mechanism, nil + + default: + // create plain mechanism as default + return plain.Mechanism{ + Username: username, + Password: password, + }, fmt.Errorf("unable to configure sasl mechanism (%s)", saslMechanism) + } +} + // GetHeader loops over the message headers, returning the value of key, if // found. func (k kafkautil) GetHeader(message kafka.Message, key string) (string, error) {