Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move state.State to public package #56

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/discovery-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ import (
"github.com/siderolabs/discovery-service/internal/landing"
"github.com/siderolabs/discovery-service/internal/limiter"
_ "github.com/siderolabs/discovery-service/internal/proto"
"github.com/siderolabs/discovery-service/internal/state"
"github.com/siderolabs/discovery-service/internal/state/storage"
"github.com/siderolabs/discovery-service/pkg/limits"
"github.com/siderolabs/discovery-service/pkg/server"
"github.com/siderolabs/discovery-service/pkg/state"
)

var (
Expand Down
5 changes: 3 additions & 2 deletions internal/landing/landing.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (

"go.uber.org/zap"

"github.com/siderolabs/discovery-service/internal/state"
internalstate "github.com/siderolabs/discovery-service/internal/state"
"github.com/siderolabs/discovery-service/pkg/state"
)

//go:embed "html/index.html"
Expand All @@ -28,7 +29,7 @@ var inspectTemplate []byte
// ClusterInspectData represents all affiliate data asssociated with a cluster.
type ClusterInspectData struct {
ClusterID string
Affiliates []*state.AffiliateExport
Affiliates []*internalstate.AffiliateExport
}

var inspectPage = template.Must(template.New("inspect").Parse(string(inspectTemplate)))
Expand Down
10 changes: 8 additions & 2 deletions internal/state/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@ type Cluster struct {
subscriptionsMu sync.Mutex
}

// NewCluster creates new cluster with specified ID.
// NewCluster creates new cluster with specified id.
func NewCluster(id string) *Cluster {
return &Cluster{
id: id,
affiliates: map[string]*Affiliate{},
}
}

// ID returns the cluster id.
func (cluster *Cluster) ID() string {
return cluster.id
}

// WithAffiliate runs a function against the affiliate.
//
// Cluster state is locked while the function is running.
Expand Down Expand Up @@ -174,7 +179,8 @@ func (cluster *Cluster) notify(notifications ...*Notification) {
}
}

