Skip to content

Commit

Permalink
Merge pull request #43 from golift/dn2_round_robin
Browse files Browse the repository at this point in the history
Round Robin Feature
  • Loading branch information
davidnewhall authored Apr 16, 2024
2 parents 14c7ce5 + c09a6ad commit fa31e5f
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 10 deletions.
87 changes: 84 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ type Config struct {
// What to reset the backoff to when max is hit.
// Set this to max to stay at max.
BackoffReset time.Duration
// If RRConfig is non-nil then the servers provided in Targets are
// tried sequentially after they cannot be reached in RetryInterval.
*RoundRobinConfig
// If this is true, then the servers provided in Targets are tried
// sequentially after they cannot be reached in RetryInterval.
// Handler is an optional custom handler for all proxied requests.
// Leaving this nil makes all requests use an empty http.Client.
Handler func(http.ResponseWriter, *http.Request)
Expand All @@ -54,13 +59,24 @@ type Config struct {
mulch.Logger
}

// RoundRobinConfig is the configuration specific to round robin target acquisition.
type RoundRobinConfig struct {
// When RoundRobin is true, this configures how long a server is
// retried unsuccessfully before trying the next server in Targets list.
RetryInterval time.Duration
// Callback is called when the tunnel changes.
Callback func(ctx context.Context, socket string)
}

// Client connects to one or more Server using HTTP websockets.
// The Server can then send HTTP requests to execute.
type Client struct {
*Config
client *http.Client
dialer *websocket.Dialer
pools map[string]*Pool
lastConn time.Time // keeps track of last successful connection to our active target.
target int // keeps track of active target in round robin mode.
client *http.Client
dialer *websocket.Dialer
pools map[string]*Pool
}

// NewConfig creates a new ProxyConfig.
Expand Down Expand Up @@ -94,7 +110,16 @@ func NewClient(config *Config) *Client {
config.BackoffReset = DefaultBackoffReset
}

if config.RoundRobinConfig != nil {
if len(config.Targets) <= 1 {
config.RoundRobinConfig = nil
} else if config.RoundRobinConfig.RetryInterval == 0 {
config.RoundRobinConfig.RetryInterval = time.Minute
}
}

return &Client{
target: -1,
Config: config,
client: &http.Client{},
dialer: &websocket.Dialer{
Expand All @@ -107,11 +132,56 @@ func NewClient(config *Config) *Client {

// Start the Proxy.
func (c *Client) Start(ctx context.Context) {
if c.Config.RoundRobinConfig != nil {
c.startOnePool(ctx)
} else {
c.startAllPools(ctx)
}
}

func (c *Client) startAllPools(ctx context.Context) {
for _, target := range c.Config.Targets {
if c.pools[target] != nil && !c.pools[target].shutdown {
panic("Attempt to overwrite active mulery client pool!")
}

c.pools[target] = StartPool(ctx, c, target, c.Config.SecretKey)
}
}

// startOnePool happens in round robin mode.
func (c *Client) startOnePool(ctx context.Context) {
c.target++
if c.target >= len(c.Config.Targets) {
c.target = 0
}

target := c.Config.Targets[c.target]
c.lastConn = time.Now()

if c.pools[target] != nil && !c.pools[target].shutdown {
panic("Attempt to overwrite active mulery client pool!")
}

if c.Callback != nil {
c.Callback(ctx, target)
}

c.pools[target] = StartPool(ctx, c, target, c.Config.SecretKey)
}

// restart calls shutdown and start inside a go routine.
// Allows a failing pool to restart the client.
// This is only useful in RoundRobin mode, do not call it otherwise.
func (c *Client) restart(ctx context.Context) {
c.Printf("Restarting tunnel to connect to next websocket target.")

go func() {
c.Shutdown()
c.Start(ctx)
}()
}

// Shutdown the Proxy.
func (c *Client) Shutdown() {
for _, pool := range c.pools {
Expand All @@ -123,3 +193,14 @@ func (c *Client) Shutdown() {
func (c *Client) GetID() string {
return mulch.HashKeyID(c.SecretKey, c.ID)
}

// PoolStats returns stats for all pools.
func (c *Client) PoolStats() map[string]*PoolSize {
sizes := map[string]*PoolSize{}

for socket, pool := range c.pools {
sizes[socket] = pool.size()
}

return sizes
}
44 changes: 37 additions & 7 deletions client/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Pool struct {
target string
secretKey string
connections []*Connection
disconnects int
done chan struct{}
getSize chan struct{}
repSize chan *PoolSize
Expand All @@ -24,10 +25,14 @@ type Pool struct {

// PoolSize represent the number of open connections per status.
type PoolSize struct {
Connecting int
Idle int
Running int
Total int
Disconnects int
Connecting int
Idle int
Running int
Total int
LastConn time.Time
LastTry time.Time
Active bool
}

// StartPool creates and starts a pool in one command.
Expand Down Expand Up @@ -122,6 +127,21 @@ func (p *Pool) connector(ctx context.Context, now time.Time) {
toCreate = p.client.Config.PoolMaxSize - poolSize.Total
}

p.fillConnectionPool(ctx, now, toCreate)
}

func (p *Pool) fillConnectionPool(ctx context.Context, now time.Time, toCreate int) {
if p.client.RoundRobinConfig != nil {
if toCreate == 0 {
// Keep this up to date, or the logic will skip to the next server prematurely.
p.client.lastConn = now
} else if now.Sub(p.client.lastConn) > p.client.RetryInterval {
// We need more connections and the last successful connection was too long ago.
// Restart and skip to the next server in the round robin target list.
defer p.client.restart(ctx)
}
}

// Try to reach ideal pool size.
for ; toCreate > 0; toCreate-- {
// This is the only place a connection is added to the pool.
Expand Down Expand Up @@ -153,7 +173,8 @@ func (p *Pool) remove(connection *Connection) {
if connection != conn {
filtered = append(filtered, conn)
} else {
conn.Close()
p.disconnects++
conn.Close() //nolint:wsl
}
}

Expand All @@ -162,8 +183,10 @@ func (p *Pool) remove(connection *Connection) {

// Shutdown and close all connections in the pool.
func (p *Pool) Shutdown() {
p.shutdown = true
close(p.done)
if !p.shutdown {
p.shutdown = true
close(p.done)
}
}

func (ps *PoolSize) String() string {
Expand All @@ -180,6 +203,13 @@ func (p *Pool) Size() *PoolSize {
func (p *Pool) size() *PoolSize {
poolSize := new(PoolSize)
poolSize.Total = len(p.connections)
poolSize.Disconnects = p.disconnects
poolSize.LastTry = p.lastTry
poolSize.Active = !p.shutdown

if poolSize.LastConn = p.lastTry; !p.shutdown && p.client.RoundRobinConfig != nil {
poolSize.LastConn = p.client.lastConn
}

for _, connection := range p.connections {
switch connection.Status() {
Expand Down

0 comments on commit fa31e5f

Please sign in to comment.