Skip to content

Commit

Permalink
feat: support different sasl mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
r14chandra committed Mar 7, 2024
1 parent 0b78ddc commit 0f5a13e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 25 deletions.
4 changes: 4 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Config struct {
KafkaSystemProfileTopic string
KafkaUsername string
KafkaCAPath string
KafkaSaslMechanism string
LogBatchFrequency time.Duration
LogFormat flagvar.Enum
LogGroup string
Expand Down Expand Up @@ -96,6 +97,7 @@ var DefaultConfig Config = Config{
KafkaSystemProfileTopic: "platform.inventory.system-profile",
KafkaUsername: "",
KafkaCAPath: "",
KafkaSaslMechanism: "",
LogBatchFrequency: 10 * time.Second,
LogFormat: flagvar.Enum{Choices: []string{"json", "text"}, Value: "json"},
LogGroup: "platform-dev",
Expand Down Expand Up @@ -138,6 +140,7 @@ func init() {

DefaultConfig.KafkaUsername = *broker.Sasl.Username
DefaultConfig.KafkaPassword = *broker.Sasl.Password
DefaultConfig.KafkaSaslMechanism = *broker.Sasl.SaslMechanism

if broker.Cacert != nil {
caPath, err := clowder.LoadedConfig.KafkaCa(broker)
Expand Down Expand Up @@ -200,6 +203,7 @@ func FlagSet(name string, errorHandling flag.ErrorHandling) *flag.FlagSet {
fs.StringVar(&DefaultConfig.KafkaSystemProfileTopic, "kafka-system-profile-topic", DefaultConfig.KafkaSystemProfileTopic, "host-inventory system-profile topic name")
fs.StringVar(&DefaultConfig.KafkaUsername, "kafka-username", DefaultConfig.KafkaUsername, "managed kafka auth username")
fs.StringVar(&DefaultConfig.KafkaCAPath, "kafka-cacert-path", DefaultConfig.KafkaCAPath, "managed kafka cacert path")
fs.StringVar(&DefaultConfig.KafkaSaslMechanism, "kafka-sasl-mechanism", DefaultConfig.KafkaSaslMechanism, "managed kafka sasl mechanism")
fs.DurationVar(&DefaultConfig.LogBatchFrequency, "log-batch-frequency", DefaultConfig.LogBatchFrequency, "CloudWatch batch log frequency")
fs.Var(&DefaultConfig.LogFormat, "log-format", fmt.Sprintf("structured logging output format (%v)", DefaultConfig.LogFormat.Help()))
fs.StringVar(&DefaultConfig.LogGroup, "log-group", DefaultConfig.LogGroup, "CloudWatch log group")
Expand Down
91 changes: 66 additions & 25 deletions internal/util/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import (
"crypto/x509"
"fmt"
"io/ioutil"
"strings"
"time"

"github.com/segmentio/kafka-go"
"github.com/segmentio/kafka-go/sasl"
"github.com/segmentio/kafka-go/sasl/plain"
"github.com/segmentio/kafka-go/sasl/scram"
)

var Kafka kafkautil
Expand All @@ -21,22 +24,12 @@ func (k kafkautil) NewReader(topic string) *kafka.Reader {
var dialer *kafka.Dialer

if config.DefaultConfig.KafkaUsername != "" && config.DefaultConfig.KafkaPassword != "" {
tlsConfig, err := createTLSConfig(config.DefaultConfig.KafkaCAPath)

// Log the error, but proceed with default TLS configuration
if err != nil {
fmt.Println(err)
tlsConfig = &tls.Config{} // Providing default empty TLS configuration
}

saslMechanism, tlsConfig := getSaslAndTLSConfig()
dialer = &kafka.Dialer{
Timeout: 10 * time.Second,
DualStack: true,
SASLMechanism: plain.Mechanism{
Username: config.DefaultConfig.KafkaUsername,
Password: config.DefaultConfig.KafkaPassword,
},
TLS: tlsConfig,
Timeout: 10 * time.Second,
DualStack: true,
SASLMechanism: saslMechanism,
TLS: tlsConfig,
}
}

Expand All @@ -54,18 +47,11 @@ func (k kafkautil) NewWriter(topic string) *kafka.Writer {
var transport *kafka.Transport = kafka.DefaultTransport.(*kafka.Transport)

if config.DefaultConfig.KafkaUsername != "" && config.DefaultConfig.KafkaPassword != "" {
tlsConfig, err := createTLSConfig(config.DefaultConfig.KafkaCAPath)
if err != nil {
fmt.Println("Error creating TLS configuration for Kafka:", err)
tlsConfig = &tls.Config{} // Providing default empty TLS configuration
}
saslMechanism, tlsConfig := getSaslAndTLSConfig()

transport = &kafka.Transport{
SASL: plain.Mechanism{
Username: config.DefaultConfig.KafkaUsername,
Password: config.DefaultConfig.KafkaPassword,
},
TLS: tlsConfig,
SASL: saslMechanism,
TLS: tlsConfig,
}

}
Expand All @@ -77,6 +63,29 @@ func (k kafkautil) NewWriter(topic string) *kafka.Writer {
}
}

func getSaslAndTLSConfig() (sasl.Mechanism, *tls.Config) {
username := config.DefaultConfig.KafkaUsername
password := config.DefaultConfig.KafkaPassword
saslmechanismName := config.DefaultConfig.KafkaSaslMechanism

tlsConfig, err := createTLSConfig(config.DefaultConfig.KafkaCAPath)
if err != nil {
fmt.Println("Error creating TLS configuration for Kafka:", err)
tlsConfig = &tls.Config{} // Providing default empty TLS configuration
}

saslMechanism, err := createSaslMechanism(
saslmechanismName,
username,
password,
)
if err != nil {
fmt.Println("Error creating SASL Mechanism for Kafka, using plain mechanism:", err)
}

return saslMechanism, tlsConfig
}

func createTLSConfig(pathToCert string) (*tls.Config, error) {

tlsConfig := tls.Config{
Expand All @@ -99,6 +108,38 @@ func createTLSConfig(pathToCert string) (*tls.Config, error) {
return &tlsConfig, nil
}

func createSaslMechanism(saslMechanism string, username string, password string) (sasl.Mechanism, error) {

switch strings.ToLower(saslMechanism) {
case "plain":
return plain.Mechanism{
Username: username,
Password: password,
}, nil

case "scram-sha-512":
mechanism, err := scram.Mechanism(scram.SHA512, username, password)
if err != nil {
return nil, fmt.Errorf("unable to create scram-sha-512 mechanism: %w", err)
}
return mechanism, nil

case "scram-sha-256":
mechanism, err := scram.Mechanism(scram.SHA256, username, password)
if err != nil {
return nil, fmt.Errorf("unable to create scram-sha-256 mechanism: %w", err)
}
return mechanism, nil

default:
// create plain mechanism as default
return plain.Mechanism{
Username: username,
Password: password,
}, fmt.Errorf("unable to configure sasl mechanism (%s)", saslMechanism)
}
}

// GetHeader loops over the message headers, returning the value of key, if
// found.
func (k kafkautil) GetHeader(message kafka.Message, key string) (string, error) {
Expand Down

0 comments on commit 0f5a13e

Please sign in to comment.