Skip to content

Commit

Permalink
Fix: Use heartbeats to correct for request token drift (#777)
Browse files Browse the repository at this point in the history
- Uses key events + request heartbeats to correct for token drift after
a particular gateway crashed unexpectedly

The consequence of this is that if a gateway that was actively handling
a request crashes or is forcibly terminated, there will be a 30 delay
before any active containers that were handling requests have the token
count incremented. This should fix "bricked" containers that had an
inaccurately low token count.
  • Loading branch information
luke-lombardi authored Dec 11, 2024
1 parent cae3de1 commit ee7a259
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 20 deletions.
55 changes: 40 additions & 15 deletions pkg/abstractions/endpoint/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ type RequestBuffer struct {
availableContainersLock sync.RWMutex
maxTokens int
isASGI bool
keyEventManager *common.KeyEventManager
keyEventChan chan common.KeyEvent
}

func NewRequestBuffer(
Expand All @@ -67,12 +69,13 @@ func NewRequestBuffer(
stubId string,
size int,
containerRepo repository.ContainerRepository,
keyEventManager *common.KeyEventManager,
stubConfig *types.StubConfigV1,
tailscale *network.Tailscale,
tsConfig types.TailscaleConfig,
isASGI bool,
) *RequestBuffer {
b := &RequestBuffer{
rb := &RequestBuffer{
ctx: ctx,
rdb: rdb,
workspace: workspace,
Expand All @@ -82,6 +85,8 @@ func NewRequestBuffer(
availableContainers: []container{},
availableContainersLock: sync.RWMutex{},
containerRepo: containerRepo,
keyEventManager: keyEventManager,
keyEventChan: make(chan common.KeyEvent),
httpClient: &http.Client{},
tailscale: tailscale,
tsConfig: tsConfig,
Expand All @@ -91,13 +96,38 @@ func NewRequestBuffer(

if stubConfig.ConcurrentRequests > 1 && isASGI {
// Floor is set to the number of workers
b.maxTokens = max(int(stubConfig.ConcurrentRequests), b.maxTokens)
rb.maxTokens = max(int(stubConfig.ConcurrentRequests), rb.maxTokens)
}

go b.discoverContainers()
go b.processRequests()
go rb.discoverContainers()
go rb.processRequests()

return b
// Listen for heartbeat key events
go rb.keyEventManager.ListenForPattern(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, "*", "*"), rb.keyEventChan)
go rb.handleHeartbeatEvents()

return rb
}

func (rb *RequestBuffer) handleHeartbeatEvents() {
for {
select {
case event := <-rb.keyEventChan:
operation := event.Operation

switch operation {
case common.KeyOperationSet, common.KeyOperationHSet, common.KeyOperationDel, common.KeyOperationExpire:
// Do nothing
case common.KeyOperationExpired:
if parts := strings.Split(event.Key, ":"); len(parts) >= 2 {
taskId, containerId := parts[len(parts)-2], parts[len(parts)-1]
rb.releaseRequestToken(containerId, taskId)
}
}
case <-rb.ctx.Done():
return
}
}
}

func (rb *RequestBuffer) ForwardRequest(ctx echo.Context, task *EndpointTask) error {
Expand Down Expand Up @@ -294,12 +324,7 @@ func (rb *RequestBuffer) acquireRequestToken(containerId string) error {
return nil
}

func (rb *RequestBuffer) releaseRequestToken(containerId string) error {
// TODO: if a gateway crashes before releasing the token, it could lead to a drift
// in the count of available request tokens for a particular container. To handle this
// we could move the release logic to the task implementation (e.g. task.Complete), so that
// it handles the release of the token and is not tied to a specific gateway

func (rb *RequestBuffer) releaseRequestToken(containerId, taskId string) error {
tokenKey := Keys.endpointRequestTokens(rb.workspace.Name, rb.stubId, containerId)

err := rb.rdb.Incr(rb.ctx, tokenKey).Err()
Expand All @@ -312,7 +337,7 @@ func (rb *RequestBuffer) releaseRequestToken(containerId string) error {
return err
}

return nil
return rb.rdb.Del(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, taskId, containerId)).Err()
}

func (rb *RequestBuffer) getHttpClient(address string) (*http.Client, error) {
Expand Down Expand Up @@ -492,15 +517,15 @@ func (rb *RequestBuffer) heartBeat(req *request, containerId string) {
ticker := time.NewTicker(endpointRequestHeartbeatInterval)
defer ticker.Stop()

rb.rdb.Set(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, req.task.msg.TaskId), containerId, endpointRequestHeartbeatInterval)
rb.rdb.Set(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, req.task.msg.TaskId, containerId), 1, endpointRequestHeartbeatInterval)
for {
select {
case <-ctx.Done():
return
case <-rb.ctx.Done():
return
case <-ticker.C:
rb.rdb.Set(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, req.task.msg.TaskId), containerId, endpointRequestHeartbeatInterval)
rb.rdb.Set(rb.ctx, Keys.endpointRequestHeartbeat(rb.workspace.Name, rb.stubId, req.task.msg.TaskId, containerId), 1, endpointRequestHeartbeatInterval)
}
}
}
Expand All @@ -510,7 +535,7 @@ func (rb *RequestBuffer) afterRequest(req *request, containerId string) {
req.done <- true
}()

defer rb.releaseRequestToken(containerId)
defer rb.releaseRequestToken(containerId, req.task.msg.TaskId)

// Set keep warm lock
if rb.stubConfig.KeepWarmSeconds == 0 {
Expand Down
8 changes: 4 additions & 4 deletions pkg/abstractions/endpoint/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (es *HttpEndpointService) getOrCreateEndpointInstance(ctx context.Context,
instance.isASGI = true
}

instance.buffer = NewRequestBuffer(autoscaledInstance.Ctx, es.rdb, &stub.Workspace, stubId, requestBufferSize, es.containerRepo, stubConfig, es.tailscale, es.config.Tailscale, instance.isASGI)
instance.buffer = NewRequestBuffer(autoscaledInstance.Ctx, es.rdb, &stub.Workspace, stubId, requestBufferSize, es.containerRepo, es.keyEventManager, stubConfig, es.tailscale, es.config.Tailscale, instance.isASGI)

// Embed autoscaled instance struct
instance.AutoscaledInstance = autoscaledInstance
Expand Down Expand Up @@ -314,7 +314,7 @@ var (
endpointKeepWarmLock string = "endpoint:%s:%s:keep_warm_lock:%s"
endpointInstanceLock string = "endpoint:%s:%s:instance_lock"
endpointRequestTokens string = "endpoint:%s:%s:request_tokens:%s"
endpointRequestHeartbeat string = "endpoint:%s:%s:request_heartbeat:%s"
endpointRequestHeartbeat string = "endpoint:%s:%s:request_heartbeat:%s:%s"
endpointServeLock string = "endpoint:%s:%s:serve_lock"
)

Expand All @@ -330,8 +330,8 @@ func (k *keys) endpointRequestTokens(workspaceName, stubId, containerId string)
return fmt.Sprintf(endpointRequestTokens, workspaceName, stubId, containerId)
}

func (k *keys) endpointRequestHeartbeat(workspaceName, stubId, taskId string) string {
return fmt.Sprintf(endpointRequestHeartbeat, workspaceName, stubId, taskId)
func (k *keys) endpointRequestHeartbeat(workspaceName, stubId, taskId, containerId string) string {
return fmt.Sprintf(endpointRequestHeartbeat, workspaceName, stubId, taskId, containerId)
}

func (k *keys) endpointServeLock(workspaceName, stubId string) string {
Expand Down
7 changes: 6 additions & 1 deletion pkg/abstractions/endpoint/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ func (t *EndpointTask) Cancel(ctx context.Context, reason types.TaskCancellation
}

func (t *EndpointTask) HeartBeat(ctx context.Context) (bool, error) {
heartbeatKey := Keys.endpointRequestHeartbeat(t.msg.WorkspaceName, t.msg.StubId, t.msg.TaskId)
task, err := t.es.backendRepo.GetTask(ctx, t.msg.TaskId)
if err != nil {
return false, err
}

heartbeatKey := Keys.endpointRequestHeartbeat(t.msg.WorkspaceName, t.msg.StubId, t.msg.TaskId, task.ContainerId)
exists, err := t.es.rdb.Exists(ctx, heartbeatKey).Result()
if err != nil {
return false, fmt.Errorf("failed to retrieve endpoint heartbeat key <%v>: %w", heartbeatKey, err)
Expand Down

0 comments on commit ee7a259

Please sign in to comment.