diff --git a/builder.go b/builder.go index 84df68e2..e3d184b3 100644 --- a/builder.go +++ b/builder.go @@ -28,6 +28,7 @@ type serverBuilder struct { cmdExecProfBuilder profiling.CmdProfilerBuilder storeBuilder store.Builder reporter reporter.Reporter + disableParallelism bool } func newBuilder() (*serverBuilder, error) { @@ -78,5 +79,6 @@ func (builder *serverBuilder) build() (*Server, error) { cmdExecProfBuilder: builder.cmdExecProfBuilder, versionInfo: builder.versionInfo, reporter: builder.reporter, + disableParallelism: builder.disableParallelism, }, nil } diff --git a/internal/contexts/context.go b/internal/contexts/context.go index db369486..983819cc 100644 --- a/internal/contexts/context.go +++ b/internal/contexts/context.go @@ -57,3 +57,19 @@ func IsRemoteUpdateCtx(ctx context.Context) bool { func NewRemoteUpdateCtx(ctx context.Context) context.Context { return context.WithValue(ctx, handleRemoteUpdateCtxKey, struct{}{}) } + +type disableParallelismCtxType struct{} + +var disableParallelismCtxKey disableParallelismCtxType + +func IsParallelismDisabledCtx(ctx context.Context) bool { + return ctx.Value(disableParallelismCtxKey) != nil +} + +func NewDisableParallelismCtx(ctx context.Context, disabled bool) context.Context { + if !disabled { + return ctx + } + + return context.WithValue(ctx, disableParallelismCtxKey, struct{}{}) +} diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index dc427d95..7156c468 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -80,7 +80,7 @@ func (m *Mailbox) Fetch(ctx context.Context, seq *proto.SequenceSet, attributes // Only run in parallel if we have to fetch more than minCountForParallelism messages or if we have more than one // message and we need to access the literal. - if len(snapMessages) > minCountForParallelism || (len(snapMessages) > 1 && needsLiteral) { + if !contexts.IsParallelismDisabledCtx(ctx) && (len(snapMessages) > minCountForParallelism || (len(snapMessages) > 1 && needsLiteral)) { // If multiple fetch request are happening in parallel, reduce the number of goroutines in proportion to that // to avoid overloading the user's machine. parallelism = runtime.NumCPU() / int(activeFetchRequests) diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index 6a1476f3..5cd0e1db 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -48,7 +48,13 @@ func (m *Mailbox) Search(ctx context.Context, keys []*proto.SearchKey, decoder * activeSearchRequests := atomic.AddInt32(&totalActiveSearchRequests, 1) defer atomic.AddInt32(&totalActiveSearchRequests, -1) - parallelism := runtime.NumCPU() / int(activeSearchRequests) + var parallelism int + + if contexts.IsParallelismDisabledCtx(ctx) { + parallelism = 1 + } else { + parallelism = runtime.NumCPU() / int(activeSearchRequests) + } if err := parallel.DoContext(ctx, parallelism, msgCount, func(ctx context.Context, i int) error { msg, ok := m.snap.messages.getWithSeqID(imap.SeqID(i + 1)) diff --git a/option.go b/option.go index 9b16da7d..fef706bc 100644 --- a/option.go +++ b/option.go @@ -167,3 +167,13 @@ func (w *withReporter) config(builder *serverBuilder) { func WithReporter(reporter reporter.Reporter) Option { return &withReporter{reporter: reporter} } + +type withDisableParallelism struct{} + +func (withDisableParallelism) config(builder *serverBuilder) { + builder.disableParallelism = true +} + +func WithDisableParallelism() Option { + return &withDisableParallelism{} +} diff --git a/server.go b/server.go index e519c412..26d96ca2 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/ProtonMail/gluon/internal/contexts" "io" "net" "runtime/pprof" @@ -78,6 +79,9 @@ type Server struct { // idleBulkTime to control how often IDLE responses are sent. 0 means // immediate response with no response merging. idleBulkTime time.Duration + + // disableParallelism indicates whether the server is allowed to parallelize certain IMAP commands. + disableParallelism bool } // New creates a new server with the given options. @@ -154,6 +158,7 @@ func (s *Server) AddWatcher(ofType ...events.Event) <-chan events.Event { // It stops serving when the context is canceled, the listener is closed, or the server is closed. func (s *Server) Serve(ctx context.Context, l net.Listener) error { ctx = reporter.NewContextWithReporter(ctx, s.reporter) + ctx = contexts.NewDisableParallelismCtx(ctx, s.disableParallelism) s.publish(events.ListenerAdded{ Addr: l.Addr(), diff --git a/tests/server_test.go b/tests/server_test.go index c04e74a2..2dd60dc5 100644 --- a/tests/server_test.go +++ b/tests/server_test.go @@ -62,13 +62,14 @@ func (*dummyConnectorBuilder) New(usernames []string, password []byte, period ti } type serverOptions struct { - credentials []credentials - delimiter string - loginJailTime time.Duration - dataDir string - idleBulkTime time.Duration - storeBuilder store.Builder - connectorBuilder connectorBuilder + credentials []credentials + delimiter string + loginJailTime time.Duration + dataDir string + idleBulkTime time.Duration + storeBuilder store.Builder + connectorBuilder connectorBuilder + disableParallelism bool } func (s *serverOptions) defaultUsername() string { @@ -131,6 +132,12 @@ func (c *connectorBuilderOption) apply(options *serverOptions) { options.connectorBuilder = c.builder } +type disableParallelism struct{} + +func (disableParallelism) apply(options *serverOptions) { + options.disableParallelism = true +} + func withIdleBulkTime(idleBulkTime time.Duration) serverOption { return &idleBulkTimeOption{idleBulkTime: idleBulkTime} } @@ -155,6 +162,10 @@ func withConnectorBuilder(builder connectorBuilder) serverOption { return &connectorBuilderOption{builder: builder} } +func withDisableParallelism() serverOption { + return &disableParallelism{} +} + func defaultServerOptions(tb testing.TB, modifiers ...serverOption) *serverOptions { options := &serverOptions{ credentials: []credentials{{ @@ -187,8 +198,7 @@ func runServer(tb testing.TB, options *serverOptions, tests func(session *testSe // Log the (temporary?) directory to store gluon data. logrus.Tracef("Gluon Data Dir: %v", options.dataDir) - // Create a new gluon server. - server, err := gluon.New( + gluonOptions := []gluon.Option{ gluon.WithDataDir(options.dataDir), gluon.WithDelimiter(options.delimiter), gluon.WithLoginJailTime(options.loginJailTime), @@ -210,6 +220,15 @@ func runServer(tb testing.TB, options *serverOptions, tests func(session *testSe ), gluon.WithIdleBulkTime(options.idleBulkTime), gluon.WithStoreBuilder(options.storeBuilder), + } + + if options.disableParallelism { + gluonOptions = append(gluonOptions, gluon.WithDisableParallelism()) + } + + // Create a new gluon server. + server, err := gluon.New( + gluonOptions..., ) require.NoError(tb, err)