diff --git a/cmd/discovery-service/main.go b/cmd/discovery-service/main.go index b9dce91..b9e8073 100644 --- a/cmd/discovery-service/main.go +++ b/cmd/discovery-service/main.go @@ -37,10 +37,10 @@ import ( "github.com/siderolabs/discovery-service/internal/landing" "github.com/siderolabs/discovery-service/internal/limiter" _ "github.com/siderolabs/discovery-service/internal/proto" - "github.com/siderolabs/discovery-service/internal/state" "github.com/siderolabs/discovery-service/internal/state/storage" "github.com/siderolabs/discovery-service/pkg/limits" "github.com/siderolabs/discovery-service/pkg/server" + "github.com/siderolabs/discovery-service/pkg/state" ) var ( diff --git a/internal/landing/landing.go b/internal/landing/landing.go index ac8d838..c7ee8ad 100644 --- a/internal/landing/landing.go +++ b/internal/landing/landing.go @@ -16,7 +16,8 @@ import ( "go.uber.org/zap" - "github.com/siderolabs/discovery-service/internal/state" + internalstate "github.com/siderolabs/discovery-service/internal/state" + "github.com/siderolabs/discovery-service/pkg/state" ) //go:embed "html/index.html" @@ -28,7 +29,7 @@ var inspectTemplate []byte // ClusterInspectData represents all affiliate data asssociated with a cluster. type ClusterInspectData struct { ClusterID string - Affiliates []*state.AffiliateExport + Affiliates []*internalstate.AffiliateExport } var inspectPage = template.Must(template.New("inspect").Parse(string(inspectTemplate))) diff --git a/internal/state/cluster.go b/internal/state/cluster.go index 0ea6728..7440fb2 100644 --- a/internal/state/cluster.go +++ b/internal/state/cluster.go @@ -27,7 +27,7 @@ type Cluster struct { subscriptionsMu sync.Mutex } -// NewCluster creates new cluster with specified ID. +// NewCluster creates new cluster with specified id. func NewCluster(id string) *Cluster { return &Cluster{ id: id, @@ -35,6 +35,11 @@ func NewCluster(id string) *Cluster { } } +// ID returns the cluster id. +func (cluster *Cluster) ID() string { + return cluster.id +} + // WithAffiliate runs a function against the affiliate. // // Cluster state is locked while the function is running. @@ -174,7 +179,8 @@ func (cluster *Cluster) notify(notifications ...*Notification) { } } -func (cluster *Cluster) stats() (affiliates, endpoints, subscriptions int) { +// Stats returns the number of affiliates, endpoints and subscriptions. +func (cluster *Cluster) Stats() (affiliates, endpoints, subscriptions int) { cluster.affiliatesMu.Lock() affiliates = len(cluster.affiliates) diff --git a/internal/state/snapshot.go b/internal/state/snapshot.go index 99ed5c3..1406b5e 100644 --- a/internal/state/snapshot.go +++ b/internal/state/snapshot.go @@ -6,7 +6,6 @@ package state import ( - "fmt" "slices" "github.com/siderolabs/gen/xslices" @@ -15,52 +14,8 @@ import ( storagepb "github.com/siderolabs/discovery-service/api/storage" ) -// ExportClusterSnapshots exports all cluster snapshots and calls the provided function for each one. -// -// Implements storage.Snapshotter interface. -func (state *State) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error { - var err error - - // reuse the same snapshotin each iteration - clusterSnapshot := &storagepb.ClusterSnapshot{} - - state.clusters.Enumerate(func(_ string, cluster *Cluster) bool { - snapshotCluster(cluster, clusterSnapshot) - - err = f(clusterSnapshot) - - return err == nil - }) - - return err -} - -// ImportClusterSnapshots imports cluster snapshots by calling the provided function until it returns false. -// -// Implements storage.Snapshotter interface. -func (state *State) ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error { - for { - clusterSnapshot, ok, err := f() - if err != nil { - return err - } - - if !ok { - break - } - - cluster := clusterFromSnapshot(clusterSnapshot) - - _, loaded := state.clusters.LoadOrStore(cluster.id, cluster) - if loaded { - return fmt.Errorf("cluster %q already exists", cluster.id) - } - } - - return nil -} - -func snapshotCluster(cluster *Cluster, snapshot *storagepb.ClusterSnapshot) { +// Snapshot takes a snapshot of the cluster into the given snapshot reference. +func (cluster *Cluster) Snapshot(snapshot *storagepb.ClusterSnapshot) { cluster.affiliatesMu.Lock() defer cluster.affiliatesMu.Unlock() @@ -110,7 +65,8 @@ func snapshotCluster(cluster *Cluster, snapshot *storagepb.ClusterSnapshot) { } } -func clusterFromSnapshot(snapshot *storagepb.ClusterSnapshot) *Cluster { +// ClusterFromSnapshot creates a new cluster from the provided snapshot. +func ClusterFromSnapshot(snapshot *storagepb.ClusterSnapshot) *Cluster { return &Cluster{ id: snapshot.Id, affiliates: xslices.ToMap(snapshot.Affiliates, affiliateFromSnapshot), diff --git a/internal/state/state.go b/internal/state/state.go index 5509a40..1600d72 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -3,175 +3,5 @@ // Use of this software is governed by the Business Source License // included in the LICENSE file. -// Package state implements server state with clusters, affiliates, subscriptions, etc. +// Package state contains internal state-related components such as affiliates, subscriptions, etc. package state - -import ( - "context" - "time" - - prom "github.com/prometheus/client_golang/prometheus" - "github.com/siderolabs/gen/concurrent" - "go.uber.org/zap" -) - -// State keeps the discovery service state. -type State struct { - clusters *concurrent.HashTrieMap[string, *Cluster] - logger *zap.Logger - - mClustersDesc *prom.Desc - mAffiliatesDesc *prom.Desc - mEndpointsDesc *prom.Desc - mSubscriptionsDesc *prom.Desc - mGCRuns prom.Counter - mGCClusters prom.Counter - mGCAffiliates prom.Counter -} - -// NewState create new instance of State. -func NewState(logger *zap.Logger) *State { - return &State{ - clusters: concurrent.NewHashTrieMap[string, *Cluster](), - logger: logger, - mClustersDesc: prom.NewDesc( - "discovery_state_clusters", - "The current number of clusters in the state.", - nil, nil, - ), - mAffiliatesDesc: prom.NewDesc( - "discovery_state_affiliates", - "The current number of affiliates in the state.", - nil, nil, - ), - mEndpointsDesc: prom.NewDesc( - "discovery_state_endpoints", - "The current number of endpoints in the state.", - nil, nil, - ), - mSubscriptionsDesc: prom.NewDesc( - "discovery_state_subscriptions", - "The current number of subscriptions in the state.", - nil, nil, - ), - mGCRuns: prom.NewCounter(prom.CounterOpts{ - Name: "discovery_state_gc_runs_total", - Help: "The number of GC runs.", - }), - mGCClusters: prom.NewCounter(prom.CounterOpts{ - Name: "discovery_state_gc_clusters_total", - Help: "The total number of GC'ed clusters.", - }), - mGCAffiliates: prom.NewCounter(prom.CounterOpts{ - Name: "discovery_state_gc_affiliates_total", - Help: "The total number of GC'ed affiliates.", - }), - } -} - -// GetCluster returns cluster by ID, creating it if needed. -func (state *State) GetCluster(id string) *Cluster { - if cluster, ok := state.clusters.Load(id); ok { - return cluster - } - - cluster, loaded := state.clusters.LoadOrStore(id, NewCluster(id)) - if !loaded { - state.logger.Debug("cluster created", zap.String("cluster_id", id)) - } - - return cluster -} - -// GarbageCollect recursively each cluster, and remove empty clusters. -func (state *State) GarbageCollect(now time.Time) (removedClusters, removedAffiliates int) { - state.clusters.Enumerate(func(key string, cluster *Cluster) bool { - ra, empty := cluster.GarbageCollect(now) - removedAffiliates += ra - - if empty { - state.clusters.CompareAndDelete(key, cluster) - state.logger.Debug("cluster removed", zap.String("cluster_id", key)) - - removedClusters++ - } - - return true - }) - - state.mGCRuns.Inc() - state.mGCClusters.Add(float64(removedClusters)) - state.mGCAffiliates.Add(float64(removedAffiliates)) - - return -} - -// RunGC runs the garbage collection on interval. -func (state *State) RunGC(ctx context.Context, logger *zap.Logger, interval time.Duration) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for ctx.Err() == nil { - removedClusters, removedAffiliates := state.GarbageCollect(time.Now()) - clusters, affiliates, endpoints, subscriptions := state.stats() - - logFunc := logger.Debug - if removedClusters > 0 || removedAffiliates > 0 { - logFunc = logger.Info - } - - logFunc( - "garbage collection run", - zap.Int("removed_clusters", removedClusters), - zap.Int("removed_affiliates", removedAffiliates), - zap.Int("current_clusters", clusters), - zap.Int("current_affiliates", affiliates), - zap.Int("current_endpoints", endpoints), - zap.Int("current_subscriptions", subscriptions), - ) - - select { - case <-ctx.Done(): - case <-ticker.C: - } - } -} - -func (state *State) stats() (clusters, affiliates, endpoints, subscriptions int) { - state.clusters.Enumerate(func(_ string, cluster *Cluster) bool { - clusters++ - - a, e, s := cluster.stats() - affiliates += a - endpoints += e - subscriptions += s - - return true - }) - - return -} - -// Describe implements prom.Collector interface. -func (state *State) Describe(ch chan<- *prom.Desc) { - prom.DescribeByCollect(state, ch) -} - -// Collect implements prom.Collector interface. -func (state *State) Collect(ch chan<- prom.Metric) { - clusters, affiliates, endpoints, subscriptions := state.stats() - - ch <- prom.MustNewConstMetric(state.mClustersDesc, prom.GaugeValue, float64(clusters)) - ch <- prom.MustNewConstMetric(state.mAffiliatesDesc, prom.GaugeValue, float64(affiliates)) - ch <- prom.MustNewConstMetric(state.mEndpointsDesc, prom.GaugeValue, float64(endpoints)) - ch <- prom.MustNewConstMetric(state.mSubscriptionsDesc, prom.GaugeValue, float64(subscriptions)) - - ch <- state.mGCRuns - ch <- state.mGCClusters - ch <- state.mGCAffiliates -} - -// Check interfaces. -var ( - _ prom.Collector = (*State)(nil) -) diff --git a/internal/state/storage/storage_bench_test.go b/internal/state/storage/storage_bench_test.go index 445bd68..81d3bee 100644 --- a/internal/state/storage/storage_bench_test.go +++ b/internal/state/storage/storage_bench_test.go @@ -13,8 +13,8 @@ import ( "go.uber.org/zap" storagepb "github.com/siderolabs/discovery-service/api/storage" - "github.com/siderolabs/discovery-service/internal/state" "github.com/siderolabs/discovery-service/internal/state/storage" + "github.com/siderolabs/discovery-service/pkg/state" ) func BenchmarkExport(b *testing.B) { diff --git a/pkg/server/landing_test.go b/pkg/server/landing_test.go index 8869d81..c5925a7 100644 --- a/pkg/server/landing_test.go +++ b/pkg/server/landing_test.go @@ -18,7 +18,8 @@ import ( "go.uber.org/zap/zaptest" "github.com/siderolabs/discovery-service/internal/landing" - "github.com/siderolabs/discovery-service/internal/state" + internalstate "github.com/siderolabs/discovery-service/internal/state" + "github.com/siderolabs/discovery-service/pkg/state" ) // TestInspectHandler tests the /inspect endpoint. @@ -33,14 +34,14 @@ func TestInspectHanlder(t *testing.T) { t.Cleanup(cancel) // add affiliates to the cluster "fake1" - err := testCluster.WithAffiliate("af1", func(affiliate *state.Affiliate) error { + err := testCluster.WithAffiliate("af1", func(affiliate *internalstate.Affiliate) error { affiliate.Update([]byte("data1"), now.Add(time.Minute)) return nil }) require.NoError(t, err) - err = testCluster.WithAffiliate("af2", func(affiliate *state.Affiliate) error { + err = testCluster.WithAffiliate("af2", func(affiliate *internalstate.Affiliate) error { affiliate.Update([]byte("data2"), now.Add(time.Minute)) return nil diff --git a/pkg/server/server.go b/pkg/server/server.go index 5bee6c5..a50b2ce 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -17,7 +17,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/siderolabs/discovery-service/internal/state" + internalstate "github.com/siderolabs/discovery-service/internal/state" + "github.com/siderolabs/discovery-service/pkg/state" ) const updateBuffer = 32 @@ -104,7 +105,7 @@ func (srv *ClusterServer) AffiliateUpdate(_ context.Context, req *pb.AffiliateUp return nil, err } - if err := srv.state.GetCluster(req.ClusterId).WithAffiliate(req.AffiliateId, func(affiliate *state.Affiliate) error { + if err := srv.state.GetCluster(req.ClusterId).WithAffiliate(req.AffiliateId, func(affiliate *internalstate.Affiliate) error { expiration := time.Now().Add(req.Ttl.AsDuration()) if len(req.AffiliateData) > 0 { @@ -114,9 +115,9 @@ func (srv *ClusterServer) AffiliateUpdate(_ context.Context, req *pb.AffiliateUp return affiliate.MergeEndpoints(req.AffiliateEndpoints, expiration) }); err != nil { switch { - case errors.Is(err, state.ErrTooManyEndpoints): + case errors.Is(err, internalstate.ErrTooManyEndpoints): return nil, status.Error(codes.ResourceExhausted, err.Error()) - case errors.Is(err, state.ErrTooManyAffiliates): + case errors.Is(err, internalstate.ErrTooManyAffiliates): return nil, status.Error(codes.ResourceExhausted, err.Error()) default: return nil, err @@ -170,7 +171,7 @@ func (srv *ClusterServer) Watch(req *pb.WatchRequest, server pb.Cluster_WatchSer } // make enough room to handle connection issues - updates := make(chan *state.Notification, updateBuffer) + updates := make(chan *internalstate.Notification, updateBuffer) snapshot, subscription := srv.state.GetCluster(req.ClusterId).Subscribe(updates) defer subscription.Close() diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 89c24f4..362e50d 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -32,9 +32,9 @@ import ( "github.com/siderolabs/discovery-service/internal/limiter" _ "github.com/siderolabs/discovery-service/internal/proto" - "github.com/siderolabs/discovery-service/internal/state" "github.com/siderolabs/discovery-service/pkg/limits" "github.com/siderolabs/discovery-service/pkg/server" + "github.com/siderolabs/discovery-service/pkg/state" ) func checkMetrics(t *testing.T, c prom.Collector) { diff --git a/pkg/server/validate.go b/pkg/server/validate.go index d7e1922..1ec4b82 100644 --- a/pkg/server/validate.go +++ b/pkg/server/validate.go @@ -16,11 +16,11 @@ import ( func validateClusterID(id string) error { if len(id) < 1 { - return status.Errorf(codes.InvalidArgument, "cluster ID can't be empty") + return status.Errorf(codes.InvalidArgument, "cluster id can't be empty") } if len(id) > limits.ClusterIDMax { - return status.Errorf(codes.InvalidArgument, "cluster ID is too long") + return status.Errorf(codes.InvalidArgument, "cluster id is too long") } return nil @@ -28,11 +28,11 @@ func validateClusterID(id string) error { func validateAffiliateID(id string) error { if len(id) < 1 { - return status.Errorf(codes.InvalidArgument, "affiliate ID can't be empty") + return status.Errorf(codes.InvalidArgument, "affiliate id can't be empty") } if len(id) > limits.AffiliateIDMax { - return status.Errorf(codes.InvalidArgument, "affiliate ID is too long") + return status.Errorf(codes.InvalidArgument, "affiliate id is too long") } return nil diff --git a/pkg/state/snapshot.go b/pkg/state/snapshot.go new file mode 100644 index 0000000..33ee5dd --- /dev/null +++ b/pkg/state/snapshot.go @@ -0,0 +1,58 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package state + +import ( + "fmt" + + storagepb "github.com/siderolabs/discovery-service/api/storage" + internalstate "github.com/siderolabs/discovery-service/internal/state" +) + +// ExportClusterSnapshots exports all cluster snapshots and calls the provided function for each one. +// +// Implements storage.Snapshotter interface. +func (state *State) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error { + var err error + + // reuse the same snapshotin each iteration + clusterSnapshot := &storagepb.ClusterSnapshot{} + + state.clusters.Enumerate(func(_ string, cluster *internalstate.Cluster) bool { + cluster.Snapshot(clusterSnapshot) + + err = f(clusterSnapshot) + + return err == nil + }) + + return err +} + +// ImportClusterSnapshots imports cluster snapshots by calling the provided function until it returns false. +// +// Implements storage.Snapshotter interface. +func (state *State) ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error { + for { + clusterSnapshot, ok, err := f() + if err != nil { + return err + } + + if !ok { + break + } + + cluster := internalstate.ClusterFromSnapshot(clusterSnapshot) + + _, loaded := state.clusters.LoadOrStore(cluster.ID(), cluster) + if loaded { + return fmt.Errorf("cluster %q already exists", cluster.ID()) + } + } + + return nil +} diff --git a/pkg/state/state.go b/pkg/state/state.go new file mode 100644 index 0000000..876e7a9 --- /dev/null +++ b/pkg/state/state.go @@ -0,0 +1,179 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +// Package state implements server state. +package state + +import ( + "context" + "time" + + prom "github.com/prometheus/client_golang/prometheus" + "github.com/siderolabs/gen/concurrent" + "go.uber.org/zap" + + internalstate "github.com/siderolabs/discovery-service/internal/state" +) + +// State keeps the discovery service state. +type State struct { + clusters *concurrent.HashTrieMap[string, *internalstate.Cluster] + logger *zap.Logger + + mClustersDesc *prom.Desc + mAffiliatesDesc *prom.Desc + mEndpointsDesc *prom.Desc + mSubscriptionsDesc *prom.Desc + mGCRuns prom.Counter + mGCClusters prom.Counter + mGCAffiliates prom.Counter +} + +// NewState create new instance of State. +func NewState(logger *zap.Logger) *State { + return &State{ + clusters: concurrent.NewHashTrieMap[string, *internalstate.Cluster](), + logger: logger, + mClustersDesc: prom.NewDesc( + "discovery_state_clusters", + "The current number of clusters in the state.", + nil, nil, + ), + mAffiliatesDesc: prom.NewDesc( + "discovery_state_affiliates", + "The current number of affiliates in the state.", + nil, nil, + ), + mEndpointsDesc: prom.NewDesc( + "discovery_state_endpoints", + "The current number of endpoints in the state.", + nil, nil, + ), + mSubscriptionsDesc: prom.NewDesc( + "discovery_state_subscriptions", + "The current number of subscriptions in the state.", + nil, nil, + ), + mGCRuns: prom.NewCounter(prom.CounterOpts{ + Name: "discovery_state_gc_runs_total", + Help: "The number of GC runs.", + }), + mGCClusters: prom.NewCounter(prom.CounterOpts{ + Name: "discovery_state_gc_clusters_total", + Help: "The total number of GC'ed clusters.", + }), + mGCAffiliates: prom.NewCounter(prom.CounterOpts{ + Name: "discovery_state_gc_affiliates_total", + Help: "The total number of GC'ed affiliates.", + }), + } +} + +// GetCluster returns cluster by id, creating it if needed. +func (state *State) GetCluster(id string) *internalstate.Cluster { + if cluster, ok := state.clusters.Load(id); ok { + return cluster + } + + cluster, loaded := state.clusters.LoadOrStore(id, internalstate.NewCluster(id)) + if !loaded { + state.logger.Debug("cluster created", zap.String("cluster_id", id)) + } + + return cluster +} + +// GarbageCollect recursively each cluster, and remove empty clusters. +func (state *State) GarbageCollect(now time.Time) (removedClusters, removedAffiliates int) { + state.clusters.Enumerate(func(key string, cluster *internalstate.Cluster) bool { + ra, empty := cluster.GarbageCollect(now) + removedAffiliates += ra + + if empty { + state.clusters.CompareAndDelete(key, cluster) + state.logger.Debug("cluster removed", zap.String("cluster_id", key)) + + removedClusters++ + } + + return true + }) + + state.mGCRuns.Inc() + state.mGCClusters.Add(float64(removedClusters)) + state.mGCAffiliates.Add(float64(removedAffiliates)) + + return +} + +// RunGC runs the garbage collection on interval. +func (state *State) RunGC(ctx context.Context, logger *zap.Logger, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for ctx.Err() == nil { + removedClusters, removedAffiliates := state.GarbageCollect(time.Now()) + clusters, affiliates, endpoints, subscriptions := state.stats() + + logFunc := logger.Debug + if removedClusters > 0 || removedAffiliates > 0 { + logFunc = logger.Info + } + + logFunc( + "garbage collection run", + zap.Int("removed_clusters", removedClusters), + zap.Int("removed_affiliates", removedAffiliates), + zap.Int("current_clusters", clusters), + zap.Int("current_affiliates", affiliates), + zap.Int("current_endpoints", endpoints), + zap.Int("current_subscriptions", subscriptions), + ) + + select { + case <-ctx.Done(): + case <-ticker.C: + } + } +} + +func (state *State) stats() (clusters, affiliates, endpoints, subscriptions int) { + state.clusters.Enumerate(func(_ string, cluster *internalstate.Cluster) bool { + clusters++ + + a, e, s := cluster.Stats() + affiliates += a + endpoints += e + subscriptions += s + + return true + }) + + return +} + +// Describe implements prom.Collector interface. +func (state *State) Describe(ch chan<- *prom.Desc) { + prom.DescribeByCollect(state, ch) +} + +// Collect implements prom.Collector interface. +func (state *State) Collect(ch chan<- prom.Metric) { + clusters, affiliates, endpoints, subscriptions := state.stats() + + ch <- prom.MustNewConstMetric(state.mClustersDesc, prom.GaugeValue, float64(clusters)) + ch <- prom.MustNewConstMetric(state.mAffiliatesDesc, prom.GaugeValue, float64(affiliates)) + ch <- prom.MustNewConstMetric(state.mEndpointsDesc, prom.GaugeValue, float64(endpoints)) + ch <- prom.MustNewConstMetric(state.mSubscriptionsDesc, prom.GaugeValue, float64(subscriptions)) + + ch <- state.mGCRuns + ch <- state.mGCClusters + ch <- state.mGCAffiliates +} + +// Check interfaces. +var ( + _ prom.Collector = (*State)(nil) +) diff --git a/internal/state/state_test.go b/pkg/state/state_test.go similarity index 90% rename from internal/state/state_test.go rename to pkg/state/state_test.go index d1d2e4c..05aa0ff 100644 --- a/internal/state/state_test.go +++ b/pkg/state/state_test.go @@ -15,7 +15,8 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" - "github.com/siderolabs/discovery-service/internal/state" + internalstate "github.com/siderolabs/discovery-service/internal/state" + "github.com/siderolabs/discovery-service/pkg/state" ) func checkMetrics(t *testing.T, c prom.Collector) { @@ -41,7 +42,7 @@ func TestState(t *testing.T) { assert.Equal(t, 0, removedAffiliates) st.GetCluster("id1") - assert.NoError(t, st.GetCluster("id2").WithAffiliate("af1", func(affiliate *state.Affiliate) error { + assert.NoError(t, st.GetCluster("id2").WithAffiliate("af1", func(affiliate *internalstate.Affiliate) error { affiliate.Update([]byte("data1"), now.Add(time.Minute)) return nil