Skip to content

Commit

Permalink
Merge pull request #142 from mikenairn/add_supported_providers_flag
Browse files Browse the repository at this point in the history
feat: Add providers startup flag
  • Loading branch information
mikenairn authored May 29, 2024
2 parents c174376 + 69c0b37 commit 8b67a9a
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 96 deletions.
48 changes: 43 additions & 5 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"flag"
"fmt"
"os"
"strings"
"time"
Expand Down Expand Up @@ -67,6 +68,8 @@ func main() {
var minRequeueTime time.Duration
var validFor time.Duration
var maxRequeueTime time.Duration
var providers stringSliceFlags

flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.")
flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
flag.BoolVar(&enableLeaderElection, "leader-elect", false,
Expand All @@ -81,13 +84,14 @@ func main() {
flag.DurationVar(&minRequeueTime, "min-requeue-time", DefaultValidationDuration,
"The minimal timeout between calls to the DNS Provider"+
"Controls if we commit to the full reconcile loop")
flag.Var(&providers, "provider", "DNS Provider(s) to enable. Can be passed multiple times e.g. --provider aws --provider google, or as a comma separated list e.g. --provider aws,gcp")
opts := zap.Options{
Development: true,
}
opts.BindFlags(flag.CommandLine)
flag.Parse()
logger := zap.New(zap.UseFlagOptions(&opts))
ctrl.SetLogger(logger)

ctrl.SetLogger(zap.New(zap.UseFlagOptions(&opts)))

var watchNamespaces = "WATCH_NAMESPACES"
defaultOptions := ctrl.Options{
Expand All @@ -101,15 +105,14 @@ func main() {

if watch := os.Getenv(watchNamespaces); watch != "" {
namespaces := strings.Split(watch, ",")
logger.Info("watching namespaces set ", watchNamespaces, namespaces)
setupLog.Info("watching namespaces set ", watchNamespaces, namespaces)
cacheOpts := cache.Options{
DefaultNamespaces: map[string]cache.Config{},
}
for _, ns := range namespaces {
cacheOpts.DefaultNamespaces[ns] = cache.Config{}
}
defaultOptions.Cache = cacheOpts

}

mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), defaultOptions)
Expand All @@ -118,7 +121,21 @@ func main() {
os.Exit(1)
}

providerFactory := provider.NewFactory(mgr.GetClient())
if len(providers) == 0 {
defaultProviders := provider.RegisteredDefaultProviders()
if defaultProviders == nil {
setupLog.Error(fmt.Errorf("no default providers registered"), "unable to set providers")
os.Exit(1)
}
providers = defaultProviders
}

setupLog.Info("init provider factory", "providers", providers)
providerFactory, err := provider.NewFactory(mgr.GetClient(), providers)
if err != nil {
setupLog.Error(err, "unable to create provider factory")
os.Exit(1)
}

if err = (&controller.ManagedZoneReconciler{
Client: mgr.GetClient(),
Expand Down Expand Up @@ -154,3 +171,24 @@ func main() {
os.Exit(1)
}
}

type stringSliceFlags []string

func (n *stringSliceFlags) String() string {
return strings.Join(*n, ",")
}

func (n *stringSliceFlags) Set(s string) error {
if len(s) == 0 {
return fmt.Errorf("cannot be empty")
}
for _, strVal := range strings.Split(s, ",") {
for _, v := range *n {
if v == strVal {
return nil
}
}
*n = append(*n, strVal)
}
return nil
}
8 changes: 8 additions & 0 deletions config/deploy/local/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ resources:

patchesStrategicMerge:
- manager_config_patch.yaml

patches:
- patch: |-
- op: add
path: /spec/template/spec/containers/0/args/-
value: --provider=aws,google,inmemory
target:
kind: Deployment
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
github.com/rs/xid v1.5.0
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.4
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
google.golang.org/api v0.134.0
k8s.io/api v0.28.3
k8s.io/apimachinery v0.28.3
Expand Down Expand Up @@ -78,7 +79,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.26.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/net v0.23.0 // indirect
golang.org/x/oauth2 v0.13.0 // indirect
golang.org/x/sys v0.18.0 // indirect
Expand Down
3 changes: 2 additions & 1 deletion internal/controller/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ var _ = BeforeSuite(func() {
})
Expect(err).ToNot(HaveOccurred())

providerFactory := provider.NewFactory(mgr.GetClient())
providerFactory, err := provider.NewFactory(mgr.GetClient(), []string{"inmemory"})
Expect(err).ToNot(HaveOccurred())

err = (&ManagedZoneReconciler{
Client: mgr.GetClient(),
Expand Down
2 changes: 1 addition & 1 deletion internal/provider/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,5 +264,5 @@ func (*Route53DNSProvider) ProviderSpecific() provider.ProviderSpecificLabels {

// Register this Provider with the provider factory
func init() {
provider.RegisterProvider("aws", NewProviderFromSecret)
provider.RegisterProvider("aws", NewProviderFromSecret, true)
}
41 changes: 33 additions & 8 deletions internal/provider/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package provider

import (
"context"
"errors"
"fmt"
"slices"
"sync"

"golang.org/x/exp/maps"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand All @@ -23,14 +27,23 @@ type ProviderConstructor func(context.Context, *v1.Secret, Config) (Provider, er
var (
constructors = make(map[string]ProviderConstructor)
constructorsLock sync.RWMutex
defaultProviders []string
)

// RegisterProvider will register a provider constructor, so it can be used within the application.
// 'name' should be unique, and should be used to identify this provider.
func RegisterProvider(name string, c ProviderConstructor) {
// 'name' should be unique, and should be used to identify this provider
// `asDefault` indicates if the provider should be added as a default provider and included in the default providers list.
func RegisterProvider(name string, c ProviderConstructor, asDefault bool) {
constructorsLock.Lock()
defer constructorsLock.Unlock()
constructors[name] = c
if asDefault {
defaultProviders = append(defaultProviders, name)
}
}

func RegisteredDefaultProviders() []string {
return defaultProviders
}

// Factory is an interface that can be used to obtain Provider implementations.
Expand All @@ -42,11 +55,20 @@ type Factory interface {
// factory is the default Factory implementation
type factory struct {
client.Client
providers []string
}

// NewFactory returns a new provider factory with the given client.
func NewFactory(c client.Client) Factory {
return &factory{Client: c}
// NewFactory returns a new provider factory with the given client and given providers enabled.
// Will return an error if any given provider has no registered provider implementation.
func NewFactory(c client.Client, p []string) (Factory, error) {
var err error
registeredProviders := maps.Keys(constructors)
for _, provider := range p {
if !slices.Contains(registeredProviders, provider) {
err = errors.Join(err, fmt.Errorf("provider '%s' not registered", provider))
}
}
return &factory{Client: c, providers: p}, err
}

// ProviderFor will return a Provider interface for the given ProviderAccessor secret.
Expand All @@ -62,18 +84,21 @@ func (f *factory) ProviderFor(ctx context.Context, pa v1alpha1.ProviderAccessor,
return nil, err
}

providerType, err := nameForProviderSecret(providerSecret)
provider, err := nameForProviderSecret(providerSecret)
if err != nil {
return nil, err
}

constructorsLock.RLock()
defer constructorsLock.RUnlock()
if constructor, ok := constructors[providerType]; ok {
if constructor, ok := constructors[provider]; ok {
if !slices.Contains(f.providers, provider) {
return nil, fmt.Errorf("provider '%s' not enabled", provider)
}
return constructor(ctx, providerSecret, c)
}

return nil, fmt.Errorf("provider '%s' not registered", providerType)
return nil, fmt.Errorf("provider '%s' not registered", provider)
}

func nameForProviderSecret(secret *v1.Secret) (string, error) {
Expand Down
18 changes: 0 additions & 18 deletions internal/provider/fake/factory.go

This file was deleted.

59 changes: 0 additions & 59 deletions internal/provider/fake/provider.go

This file was deleted.

2 changes: 1 addition & 1 deletion internal/provider/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,5 +719,5 @@ func (p *GoogleDNSProvider) getResourceRecordSets(ctx context.Context, zoneID st

// Register this Provider with the provider factory
func init() {
provider.RegisterProvider("google", NewProviderFromSecret)
provider.RegisterProvider("google", NewProviderFromSecret, true)
}
2 changes: 1 addition & 1 deletion internal/provider/inmemory/inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ func (i InMemoryDNSProvider) ProviderSpecific() provider.ProviderSpecificLabels
// Register this Provider with the provider factory
func init() {
client = inmemory.NewInMemoryClient()
provider.RegisterProvider("inmemory", NewProviderFromSecret)
provider.RegisterProvider("inmemory", NewProviderFromSecret, false)
}
6 changes: 5 additions & 1 deletion test/e2e/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,11 @@ func EndpointsForHost(ctx context.Context, provider provider.Provider, host stri
}

func providerForManagedZone(ctx context.Context, mz *v1alpha1.ManagedZone) (provider.Provider, error) {
providerFactory := provider.NewFactory(k8sClient)
//ToDo mnairn: We have a mismatch in naming GCP vs Google, we need to make this consistent one way or the other
providerFactory, err := provider.NewFactory(k8sClient, []string{"aws", "google"})
if err != nil {
return nil, err
}
providerConfig := provider.Config{
DomainFilter: externaldnsendpoint.NewDomainFilter([]string{mz.Spec.DomainName}),
ZoneTypeFilter: externaldnsprovider.NewZoneTypeFilter(""),
Expand Down

0 comments on commit 8b67a9a

Please sign in to comment.