diff --git a/operator/rpc_client.go b/operator/rpc_client.go index 9f2a63ae..2bef5337 100644 --- a/operator/rpc_client.go +++ b/operator/rpc_client.go @@ -1,6 +1,7 @@ package operator import ( + "errors" "fmt" "net/rpc" "sync" @@ -29,7 +30,7 @@ const ( type RpcMessage = interface{} type AggregatorRpcClient struct { - rpcClientLock sync.Mutex + rpcClientLock sync.RWMutex rpcClient *rpc.Client aggregatorIpPortAddr string @@ -71,45 +72,77 @@ func (c *AggregatorRpcClient) WithMetrics(registry *prometheus.Registry) error { } func (c *AggregatorRpcClient) dialAggregatorRpcClient() error { + c.rpcClientLock.Lock() + defer c.rpcClientLock.Unlock() + + if c.rpcClient != nil { + return nil + } + c.logger.Info("rpc client is nil. Dialing aggregator rpc client") client, err := rpc.DialHTTP("tcp", c.aggregatorIpPortAddr) if err != nil { + c.logger.Error("Error dialing aggregator rpc client", "err", err) return err } + c.rpcClient = client + return nil } func (c *AggregatorRpcClient) InitializeClientIfNotExist() error { - c.rpcClientLock.Lock() - defer c.rpcClientLock.Unlock() - + c.rpcClientLock.RLock() if c.rpcClient != nil { + c.rpcClientLock.RUnlock() return nil } + c.rpcClientLock.RUnlock() return c.dialAggregatorRpcClient() } + +func (c *AggregatorRpcClient) handleRpcError(err error) error { + if err == rpc.ErrShutdown { + go c.handleRpcShutdown() + } + + return nil +} + +func (c *AggregatorRpcClient) handleRpcShutdown() { + c.rpcClientLock.Lock() + defer c.rpcClientLock.Unlock() + + if c.rpcClient != nil { + c.logger.Info("Closing RPC client due to shutdown") + + err := c.rpcClient.Close() + if err != nil { + c.logger.Error("Error closing RPC client", "err", err) + } + + c.rpcClient = nil + } +} + func (c *AggregatorRpcClient) onTick() { - tickerC := c.resendTicker.C for { - // TODO(edwin): handle closed chan - <-tickerC - - { - c.unsentMessagesLock.Lock() - if len(c.unsentMessages) == 0 { - c.unsentMessagesLock.Unlock() - continue - } - c.unsentMessagesLock.Unlock() - } + <-c.resendTicker.C err := c.InitializeClientIfNotExist() if err != nil { + c.logger.Error("Error initializing client", "err", err) + continue + } + + c.unsentMessagesLock.Lock() + if len(c.unsentMessages) == 0 { + c.unsentMessagesLock.Unlock() continue } + c.unsentMessagesLock.Unlock() c.tryResendFromDeque() } @@ -164,39 +197,42 @@ func (c *AggregatorRpcClient) tryResendFromDeque() { } func (c *AggregatorRpcClient) sendOperatorMessage(sendCb func() error, message RpcMessage) { + c.rpcClientLock.RLock() + defer c.rpcClientLock.RUnlock() + appendProtected := func() { c.unsentMessagesLock.Lock() c.unsentMessages = append(c.unsentMessages, message) c.unsentMessagesLock.Unlock() } - err := c.InitializeClientIfNotExist() - if err != nil { + if c.rpcClient == nil { appendProtected() return } c.logger.Info("Sending request to aggregator") - err = sendCb() + err := sendCb() if err != nil { + c.handleRpcError(err) appendProtected() return } - - c.tryResendFromDeque() } func (c *AggregatorRpcClient) sendRequest(sendCb func() error) error { - err := c.InitializeClientIfNotExist() - if err != nil { - c.logger.Error("Could not reinitialize RPC client") - return err + c.rpcClientLock.RLock() + defer c.rpcClientLock.RUnlock() + + if c.rpcClient == nil { + return errors.New("rpc client is nil") } c.logger.Info("Sending request to aggregator") - err = sendCb() + err := sendCb() if err != nil { + c.handleRpcError(err) return err }