Skip to content

Commit

Permalink
chore: Use queueURL instead of queueName for sqs provider (aws#6035)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-innis authored Apr 13, 2024
1 parent 13cd913 commit e4bf5d7
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 41 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.22
require (
github.com/Pallinder/go-randomdata v1.2.0
github.com/PuerkitoBio/goquery v1.9.1
github.com/avast/retry-go v3.0.0+incompatible
github.com/aws/aws-sdk-go v1.51.16
github.com/aws/karpenter-provider-aws/tools/kompat v0.0.0-20240410220356-6b868db24881
github.com/awslabs/amazon-eks-ami/nodeadm v0.0.0-20240229193347-cfab22a10647
Expand Down Expand Up @@ -37,7 +38,6 @@ require (
contrib.go.opencensus.io/exporter/prometheus v0.4.2 // indirect
github.com/Masterminds/semver/v3 v3.2.1 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/avast/retry-go v3.0.0+incompatible // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/blendle/zapdriver v1.3.1 // indirect
Expand Down
4 changes: 3 additions & 1 deletion pkg/controllers/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ func NewControllers(ctx context.Context, sess *session.Session, clk clock.Clock,
controllerspricing.NewController(pricingProvider),
}
if options.FromContext(ctx).InterruptionQueue != "" {
controllers = append(controllers, interruption.NewController(kubeClient, clk, recorder, lo.Must(sqs.NewProvider(ctx, servicesqs.New(sess), options.FromContext(ctx).InterruptionQueue)), unavailableOfferings))
sqsapi := servicesqs.New(sess)
out := lo.Must(sqsapi.GetQueueUrlWithContext(ctx, &servicesqs.GetQueueUrlInput{QueueName: lo.ToPtr(options.FromContext(ctx).InterruptionQueue)}))
controllers = append(controllers, interruption.NewController(kubeClient, clk, recorder, lo.Must(sqs.NewDefaultProvider(sqsapi, lo.FromPtr(out.QueueUrl))), unavailableOfferings))
}
return controllers
}
42 changes: 21 additions & 21 deletions pkg/controllers/interruption/interruption_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"testing"
"time"

"github.com/avast/retry-go"
"github.com/aws/aws-sdk-go/aws"
awsclient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/endpoints"
Expand Down Expand Up @@ -95,14 +96,15 @@ func benchmarkNotificationController(b *testing.B, messageCount int) {
}
}()

providers := newProviders(env.Context, env.Client)
if err := providers.makeInfrastructure(ctx); err != nil {
providers := newProviders(ctx, env.Client)
queueURL, err := providers.makeInfrastructure(ctx)
if err != nil {
b.Fatalf("standing up infrastructure, %v", err)
}
// Cleanup the infrastructure after the coretest completes
defer func() {
if err := retry.Do(func() error {
return providers.cleanupInfrastructure(ctx)
return providers.cleanupInfrastructure(queueURL)
}); err != nil {
b.Fatalf("deleting infrastructure, %v", err)
}
Expand Down Expand Up @@ -174,31 +176,29 @@ func newProviders(ctx context.Context, kubeClient client.Client) providerSet {
),
))
sqsAPI := servicesqs.New(sess)
out := lo.Must(sqsAPI.GetQueueUrlWithContext(ctx, &servicesqs.GetQueueUrlInput{QueueName: lo.ToPtr(options.FromContext(ctx).InterruptionQueue)}))
return providerSet{
kubeClient: kubeClient,
sqsAPI: sqsAPI,
sqsProvider: sqs.NewProvider(ctx, sqsAPI, "test-cluster"),
sqsProvider: lo.Must(sqs.NewDefaultProvider(sqsAPI, lo.FromPtr(out.QueueUrl))),
}
}

