From 69c0b371717e1d5574d2054e08c754cc210b7996 Mon Sep 17 00:00:00 2001 From: Michael Nairn Date: Tue, 28 May 2024 17:23:50 +0100 Subject: [PATCH] feat: Add providers startup flag Add a new flag to the controller startup that determines what providers are enabled, if none are specified the list of registered default providers is used. The provider factory is updated to take a list of providers that are enabled and ensures that only providers in this list can be loaded via the factory methods. Providers themselves optionally register themselves as "default" providers, in the case where no providers are specified these providers are enabled in the factory. --- cmd/main.go | 48 ++++++++++++++++++--- config/deploy/local/kustomization.yaml | 8 ++++ go.mod | 2 +- internal/controller/suite_test.go | 3 +- internal/provider/aws/aws.go | 2 +- internal/provider/factory.go | 41 ++++++++++++++---- internal/provider/fake/factory.go | 18 -------- internal/provider/fake/provider.go | 59 -------------------------- internal/provider/google/google.go | 2 +- internal/provider/inmemory/inmemory.go | 2 +- test/e2e/suite_test.go | 6 ++- 11 files changed, 95 insertions(+), 96 deletions(-) delete mode 100644 internal/provider/fake/factory.go delete mode 100644 internal/provider/fake/provider.go diff --git a/cmd/main.go b/cmd/main.go index 0c9940b..ce86855 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -18,6 +18,7 @@ package main import ( "flag" + "fmt" "os" "strings" "time" @@ -67,6 +68,8 @@ func main() { var minRequeueTime time.Duration var validFor time.Duration var maxRequeueTime time.Duration + var providers stringSliceFlags + flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") flag.BoolVar(&enableLeaderElection, "leader-elect", false, @@ -81,13 +84,14 @@ func main() { flag.DurationVar(&minRequeueTime, "min-requeue-time", DefaultValidationDuration, "The minimal timeout between calls to the DNS Provider"+ "Controls if we commit to the full reconcile loop") + flag.Var(&providers, "provider", "DNS Provider(s) to enable. Can be passed multiple times e.g. --provider aws --provider google, or as a comma separated list e.g. --provider aws,gcp") opts := zap.Options{ Development: true, } opts.BindFlags(flag.CommandLine) flag.Parse() - logger := zap.New(zap.UseFlagOptions(&opts)) - ctrl.SetLogger(logger) + + ctrl.SetLogger(zap.New(zap.UseFlagOptions(&opts))) var watchNamespaces = "WATCH_NAMESPACES" defaultOptions := ctrl.Options{ @@ -101,7 +105,7 @@ func main() { if watch := os.Getenv(watchNamespaces); watch != "" { namespaces := strings.Split(watch, ",") - logger.Info("watching namespaces set ", watchNamespaces, namespaces) + setupLog.Info("watching namespaces set ", watchNamespaces, namespaces) cacheOpts := cache.Options{ DefaultNamespaces: map[string]cache.Config{}, } @@ -109,7 +113,6 @@ func main() { cacheOpts.DefaultNamespaces[ns] = cache.Config{} } defaultOptions.Cache = cacheOpts - } mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), defaultOptions) @@ -118,7 +121,21 @@ func main() { os.Exit(1) } - providerFactory := provider.NewFactory(mgr.GetClient()) + if len(providers) == 0 { + defaultProviders := provider.RegisteredDefaultProviders() + if defaultProviders == nil { + setupLog.Error(fmt.Errorf("no default providers registered"), "unable to set providers") + os.Exit(1) + } + providers = defaultProviders + } + + setupLog.Info("init provider factory", "providers", providers) + providerFactory, err := provider.NewFactory(mgr.GetClient(), providers) + if err != nil { + setupLog.Error(err, "unable to create provider factory") + os.Exit(1) + } if err = (&controller.ManagedZoneReconciler{ Client: mgr.GetClient(), @@ -154,3 +171,24 @@ func main() { os.Exit(1) } } + +type stringSliceFlags []string + +func (n *stringSliceFlags) String() string { + return strings.Join(*n, ",") +} + +func (n *stringSliceFlags) Set(s string) error { + if len(s) == 0 { + return fmt.Errorf("cannot be empty") + } + for _, strVal := range strings.Split(s, ",") { + for _, v := range *n { + if v == strVal { + return nil + } + } + *n = append(*n, strVal) + } + return nil +} diff --git a/config/deploy/local/kustomization.yaml b/config/deploy/local/kustomization.yaml index ed0c7bc..eaa5b8b 100644 --- a/config/deploy/local/kustomization.yaml +++ b/config/deploy/local/kustomization.yaml @@ -8,3 +8,11 @@ resources: patchesStrategicMerge: - manager_config_patch.yaml + +patches: + - patch: |- + - op: add + path: /spec/template/spec/containers/0/args/- + value: --provider=aws,google,inmemory + target: + kind: Deployment diff --git a/go.mod b/go.mod index b674950..fabd12b 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/rs/xid v1.5.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d google.golang.org/api v0.134.0 k8s.io/api v0.28.3 k8s.io/apimachinery v0.28.3 @@ -78,7 +79,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect golang.org/x/crypto v0.21.0 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/net v0.23.0 // indirect golang.org/x/oauth2 v0.13.0 // indirect golang.org/x/sys v0.18.0 // indirect diff --git a/internal/controller/suite_test.go b/internal/controller/suite_test.go index d24c943..1348e38 100644 --- a/internal/controller/suite_test.go +++ b/internal/controller/suite_test.go @@ -99,7 +99,8 @@ var _ = BeforeSuite(func() { }) Expect(err).ToNot(HaveOccurred()) - providerFactory := provider.NewFactory(mgr.GetClient()) + providerFactory, err := provider.NewFactory(mgr.GetClient(), []string{"inmemory"}) + Expect(err).ToNot(HaveOccurred()) err = (&ManagedZoneReconciler{ Client: mgr.GetClient(), diff --git a/internal/provider/aws/aws.go b/internal/provider/aws/aws.go index 554dd68..f0a6107 100644 --- a/internal/provider/aws/aws.go +++ b/internal/provider/aws/aws.go @@ -264,5 +264,5 @@ func (*Route53DNSProvider) ProviderSpecific() provider.ProviderSpecificLabels { // Register this Provider with the provider factory func init() { - provider.RegisterProvider("aws", NewProviderFromSecret) + provider.RegisterProvider("aws", NewProviderFromSecret, true) } diff --git a/internal/provider/factory.go b/internal/provider/factory.go index acda562..ac86f42 100644 --- a/internal/provider/factory.go +++ b/internal/provider/factory.go @@ -2,9 +2,13 @@ package provider import ( "context" + "errors" "fmt" + "slices" "sync" + "golang.org/x/exp/maps" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" @@ -23,14 +27,23 @@ type ProviderConstructor func(context.Context, *v1.Secret, Config) (Provider, er var ( constructors = make(map[string]ProviderConstructor) constructorsLock sync.RWMutex + defaultProviders []string ) // RegisterProvider will register a provider constructor, so it can be used within the application. -// 'name' should be unique, and should be used to identify this provider. -func RegisterProvider(name string, c ProviderConstructor) { +// 'name' should be unique, and should be used to identify this provider +// `asDefault` indicates if the provider should be added as a default provider and included in the default providers list. +func RegisterProvider(name string, c ProviderConstructor, asDefault bool) { constructorsLock.Lock() defer constructorsLock.Unlock() constructors[name] = c + if asDefault { + defaultProviders = append(defaultProviders, name) + } +} + +func RegisteredDefaultProviders() []string { + return defaultProviders } // Factory is an interface that can be used to obtain Provider implementations. @@ -42,11 +55,20 @@ type Factory interface { // factory is the default Factory implementation type factory struct { client.Client + providers []string } -// NewFactory returns a new provider factory with the given client. -func NewFactory(c client.Client) Factory { - return &factory{Client: c} +// NewFactory returns a new provider factory with the given client and given providers enabled. +// Will return an error if any given provider has no registered provider implementation. +func NewFactory(c client.Client, p []string) (Factory, error) { + var err error + registeredProviders := maps.Keys(constructors) + for _, provider := range p { + if !slices.Contains(registeredProviders, provider) { + err = errors.Join(err, fmt.Errorf("provider '%s' not registered", provider)) + } + } + return &factory{Client: c, providers: p}, err } // ProviderFor will return a Provider interface for the given ProviderAccessor secret. @@ -62,18 +84,21 @@ func (f *factory) ProviderFor(ctx context.Context, pa v1alpha1.ProviderAccessor, return nil, err } - providerType, err := nameForProviderSecret(providerSecret) + provider, err := nameForProviderSecret(providerSecret) if err != nil { return nil, err } constructorsLock.RLock() defer constructorsLock.RUnlock() - if constructor, ok := constructors[providerType]; ok { + if constructor, ok := constructors[provider]; ok { + if !slices.Contains(f.providers, provider) { + return nil, fmt.Errorf("provider '%s' not enabled", provider) + } return constructor(ctx, providerSecret, c) } - return nil, fmt.Errorf("provider '%s' not registered", providerType) + return nil, fmt.Errorf("provider '%s' not registered", provider) } func nameForProviderSecret(secret *v1.Secret) (string, error) { diff --git a/internal/provider/fake/factory.go b/internal/provider/fake/factory.go deleted file mode 100644 index 6ff65a9..0000000 --- a/internal/provider/fake/factory.go +++ /dev/null @@ -1,18 +0,0 @@ -package fake - -import ( - "context" - - "github.com/kuadrant/dns-operator/api/v1alpha1" - "github.com/kuadrant/dns-operator/internal/provider" -) - -type Factory struct { - ProviderForFunc func(ctx context.Context, pa v1alpha1.ProviderAccessor, c provider.Config) (provider.Provider, error) -} - -var _ provider.Factory = &Factory{} - -func (f *Factory) ProviderFor(ctx context.Context, pa v1alpha1.ProviderAccessor, c provider.Config) (provider.Provider, error) { - return f.ProviderForFunc(ctx, pa, c) -} diff --git a/internal/provider/fake/provider.go b/internal/provider/fake/provider.go deleted file mode 100644 index b663c0b..0000000 --- a/internal/provider/fake/provider.go +++ /dev/null @@ -1,59 +0,0 @@ -package fake - -import ( - "context" - - externaldnsendpoint "sigs.k8s.io/external-dns/endpoint" - externaldnsplan "sigs.k8s.io/external-dns/plan" - - "github.com/kuadrant/dns-operator/api/v1alpha1" - "github.com/kuadrant/dns-operator/internal/provider" -) - -type Provider struct { - RecordsFunc func(context.Context) ([]*externaldnsendpoint.Endpoint, error) - ApplyChangesFunc func(context.Context, *externaldnsplan.Changes) error - AdjustEndpointsFunc func([]*externaldnsendpoint.Endpoint) ([]*externaldnsendpoint.Endpoint, error) - GetDomainFilterFunc func() externaldnsendpoint.DomainFilter - EnsureManagedZoneFunc func(*v1alpha1.ManagedZone) (provider.ManagedZoneOutput, error) - DeleteManagedZoneFunc func(*v1alpha1.ManagedZone) error -} - -var _ provider.Provider = &Provider{} - -// #### External DNS Provider #### - -func (p Provider) Records(ctx context.Context) ([]*externaldnsendpoint.Endpoint, error) { - return p.RecordsFunc(ctx) -} - -func (p Provider) ApplyChanges(ctx context.Context, changes *externaldnsplan.Changes) error { - return p.ApplyChangesFunc(ctx, changes) -} - -func (p Provider) AdjustEndpoints(endpoints []*externaldnsendpoint.Endpoint) ([]*externaldnsendpoint.Endpoint, error) { - return p.AdjustEndpointsFunc(endpoints) -} - -func (p Provider) GetDomainFilter() externaldnsendpoint.DomainFilter { - return p.GetDomainFilterFunc() -} - -// #### DNS Operator Provider #### - -func (p Provider) EnsureManagedZone(managedZone *v1alpha1.ManagedZone) (provider.ManagedZoneOutput, error) { - return p.EnsureManagedZoneFunc(managedZone) -} - -func (p Provider) DeleteManagedZone(managedZone *v1alpha1.ManagedZone) error { - return p.DeleteManagedZoneFunc(managedZone) -} -func (p Provider) HealthCheckReconciler() provider.HealthCheckReconciler { - return &provider.FakeHealthCheckReconciler{} -} -func (p Provider) ProviderSpecific() provider.ProviderSpecificLabels { - return provider.ProviderSpecificLabels{ - Weight: "fake/weight", - HealthCheckID: "fake/health-check-id", - } -} diff --git a/internal/provider/google/google.go b/internal/provider/google/google.go index 1ea99fe..85273b0 100644 --- a/internal/provider/google/google.go +++ b/internal/provider/google/google.go @@ -719,5 +719,5 @@ func (p *GoogleDNSProvider) getResourceRecordSets(ctx context.Context, zoneID st // Register this Provider with the provider factory func init() { - provider.RegisterProvider("google", NewProviderFromSecret) + provider.RegisterProvider("google", NewProviderFromSecret, true) } diff --git a/internal/provider/inmemory/inmemory.go b/internal/provider/inmemory/inmemory.go index 159e4c0..91631b1 100644 --- a/internal/provider/inmemory/inmemory.go +++ b/internal/provider/inmemory/inmemory.go @@ -91,5 +91,5 @@ func (i InMemoryDNSProvider) ProviderSpecific() provider.ProviderSpecificLabels // Register this Provider with the provider factory func init() { client = inmemory.NewInMemoryClient() - provider.RegisterProvider("inmemory", NewProviderFromSecret) + provider.RegisterProvider("inmemory", NewProviderFromSecret, false) } diff --git a/test/e2e/suite_test.go b/test/e2e/suite_test.go index 67cb2c5..15aaa2e 100644 --- a/test/e2e/suite_test.go +++ b/test/e2e/suite_test.go @@ -168,7 +168,11 @@ func EndpointsForHost(ctx context.Context, provider provider.Provider, host stri } func providerForManagedZone(ctx context.Context, mz *v1alpha1.ManagedZone) (provider.Provider, error) { - providerFactory := provider.NewFactory(k8sClient) + //ToDo mnairn: We have a mismatch in naming GCP vs Google, we need to make this consistent one way or the other + providerFactory, err := provider.NewFactory(k8sClient, []string{"aws", "google"}) + if err != nil { + return nil, err + } providerConfig := provider.Config{ DomainFilter: externaldnsendpoint.NewDomainFilter([]string{mz.Spec.DomainName}), ZoneTypeFilter: externaldnsprovider.NewZoneTypeFilter(""),