diff --git a/hack/format.sh b/hack/format.sh index 48c64954..df0626d2 100755 --- a/hack/format.sh +++ b/hack/format.sh @@ -1,7 +1,7 @@ #!/bin/sh set -o errexit set -o nounset -set -o pipefail +#set -o pipefail find . -type f | grep .go$ | grep -v vendor | xargs -I {} gofmt -w {} diff --git a/pkg/queue/beanstalk.go b/pkg/queue/beanstalk.go index a1e74495..b8cc9150 100644 --- a/pkg/queue/beanstalk.go +++ b/pkg/queue/beanstalk.go @@ -2,6 +2,7 @@ package queue import ( "errors" + "fmt" "io" "net/url" "path" @@ -60,6 +61,7 @@ type BeanstalkClientInterface interface { put(body []byte, pri uint32, delay, t time.Duration) (id uint64, err error) getStats() (int32, int32, int32, error) longPollReceiveMessage(longPollInterval int64) (int32, int32, error) + reestablishConn() error } type beanstalkClient struct { @@ -92,6 +94,11 @@ func getBeanstalkConn(queueURI string) (*beanstalk.Conn, error) { if err != nil { return nil, errors.New("dial-error: " + err.Error()) } + + if conn == nil { + return nil, fmt.Errorf("Connection nil for: %s\n", queueURI) + } + return conn, nil } @@ -190,8 +197,17 @@ func (c *beanstalkClient) put( func (c *beanstalkClient) doLongPoll( longPollInterval int64) (bool, uint64, error) { + var id uint64 + var err error + tubeSet := beanstalk.NewTubeSet(c.conn, path.Base(c.queueURI)) - id, _, err := tubeSet.Reserve( + if tubeSet == nil { + err = c.reestablishConn() + if err != nil { + return false, id, err + } + } + id, _, err = tubeSet.Reserve( time.Duration(longPollInterval) * time.Second) if err == nil { return true, id, nil @@ -248,6 +264,9 @@ func (b *Beanstalk) getClient( if err != nil { return nil, err } + if client == nil { + return nil, fmt.Errorf("Not able to make client for: %s\n", queueURI) + } b.clientPool.Store(queueURI, client) return client.(BeanstalkClientInterface), nil @@ -303,6 +322,17 @@ func (b *Beanstalk) waitForShortPollInterval() { time.Sleep(b.shortPollInterval) } +func (b *Beanstalk) reestablishConn(queueURI string) { + client, err := b.getClient(queueURI) + if err != nil { + klog.Errorf("Could not reestablish conn, err:%v\n", err) + return + } + + err = client.reestablishConn() + klog.Error(err) +} + func (b *Beanstalk) GetName() string { return b.name } @@ -320,6 +350,7 @@ func (b *Beanstalk) poll(key string, queueSpec QueueSpec) { if err != nil { klog.Errorf("Unable to perform request long polling %q, %v.", queueSpec.name, err) + b.reestablishConn(queueSpec.uri) return } @@ -345,6 +376,7 @@ func (b *Beanstalk) poll(key string, queueSpec QueueSpec) { if err != nil { klog.Errorf("Unable to get approximate messages in queue %q, %v.", queueSpec.name, err) + b.reestablishConn(queueSpec.uri) return } klog.V(3).Infof("%s: approxMessages=%d", queueSpec.name, approxMessages) @@ -372,6 +404,7 @@ func (b *Beanstalk) poll(key string, queueSpec QueueSpec) { if err != nil { klog.Errorf("Unable to fetch idle workers %q, %v.", queueSpec.name, err) + b.reestablishConn(queueSpec.uri) time.Sleep(100 * time.Millisecond) return } diff --git a/pkg/queue/beanstalk_mock.go b/pkg/queue/beanstalk_mock.go index f8b79386..e0284616 100644 --- a/pkg/queue/beanstalk_mock.go +++ b/pkg/queue/beanstalk_mock.go @@ -84,3 +84,17 @@ func (mr *MockBeanstalkClientInterfaceMockRecorder) put(arg0, arg1, arg2, arg3 i mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "put", reflect.TypeOf((*MockBeanstalkClientInterface)(nil).put), arg0, arg1, arg2, arg3) } + +// reestablishConn mocks base method +func (m *MockBeanstalkClientInterface) reestablishConn() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "reestablishConn") + ret0, _ := ret[0].(error) + return ret0 +} + +// put indicates an expected call of put +func (mr *MockBeanstalkClientInterfaceMockRecorder) reestablishConn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "reestablishConn", reflect.TypeOf((*MockBeanstalkClientInterface)(nil).reestablishConn)) +} diff --git a/pkg/queue/sqs.go b/pkg/queue/sqs.go index 6b96472e..1bcd2819 100644 --- a/pkg/queue/sqs.go +++ b/pkg/queue/sqs.go @@ -87,8 +87,13 @@ func (s *SQS) getSQSClient(queueURI string) *sqs.SQS { return s.sqsClientPool[getRegion(queueURI)] } -func (s *SQS) getCWClient(queueURI string) *cloudwatch.CloudWatch { - return s.cwClientPool[getRegion(queueURI)] +func (s *SQS) getCWClient(queueURI string) (*cloudwatch.CloudWatch, error) { + client, ok := s.cwClientPool[getRegion(queueURI)] + if !ok { + return nil, fmt.Errorf("Client not found for queue: %s\n", queueURI) + } + + return client, nil } func (s *SQS) longPollReceiveMessage(queueURI string) (int32, error) { @@ -187,7 +192,12 @@ func (s *SQS) getNumberOfMessagesReceived(queueURI string) (float64, error) { }, } - result, err := s.getCWClient(queueURI).GetMetricData(&cloudwatch.GetMetricDataInput{ + cwClient, err := s.getCWClient(queueURI) + if err != nil { + return 0.0, err + } + + result, err := cwClient.GetMetricData(&cloudwatch.GetMetricDataInput{ EndTime: &endTime, StartTime: &startTime, MetricDataQueries: []*cloudwatch.MetricDataQuery{query}, @@ -315,7 +325,12 @@ func (s *SQS) getAverageNumberOfMessagesSent(queueURI string) (float64, error) { }, } - result, err := s.getCWClient(queueURI).GetMetricData(&cloudwatch.GetMetricDataInput{ + cwClient, err := s.getCWClient(queueURI) + if err != nil { + return 0.0, err + } + + result, err := cwClient.GetMetricData(&cloudwatch.GetMetricDataInput{ EndTime: &endTime, StartTime: &startTime, MetricDataQueries: []*cloudwatch.MetricDataQuery{query},