func (p *providerSet) makeInfrastructure(ctx context.Context) error {
if _, err := p.sqsAPI.CreateQueueWithContext(ctx, &servicesqs.CreateQueueInput{
QueueName: lo.ToPtr(options.FromContext(ctx).InterruptionQueueName),
func (p *providerSet) makeInfrastructure(ctx context.Context) (string, error) {
out, err := p.sqsAPI.CreateQueueWithContext(ctx, &servicesqs.CreateQueueInput{
QueueName: lo.ToPtr(options.FromContext(ctx).InterruptionQueue),
Attributes: map[string]*string{
servicesqs.QueueAttributeNameMessageRetentionPeriod: aws.String("1200"), // 20 minutes for this test
},
}); err != nil {
return fmt.Errorf("creating servicesqs queue, %w", err)
})
if err != nil {
return "", fmt.Errorf("creating servicesqs queue, %w", err)
}
return nil
return lo.FromPtr(out.QueueUrl), nil
}

func (p *providerSet) cleanupInfrastructure(ctx context.Context) error {
queueURL, err := p.sqsProvider.DiscoverQueueURL(ctx)
if err != nil {
return fmt.Errorf("discovering queue url for deletion, %w", err)
}
if _, err = p.sqsAPI.DeleteQueueWithContext(ctx, &servicesqs.DeleteQueueInput{
func (p *providerSet) cleanupInfrastructure(queueURL string) error {
if _, err := p.sqsAPI.DeleteQueueWithContext(ctx, &servicesqs.DeleteQueueInput{
QueueUrl: lo.ToPtr(queueURL),
}); err != nil {
return fmt.Errorf("deleting servicesqs queue, %w", err)
Expand All @@ -220,11 +220,11 @@ func (p *providerSet) monitorMessagesProcessed(ctx context.Context, eventRecorde
totalProcessed := 0
go func() {
for totalProcessed < expectedProcessed {
totalProcessed = eventRecorder.Calls(events.InstanceStopping(coretest.Node()).Reason) +
eventRecorder.Calls(events.InstanceTerminating(coretest.Node()).Reason) +
eventRecorder.Calls(events.InstanceUnhealthy(coretest.Node()).Reason) +
eventRecorder.Calls(events.InstanceRebalanceRecommendation(coretest.Node()).Reason) +
eventRecorder.Calls(events.InstanceSpotInterrupted(coretest.Node()).Reason)
totalProcessed = eventRecorder.Calls(events.Stopping(coretest.Node(), coretest.NodeClaim())[0].Reason) +
eventRecorder.Calls(events.Stopping(coretest.Node(), coretest.NodeClaim())[0].Reason) +
eventRecorder.Calls(events.Unhealthy(coretest.Node(), coretest.NodeClaim())[0].Reason) +
eventRecorder.Calls(events.RebalanceRecommendation(coretest.Node(), coretest.NodeClaim())[0].Reason) +
eventRecorder.Calls(events.SpotInterrupted(coretest.Node(), coretest.NodeClaim())[0].Reason)
logging.FromContext(ctx).With("processed-message-count", totalProcessed).Infof("processed messages from the queue")
time.Sleep(time.Second)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/interruption/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ var _ = BeforeSuite(func() {
fakeClock = &clock.FakeClock{}
unavailableOfferingsCache = awscache.NewUnavailableOfferings()
sqsapi = &fake.SQSAPI{}
sqsProvider = lo.Must(sqs.NewProvider(ctx, sqsapi, "test-cluster"))
sqsProvider = lo.Must(sqs.NewDefaultProvider(sqsapi, fmt.Sprintf("https://sqs.%s.amazonaws.com/%s/test-cluster", fake.DefaultRegion, fake.DefaultAccount)))
controller = interruption.NewController(env.Client, fakeClock, events.NewRecorder(&record.FakeRecorder{}), sqsProvider, unavailableOfferingsCache)
})

Expand Down
26 changes: 10 additions & 16 deletions pkg/providers/sqs/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sqs"
Expand All @@ -34,26 +35,19 @@ type Provider interface {
type DefaultProvider struct {
client sqsiface.SQSAPI

name string
url string
queueURL string
}

func NewProvider(ctx context.Context, client sqsiface.SQSAPI, queueName string) (*DefaultProvider, error) {
ret, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
QueueName: aws.String(queueName),
})
if err != nil {
return nil, fmt.Errorf("fetching queue url, %w", err)
}
func NewDefaultProvider(client sqsiface.SQSAPI, queueURL string) (*DefaultProvider, error) {
return &DefaultProvider{
client: client,
name: queueName,
url: aws.StringValue(ret.QueueUrl),
client: client,
queueURL: queueURL,
}, nil
}

func (p *DefaultProvider) Name() string {
return p.name
ss := strings.Split(p.queueURL, "/")
return ss[len(ss)-1]
}

func (p *DefaultProvider) GetSQSMessages(ctx context.Context) ([]*sqs.Message, error) {
Expand All @@ -67,7 +61,7 @@ func (p *DefaultProvider) GetSQSMessages(ctx context.Context) ([]*sqs.Message, e
MessageAttributeNames: []*string{
aws.String(sqs.QueueAttributeNameAll),
},
QueueUrl: aws.String(p.url),
QueueUrl: aws.String(p.queueURL),
}

result, err := p.client.ReceiveMessageWithContext(ctx, input)
Expand All @@ -85,7 +79,7 @@ func (p *DefaultProvider) SendMessage(ctx context.Context, body interface{}) (st
}
input := &sqs.SendMessageInput{
MessageBody: aws.String(string(raw)),
QueueUrl: aws.String(p.url),
QueueUrl: aws.String(p.queueURL),
}
result, err := p.client.SendMessageWithContext(ctx, input)
if err != nil {
Expand All @@ -96,7 +90,7 @@ func (p *DefaultProvider) SendMessage(ctx context.Context, body interface{}) (st

func (p *DefaultProvider) DeleteSQSMessage(ctx context.Context, msg *sqs.Message) error {
input := &sqs.DeleteMessageInput{
QueueUrl: aws.String(p.url),
QueueUrl: aws.String(p.queueURL),
ReceiptHandle: msg.ReceiptHandle,
}

Expand Down
4 changes: 3 additions & 1 deletion test/pkg/environment/aws/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ func NewEnvironment(t *testing.T) *Environment {
}
// Initialize the provider only if the INTERRUPTION_QUEUE environment variable is defined
if v, ok := os.LookupEnv("INTERRUPTION_QUEUE"); ok {
awsEnv.SQSProvider = lo.Must(sqs.NewProvider(env.Context, servicesqs.New(session), v))
sqsapi := servicesqs.New(session)
out := lo.Must(sqsapi.GetQueueUrlWithContext(env.Context, &servicesqs.GetQueueUrlInput{QueueName: aws.String(v)}))
awsEnv.SQSProvider = lo.Must(sqs.NewDefaultProvider(sqsapi, lo.FromPtr(out.QueueUrl)))
}
return awsEnv
}
Expand Down

0 comments on commit e4bf5d7

Please sign in to comment.