diff --git a/fleetspeak/src/server/https/https.go b/fleetspeak/src/server/https/https.go index 1959a61e..faca7055 100644 --- a/fleetspeak/src/server/https/https.go +++ b/fleetspeak/src/server/https/https.go @@ -84,13 +84,14 @@ func (l listener) Accept() (net.Conn, error) { // Params wraps the parameters required to create an https communicator. type Params struct { - Listener net.Listener // Where to listen for connections, required. - Cert, Key []byte // x509 encoded certificate and matching private key, required. - Streaming bool // Whether to enable streaming communications. - FrontendConfig *cpb.FrontendConfig // Configure how the frontend identifies and communicates with clients - StreamingLifespan time.Duration // Maximum time to keep a streaming connection open, defaults to 10 min. - StreamingCloseTime time.Duration // How much of StreamingLifespan to allocate to an orderly stream close, defaults to 30 sec. - StreamingJitter time.Duration // Maximum amount of jitter to add to StreamingLifespan. + Listener net.Listener // Where to listen for connections, required. + Cert, Key []byte // x509 encoded certificate and matching private key, required. + Streaming bool // Whether to enable streaming communications. + FrontendConfig *cpb.FrontendConfig // Configure how the frontend identifies and communicates with clients + StreamingLifespan time.Duration // Maximum time to keep a streaming connection open, defaults to 10 min. + StreamingCloseTime time.Duration // How much of StreamingLifespan to allocate to an orderly stream close, defaults to 30 sec. + StreamingJitter time.Duration // Maximum amount of jitter to add to StreamingLifespan. + MaxPerClientBatchProcessors uint32 // Maximum number of concurrent processors for messages coming from a single client. } // NewCommunicator creates a Communicator, which listens through l and identifies @@ -102,6 +103,10 @@ func NewCommunicator(p Params) (*Communicator, error) { if p.StreamingCloseTime == 0 { p.StreamingCloseTime = 30 * time.Second } + if p.MaxPerClientBatchProcessors == 0 { + p.MaxPerClientBatchProcessors = 10 + } + mux := http.NewServeMux() c, err := tls.X509KeyPair(p.Cert, p.Key) if err != nil { @@ -138,7 +143,7 @@ func NewCommunicator(p Params) (*Communicator, error) { } mux.Handle("/message", messageServer{&h}) if p.Streaming { - mux.Handle("/streaming-message", streamingMessageServer{&h}) + mux.Handle("/streaming-message", newStreamingMessageServer(&h, p.MaxPerClientBatchProcessors)) } mux.Handle("/files/", fileServer{&h}) diff --git a/fleetspeak/src/server/https/streaming_message_server.go b/fleetspeak/src/server/https/streaming_message_server.go index 9f8d1c5d..a8af8297 100644 --- a/fleetspeak/src/server/https/streaming_message_server.go +++ b/fleetspeak/src/server/https/streaming_message_server.go @@ -60,12 +60,17 @@ func writeUint32(res fullResponseWriter, i uint32) error { return binary.Write(res, binary.LittleEndian, i) } +func newStreamingMessageServer(c *Communicator, maxPerClientBatchProcessors uint32) *streamingMessageServer { + return &streamingMessageServer{c, maxPerClientBatchProcessors} +} + // messageServer wraps a Communicator in order to handle clients polls. type streamingMessageServer struct { *Communicator + maxPerClientBatchProcessors uint32 } -func (s streamingMessageServer) ServeHTTP(res http.ResponseWriter, req *http.Request) { +func (s *streamingMessageServer) ServeHTTP(res http.ResponseWriter, req *http.Request) { earlyError := func(msg string, status int) { log.ErrorDepth(1, fmt.Sprintf("%s: %s", http.StatusText(status), msg)) s.fs.StatsCollector().ClientPoll(stats.PollInfo{ @@ -167,7 +172,7 @@ func (s streamingMessageServer) ServeHTTP(res http.ResponseWriter, req *http.Req m.cancel() } -func (s streamingMessageServer) initialPoll(ctx context.Context, addr net.Addr, key crypto.PublicKey, res fullResponseWriter, body *bufio.Reader) (*comms.ConnectionInfo, bool, error) { +func (s *streamingMessageServer) initialPoll(ctx context.Context, addr net.Addr, key crypto.PublicKey, res fullResponseWriter, body *bufio.Reader) (*comms.ConnectionInfo, bool, error) { ctx, fin := context.WithTimeout(ctx, 3*time.Minute) pi := stats.PollInfo{ @@ -266,7 +271,7 @@ func (s streamingMessageServer) initialPoll(ctx context.Context, addr net.Addr, type streamManager struct { ctx context.Context - s streamingMessageServer + s *streamingMessageServer info *comms.ConnectionInfo res fullResponseWriter @@ -293,7 +298,7 @@ func (m *streamManager) readLoop() { // Number of batches from the same client that will be processed concurrently. const maxBatchProcessors = 10 - batchCh := make(chan *fspb.WrappedContactData, maxBatchProcessors) + batchCh := make(chan *fspb.WrappedContactData, m.s.maxPerClientBatchProcessors) for { pi, wcd, err := m.readOne() @@ -309,19 +314,27 @@ func (m *streamManager) readLoop() { return } + // Increment the counter with every processed message. + cnt++ + // This will block if number of concurrent processors is greater than maxBatchProcessors. batchCh <- wcd - go func() { + // Ensure the m.out stays open while the message processing is not done. + m.reading.Add(1) + // Given that the processing is done concurrently, capture the current counter value in + // the function argument. + go func(curCnt uint64) { + defer m.reading.Done() + wcd := <-batchCh if err := m.processOne(wcd); err != nil { log.Errorf("Error processing message from %v: %v", m.info.Client.ID, err) + return } - }() + m.out <- &fspb.ContactData{AckIndex: curCnt} + }(cnt) m.s.fs.StatsCollector().ClientPoll(*pi) - cnt++ - - m.out <- &fspb.ContactData{AckIndex: cnt} } }