func (cluster *Cluster) stats() (affiliates, endpoints, subscriptions int) {
// Stats returns the number of affiliates, endpoints and subscriptions.
func (cluster *Cluster) Stats() (affiliates, endpoints, subscriptions int) {
cluster.affiliatesMu.Lock()

affiliates = len(cluster.affiliates)
Expand Down
52 changes: 4 additions & 48 deletions internal/state/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package state

import (
"fmt"
"slices"

"github.com/siderolabs/gen/xslices"
Expand All @@ -15,52 +14,8 @@ import (
storagepb "github.com/siderolabs/discovery-service/api/storage"
)

// ExportClusterSnapshots exports all cluster snapshots and calls the provided function for each one.
//
// Implements storage.Snapshotter interface.
func (state *State) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error {
var err error

// reuse the same snapshotin each iteration
clusterSnapshot := &storagepb.ClusterSnapshot{}

state.clusters.Enumerate(func(_ string, cluster *Cluster) bool {
snapshotCluster(cluster, clusterSnapshot)

err = f(clusterSnapshot)

return err == nil
})

return err
}

// ImportClusterSnapshots imports cluster snapshots by calling the provided function until it returns false.
//
// Implements storage.Snapshotter interface.
func (state *State) ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error {
for {
clusterSnapshot, ok, err := f()
if err != nil {
return err
}

if !ok {
break
}

cluster := clusterFromSnapshot(clusterSnapshot)

_, loaded := state.clusters.LoadOrStore(cluster.id, cluster)
if loaded {
return fmt.Errorf("cluster %q already exists", cluster.id)
}
}

return nil
}

func snapshotCluster(cluster *Cluster, snapshot *storagepb.ClusterSnapshot) {
// Snapshot takes a snapshot of the cluster into the given snapshot reference.
func (cluster *Cluster) Snapshot(snapshot *storagepb.ClusterSnapshot) {
cluster.affiliatesMu.Lock()
defer cluster.affiliatesMu.Unlock()

Expand Down Expand Up @@ -110,7 +65,8 @@ func snapshotCluster(cluster *Cluster, snapshot *storagepb.ClusterSnapshot) {
}
}

func clusterFromSnapshot(snapshot *storagepb.ClusterSnapshot) *Cluster {
// ClusterFromSnapshot creates a new cluster from the provided snapshot.
func ClusterFromSnapshot(snapshot *storagepb.ClusterSnapshot) *Cluster {
return &Cluster{
id: snapshot.Id,
affiliates: xslices.ToMap(snapshot.Affiliates, affiliateFromSnapshot),
Expand Down
172 changes: 1 addition & 171 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,175 +3,5 @@
// Use of this software is governed by the Business Source License
// included in the LICENSE file.

// Package state implements server state with clusters, affiliates, subscriptions, etc.
// Package state contains internal state-related components such as affiliates, subscriptions, etc.
package state

import (
"context"
"time"

prom "github.com/prometheus/client_golang/prometheus"
"github.com/siderolabs/gen/concurrent"
"go.uber.org/zap"
)

// State keeps the discovery service state.
type State struct {
clusters *concurrent.HashTrieMap[string, *Cluster]
logger *zap.Logger

mClustersDesc *prom.Desc
mAffiliatesDesc *prom.Desc
mEndpointsDesc *prom.Desc
mSubscriptionsDesc *prom.Desc
mGCRuns prom.Counter
mGCClusters prom.Counter
mGCAffiliates prom.Counter
}

// NewState create new instance of State.
func NewState(logger *zap.Logger) *State {
return &State{
clusters: concurrent.NewHashTrieMap[string, *Cluster](),
logger: logger,
mClustersDesc: prom.NewDesc(
"discovery_state_clusters",
"The current number of clusters in the state.",
nil, nil,
),
mAffiliatesDesc: prom.NewDesc(
"discovery_state_affiliates",
"The current number of affiliates in the state.",
nil, nil,
),
mEndpointsDesc: prom.NewDesc(
"discovery_state_endpoints",
"The current number of endpoints in the state.",
nil, nil,
),
mSubscriptionsDesc: prom.NewDesc(
"discovery_state_subscriptions",
"The current number of subscriptions in the state.",
nil, nil,
),
mGCRuns: prom.NewCounter(prom.CounterOpts{
Name: "discovery_state_gc_runs_total",
Help: "The number of GC runs.",
}),
mGCClusters: prom.NewCounter(prom.CounterOpts{
Name: "discovery_state_gc_clusters_total",
Help: "The total number of GC'ed clusters.",
}),
mGCAffiliates: prom.NewCounter(prom.CounterOpts{
Name: "discovery_state_gc_affiliates_total",
Help: "The total number of GC'ed affiliates.",
}),
}
}

// GetCluster returns cluster by ID, creating it if needed.
func (state *State) GetCluster(id string) *Cluster {
if cluster, ok := state.clusters.Load(id); ok {
return cluster
}

cluster, loaded := state.clusters.LoadOrStore(id, NewCluster(id))
if !loaded {
state.logger.Debug("cluster created", zap.String("cluster_id", id))
}

return cluster
}

// GarbageCollect recursively each cluster, and remove empty clusters.
func (state *State) GarbageCollect(now time.Time) (removedClusters, removedAffiliates int) {
state.clusters.Enumerate(func(key string, cluster *Cluster) bool {
ra, empty := cluster.GarbageCollect(now)
removedAffiliates += ra

if empty {
state.clusters.CompareAndDelete(key, cluster)
state.logger.Debug("cluster removed", zap.String("cluster_id", key))

removedClusters++
}

return true
})

state.mGCRuns.Inc()
state.mGCClusters.Add(float64(removedClusters))
state.mGCAffiliates.Add(float64(removedAffiliates))

return
}

// RunGC runs the garbage collection on interval.
func (state *State) RunGC(ctx context.Context, logger *zap.Logger, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()

for ctx.Err() == nil {
removedClusters, removedAffiliates := state.GarbageCollect(time.Now())
clusters, affiliates, endpoints, subscriptions := state.stats()

logFunc := logger.Debug
if removedClusters > 0 || removedAffiliates > 0 {
logFunc = logger.Info
}

logFunc(
"garbage collection run",
zap.Int("removed_clusters", removedClusters),
zap.Int("removed_affiliates", removedAffiliates),
zap.Int("current_clusters", clusters),
zap.Int("current_affiliates", affiliates),
zap.Int("current_endpoints", endpoints),
zap.Int("current_subscriptions", subscriptions),
)

select {
case <-ctx.Done():
case <-ticker.C:
}
}
}

func (state *State) stats() (clusters, affiliates, endpoints, subscriptions int) {
state.clusters.Enumerate(func(_ string, cluster *Cluster) bool {
clusters++

a, e, s := cluster.stats()
affiliates += a
endpoints += e
subscriptions += s

return true
})

return
}

// Describe implements prom.Collector interface.
func (state *State) Describe(ch chan<- *prom.Desc) {
prom.DescribeByCollect(state, ch)
}

// Collect implements prom.Collector interface.
func (state *State) Collect(ch chan<- prom.Metric) {
clusters, affiliates, endpoints, subscriptions := state.stats()

ch <- prom.MustNewConstMetric(state.mClustersDesc, prom.GaugeValue, float64(clusters))
ch <- prom.MustNewConstMetric(state.mAffiliatesDesc, prom.GaugeValue, float64(affiliates))
ch <- prom.MustNewConstMetric(state.mEndpointsDesc, prom.GaugeValue, float64(endpoints))
ch <- prom.MustNewConstMetric(state.mSubscriptionsDesc, prom.GaugeValue, float64(subscriptions))

ch <- state.mGCRuns
ch <- state.mGCClusters
ch <- state.mGCAffiliates
}

// Check interfaces.
var (
_ prom.Collector = (*State)(nil)
)
2 changes: 1 addition & 1 deletion internal/state/storage/storage_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
"go.uber.org/zap"

storagepb "github.com/siderolabs/discovery-service/api/storage"
"github.com/siderolabs/discovery-service/internal/state"
"github.com/siderolabs/discovery-service/internal/state/storage"
"github.com/siderolabs/discovery-service/pkg/state"
)

func BenchmarkExport(b *testing.B) {
Expand Down
7 changes: 4 additions & 3 deletions pkg/server/landing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (
"go.uber.org/zap/zaptest"

"github.com/siderolabs/discovery-service/internal/landing"
"github.com/siderolabs/discovery-service/internal/state"
internalstate "github.com/siderolabs/discovery-service/internal/state"
"github.com/siderolabs/discovery-service/pkg/state"
)

// TestInspectHandler tests the /inspect endpoint.
Expand All @@ -33,14 +34,14 @@ func TestInspectHanlder(t *testing.T) {
t.Cleanup(cancel)

// add affiliates to the cluster "fake1"
err := testCluster.WithAffiliate("af1", func(affiliate *state.Affiliate) error {
err := testCluster.WithAffiliate("af1", func(affiliate *internalstate.Affiliate) error {
affiliate.Update([]byte("data1"), now.Add(time.Minute))

return nil
})
require.NoError(t, err)

err = testCluster.WithAffiliate("af2", func(affiliate *state.Affiliate) error {
err = testCluster.WithAffiliate("af2", func(affiliate *internalstate.Affiliate) error {
affiliate.Update([]byte("data2"), now.Add(time.Minute))

return nil
Expand Down
11 changes: 6 additions & 5 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/siderolabs/discovery-service/internal/state"
internalstate "github.com/siderolabs/discovery-service/internal/state"
"github.com/siderolabs/discovery-service/pkg/state"
)

const updateBuffer = 32
Expand Down Expand Up @@ -104,7 +105,7 @@ func (srv *ClusterServer) AffiliateUpdate(_ context.Context, req *pb.AffiliateUp
return nil, err
}

if err := srv.state.GetCluster(req.ClusterId).WithAffiliate(req.AffiliateId, func(affiliate *state.Affiliate) error {
if err := srv.state.GetCluster(req.ClusterId).WithAffiliate(req.AffiliateId, func(affiliate *internalstate.Affiliate) error {
expiration := time.Now().Add(req.Ttl.AsDuration())

if len(req.AffiliateData) > 0 {
Expand All @@ -114,9 +115,9 @@ func (srv *ClusterServer) AffiliateUpdate(_ context.Context, req *pb.AffiliateUp
return affiliate.MergeEndpoints(req.AffiliateEndpoints, expiration)
}); err != nil {
switch {
case errors.Is(err, state.ErrTooManyEndpoints):
case errors.Is(err, internalstate.ErrTooManyEndpoints):
return nil, status.Error(codes.ResourceExhausted, err.Error())
case errors.Is(err, state.ErrTooManyAffiliates):
case errors.Is(err, internalstate.ErrTooManyAffiliates):
return nil, status.Error(codes.ResourceExhausted, err.Error())
default:
return nil, err
Expand Down Expand Up @@ -170,7 +171,7 @@ func (srv *ClusterServer) Watch(req *pb.WatchRequest, server pb.Cluster_WatchSer
}

// make enough room to handle connection issues
updates := make(chan *state.Notification, updateBuffer)
updates := make(chan *internalstate.Notification, updateBuffer)

snapshot, subscription := srv.state.GetCluster(req.ClusterId).Subscribe(updates)
defer subscription.Close()
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import (

"github.com/siderolabs/discovery-service/internal/limiter"
_ "github.com/siderolabs/discovery-service/internal/proto"
"github.com/siderolabs/discovery-service/internal/state"
"github.com/siderolabs/discovery-service/pkg/limits"
"github.com/siderolabs/discovery-service/pkg/server"
"github.com/siderolabs/discovery-service/pkg/state"
)

func checkMetrics(t *testing.T, c prom.Collector) {
Expand Down
Loading