diff --git a/CHANGELOG.md b/CHANGELOG.md index 2401cc6d2c5f8..de8422d733225 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Moved concurrent-search from sandbox plugin to server module behind feature flag ([#7203](https://github.com/opensearch-project/OpenSearch/pull/7203)) - Allow access to indices cache clear APIs for read only indexes ([#7303](https://github.com/opensearch-project/OpenSearch/pull/7303)) - Default search preference to _primary for searchable snapshot indices ([#7628](https://github.com/opensearch-project/OpenSearch/pull/7628)) +- Changed concurrent-search threadpool type to be resizable and support task resource tracking ([#7502](https://github.com/opensearch-project/OpenSearch/pull/7502)) ### Deprecated diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/AbstractTasksIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/AbstractTasksIT.java new file mode 100644 index 0000000000000..fcfe9cb0aab00 --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/AbstractTasksIT.java @@ -0,0 +1,190 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.admin.cluster.node.tasks; + +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.admin.cluster.node.tasks.get.GetTaskResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.plugins.Plugin; +import org.opensearch.tasks.TaskId; +import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.ThreadResourceInfo; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.test.transport.MockTransportService; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** + * Base IT test class for Tasks ITs + */ +abstract class AbstractTasksIT extends OpenSearchIntegTestCase { + + protected Map, RecordingTaskManagerListener> listeners = new HashMap<>(); + + @Override + protected Collection> getMockPlugins() { + Collection> mockPlugins = new ArrayList<>(super.getMockPlugins()); + mockPlugins.remove(MockTransportService.TestPlugin.class); + return mockPlugins; + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(MockTransportService.TestPlugin.class, TestTaskPlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), true) + .build(); + } + + @Override + public void tearDown() throws Exception { + for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { + ((MockTaskManager) internalCluster().getInstance(TransportService.class, entry.getKey().v1()).getTaskManager()).removeListener( + entry.getValue() + ); + } + listeners.clear(); + super.tearDown(); + } + + /** + * Registers recording task event listeners with the given action mask on all nodes + */ + protected void registerTaskManagerListeners(String actionMasks) { + for (String nodeName : internalCluster().getNodeNames()) { + DiscoveryNode node = internalCluster().getInstance(ClusterService.class, nodeName).localNode(); + RecordingTaskManagerListener listener = new RecordingTaskManagerListener(node.getId(), actionMasks.split(",")); + ((MockTaskManager) internalCluster().getInstance(TransportService.class, nodeName).getTaskManager()).addListener(listener); + RecordingTaskManagerListener oldListener = listeners.put(new Tuple<>(node.getName(), actionMasks), listener); + assertNull(oldListener); + } + } + + /** + * Resets all recording task event listeners with the given action mask on all nodes + */ + protected void resetTaskManagerListeners(String actionMasks) { + for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { + if (actionMasks == null || entry.getKey().v2().equals(actionMasks)) { + entry.getValue().reset(); + } + } + } + + /** + * Returns the number of events that satisfy the criteria across all nodes + * + * @param actionMasks action masks to match + * @return number of events that satisfy the criteria + */ + protected int numberOfEvents(String actionMasks, Function, Boolean> criteria) { + return findEvents(actionMasks, criteria).size(); + } + + /** + * Returns all events that satisfy the criteria across all nodes + * + * @param actionMasks action masks to match + * @return number of events that satisfy the criteria + */ + protected List findEvents(String actionMasks, Function, Boolean> criteria) { + List events = new ArrayList<>(); + for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { + if (actionMasks == null || entry.getKey().v2().equals(actionMasks)) { + for (Tuple taskEvent : entry.getValue().getEvents()) { + if (criteria.apply(taskEvent)) { + events.add(taskEvent.v2()); + } + } + } + } + return events; + } + + protected Map> getThreadStats(String actionMasks, TaskId taskId) { + for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { + if (actionMasks == null || entry.getKey().v2().equals(actionMasks)) { + for (Tuple>> threadStats : entry.getValue().getThreadStats()) { + if (taskId.equals(threadStats.v1())) { + return threadStats.v2(); + } + } + } + } + return new HashMap<>(); + } + + /** + * Asserts that all tasks in the tasks list have the same parentTask + */ + protected void assertParentTask(List tasks, TaskInfo parentTask) { + for (TaskInfo task : tasks) { + assertParentTask(task, parentTask); + } + } + + protected void assertParentTask(TaskInfo task, TaskInfo parentTask) { + assertTrue(task.getParentTaskId().isSet()); + assertEquals(parentTask.getTaskId().getNodeId(), task.getParentTaskId().getNodeId()); + assertTrue(Strings.hasLength(task.getParentTaskId().getNodeId())); + assertEquals(parentTask.getId(), task.getParentTaskId().getId()); + } + + protected void expectNotFound(ThrowingRunnable r) { + Exception e = expectThrows(Exception.class, r); + ResourceNotFoundException notFound = (ResourceNotFoundException) ExceptionsHelper.unwrap(e, ResourceNotFoundException.class); + if (notFound == null) { + throw new AssertionError("Expected " + ResourceNotFoundException.class.getSimpleName(), e); + } + } + + /** + * Fetch the task status from the list tasks API using it's "fallback to get from the task index" behavior. Asserts some obvious stuff + * about the fetched task and returns a map of it's status. + */ + protected GetTaskResponse expectFinishedTask(TaskId taskId) throws IOException { + GetTaskResponse response = client().admin().cluster().prepareGetTask(taskId).get(); + assertTrue("the task should have been completed before fetching", response.getTask().isCompleted()); + TaskInfo info = response.getTask().getTask(); + assertEquals(taskId, info.getTaskId()); + assertNull(info.getStatus()); // The test task doesn't have any status + return response; + } + + protected void indexDocumentsWithRefresh(String indexName, int numDocs) { + for (int i = 0; i < numDocs; i++) { + client().prepareIndex(indexName) + .setId("test_id_" + String.valueOf(i)) + .setSource("{\"foo_" + String.valueOf(i) + "\": \"bar_" + String.valueOf(i) + "\"}", XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + } + } +} diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/ConcurrentSearchTasksIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/ConcurrentSearchTasksIT.java new file mode 100644 index 0000000000000..2b2421072e03b --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/ConcurrentSearchTasksIT.java @@ -0,0 +1,118 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.admin.cluster.node.tasks; + +import org.hamcrest.MatcherAssert; +import org.opensearch.action.admin.indices.segments.IndicesSegmentsRequest; +import org.opensearch.action.search.SearchAction; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.settings.FeatureFlagSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.ThreadResourceInfo; + +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.notNullValue; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse; + +/** + * Integration tests for task management API with Concurrent Segment Search + * + * The way the test framework bootstraps the test cluster makes it difficult to parameterize the feature flag. + * Once concurrent search is moved behind a cluster setting we can parameterize these tests behind the setting. + */ +public class ConcurrentSearchTasksIT extends AbstractTasksIT { + + private static final int INDEX_SEARCHER_THREADS = 10; + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put("thread_pool.index_searcher.size", INDEX_SEARCHER_THREADS) + .put("thread_pool.index_searcher.queue_size", INDEX_SEARCHER_THREADS) + .build(); + } + + private int getSegmentCount(String indexName) { + return client().admin() + .indices() + .segments(new IndicesSegmentsRequest(indexName)) + .actionGet() + .getIndices() + .get(indexName) + .getShards() + .get(0) + .getShards()[0].getSegments() + .size(); + } + + @Override + protected Settings featureFlagSettings() { + Settings.Builder featureSettings = Settings.builder(); + for (Setting builtInFlag : FeatureFlagSettings.BUILT_IN_FEATURE_FLAGS) { + featureSettings.put(builtInFlag.getKey(), builtInFlag.getDefaultRaw(Settings.EMPTY)); + } + featureSettings.put(FeatureFlags.CONCURRENT_SEGMENT_SEARCH, true); + return featureSettings.build(); + } + + /** + * Tests the number of threads that worked on a search task. + * + * Currently, we try to control concurrency by creating an index with 7 segments and rely on + * the way concurrent search creates leaf slices from segments. Once more concurrency controls are introduced + * we should improve this test to use those methods. + */ + public void testConcurrentSearchTaskTracking() { + final String INDEX_NAME = "test"; + final int NUM_SHARDS = 1; + final int NUM_DOCS = 7; + + registerTaskManagerListeners(SearchAction.NAME); // coordinator task + registerTaskManagerListeners(SearchAction.NAME + "[*]"); // shard task + createIndex( + INDEX_NAME, + Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, NUM_SHARDS) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .build() + ); + ensureGreen(INDEX_NAME); // Make sure all shards are allocated to catch replication tasks + indexDocumentsWithRefresh(INDEX_NAME, NUM_DOCS); // Concurrent search requires >5 segments or >250,000 docs to have concurrency, so + // we index 7 docs flushing between each to create new segments + assertSearchResponse(client().prepareSearch(INDEX_NAME).setQuery(QueryBuilders.matchAllQuery()).get()); + + // the search operation should produce one coordinator task + List mainTask = findEvents(SearchAction.NAME, Tuple::v1); + assertEquals(1, mainTask.size()); + TaskInfo mainTaskInfo = mainTask.get(0); + + List shardTasks = findEvents(SearchAction.NAME + "[*]", Tuple::v1); + assertEquals(NUM_SHARDS, shardTasks.size()); // We should only have 1 shard search task per shard + for (TaskInfo taskInfo : shardTasks) { + MatcherAssert.assertThat(taskInfo.getParentTaskId(), notNullValue()); + assertEquals(mainTaskInfo.getTaskId(), taskInfo.getParentTaskId()); + + Map> threadStats = getThreadStats(SearchAction.NAME + "[*]", taskInfo.getTaskId()); + // Concurrent search forks each slice of 5 segments to different thread + assertEquals((int) Math.ceil(getSegmentCount(INDEX_NAME) / 5.0), threadStats.size()); + + // assert that all task descriptions have non-zero length + MatcherAssert.assertThat(taskInfo.getDescription().length(), greaterThan(0)); + } + } +} diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java index 533d22a01c4ad..67e52529ae86b 100644 --- a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java @@ -32,10 +32,9 @@ package org.opensearch.action.admin.cluster.node.tasks; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchTimeoutException; -import org.opensearch.ExceptionsHelper; -import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionListener; import org.opensearch.action.TaskOperationFailure; @@ -57,15 +56,11 @@ import org.opensearch.action.support.WriteRequest; import org.opensearch.action.support.replication.ReplicationResponse; import org.opensearch.action.support.replication.TransportReplicationActionTests; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.Strings; import org.opensearch.common.collect.Tuple; import org.opensearch.common.regex.Regex; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.plugins.Plugin; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskId; @@ -75,14 +70,9 @@ import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.tasks.MockTaskManager; import org.opensearch.test.tasks.MockTaskManagerListener; -import org.opensearch.test.transport.MockTransportService; import org.opensearch.transport.ReceiveTimeoutTransportException; import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -96,12 +86,6 @@ import static java.util.Collections.emptyList; import static java.util.Collections.singleton; -import static org.opensearch.common.unit.TimeValue.timeValueMillis; -import static org.opensearch.common.unit.TimeValue.timeValueSeconds; -import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE; -import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertFutureThrows; -import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures; -import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; @@ -113,6 +97,12 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.startsWith; +import static org.opensearch.common.unit.TimeValue.timeValueMillis; +import static org.opensearch.common.unit.TimeValue.timeValueSeconds; +import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertFutureThrows; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse; /** * Integration tests for task management API @@ -120,29 +110,7 @@ * We need at least 2 nodes so we have a cluster-manager node a non-cluster-manager node */ @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, minNumDataNodes = 2) -public class TasksIT extends OpenSearchIntegTestCase { - - private Map, RecordingTaskManagerListener> listeners = new HashMap<>(); - - @Override - protected Collection> getMockPlugins() { - Collection> mockPlugins = new ArrayList<>(super.getMockPlugins()); - mockPlugins.remove(MockTransportService.TestPlugin.class); - return mockPlugins; - } - - @Override - protected Collection> nodePlugins() { - return Arrays.asList(MockTransportService.TestPlugin.class, TestTaskPlugin.class); - } - - @Override - protected Settings nodeSettings(int nodeOrdinal) { - return Settings.builder() - .put(super.nodeSettings(nodeOrdinal)) - .put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), true) - .build(); - } +public class TasksIT extends AbstractTasksIT { public void testTaskCounts() { // Run only on data nodes @@ -951,106 +919,4 @@ public void onFailure(Exception e) { assertNotNull(response.getTask().getError()); assertNull(response.getTask().getResponse()); } - - @Override - public void tearDown() throws Exception { - for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { - ((MockTaskManager) internalCluster().getInstance(TransportService.class, entry.getKey().v1()).getTaskManager()).removeListener( - entry.getValue() - ); - } - listeners.clear(); - super.tearDown(); - } - - /** - * Registers recording task event listeners with the given action mask on all nodes - */ - private void registerTaskManagerListeners(String actionMasks) { - for (String nodeName : internalCluster().getNodeNames()) { - DiscoveryNode node = internalCluster().getInstance(ClusterService.class, nodeName).localNode(); - RecordingTaskManagerListener listener = new RecordingTaskManagerListener(node.getId(), actionMasks.split(",")); - ((MockTaskManager) internalCluster().getInstance(TransportService.class, nodeName).getTaskManager()).addListener(listener); - RecordingTaskManagerListener oldListener = listeners.put(new Tuple<>(node.getName(), actionMasks), listener); - assertNull(oldListener); - } - } - - /** - * Resets all recording task event listeners with the given action mask on all nodes - */ - private void resetTaskManagerListeners(String actionMasks) { - for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { - if (actionMasks == null || entry.getKey().v2().equals(actionMasks)) { - entry.getValue().reset(); - } - } - } - - /** - * Returns the number of events that satisfy the criteria across all nodes - * - * @param actionMasks action masks to match - * @return number of events that satisfy the criteria - */ - private int numberOfEvents(String actionMasks, Function, Boolean> criteria) { - return findEvents(actionMasks, criteria).size(); - } - - /** - * Returns all events that satisfy the criteria across all nodes - * - * @param actionMasks action masks to match - * @return number of events that satisfy the criteria - */ - private List findEvents(String actionMasks, Function, Boolean> criteria) { - List events = new ArrayList<>(); - for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { - if (actionMasks == null || entry.getKey().v2().equals(actionMasks)) { - for (Tuple taskEvent : entry.getValue().getEvents()) { - if (criteria.apply(taskEvent)) { - events.add(taskEvent.v2()); - } - } - } - } - return events; - } - - /** - * Asserts that all tasks in the tasks list have the same parentTask - */ - private void assertParentTask(List tasks, TaskInfo parentTask) { - for (TaskInfo task : tasks) { - assertParentTask(task, parentTask); - } - } - - private void assertParentTask(TaskInfo task, TaskInfo parentTask) { - assertTrue(task.getParentTaskId().isSet()); - assertEquals(parentTask.getTaskId().getNodeId(), task.getParentTaskId().getNodeId()); - assertTrue(Strings.hasLength(task.getParentTaskId().getNodeId())); - assertEquals(parentTask.getId(), task.getParentTaskId().getId()); - } - - private void expectNotFound(ThrowingRunnable r) { - Exception e = expectThrows(Exception.class, r); - ResourceNotFoundException notFound = (ResourceNotFoundException) ExceptionsHelper.unwrap(e, ResourceNotFoundException.class); - if (notFound == null) { - throw new AssertionError("Expected " + ResourceNotFoundException.class.getSimpleName(), e); - } - } - - /** - * Fetch the task status from the list tasks API using it's "fallback to get from the task index" behavior. Asserts some obvious stuff - * about the fetched task and returns a map of it's status. - */ - private GetTaskResponse expectFinishedTask(TaskId taskId) throws IOException { - GetTaskResponse response = client().admin().cluster().prepareGetTask(taskId).get(); - assertTrue("the task should have been completed before fetching", response.getTask().isCompleted()); - TaskInfo info = response.getTask().getTask(); - assertEquals(taskId, info.getTaskId()); - assertNull(info.getStatus()); // The test task doesn't have any status - return response; - } } diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index 987f38e8dd8fd..e3e34378746b9 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -182,7 +182,7 @@ public static ThreadPoolType fromType(String type) { map.put(Names.REMOTE_PURGE, ThreadPoolType.SCALING); map.put(Names.REMOTE_REFRESH, ThreadPoolType.SCALING); if (FeatureFlags.isEnabled(FeatureFlags.CONCURRENT_SEGMENT_SEARCH)) { - map.put(Names.INDEX_SEARCHER, ThreadPoolType.FIXED); + map.put(Names.INDEX_SEARCHER, ThreadPoolType.FIXED_AUTO_QUEUE_SIZE); } THREAD_POOL_TYPES = Collections.unmodifiableMap(map); } @@ -279,7 +279,19 @@ public ThreadPool( new ScalingExecutorBuilder(Names.REMOTE_REFRESH, 1, halfProcMaxAt10, TimeValue.timeValueMinutes(5)) ); if (FeatureFlags.isEnabled(FeatureFlags.CONCURRENT_SEGMENT_SEARCH)) { - builders.put(Names.INDEX_SEARCHER, new FixedExecutorBuilder(settings, Names.INDEX_SEARCHER, allocatedProcessors, 1000, false)); + builders.put( + Names.INDEX_SEARCHER, + new AutoQueueAdjustingExecutorBuilder( + settings, + Names.INDEX_SEARCHER, + allocatedProcessors, + 1000, + 1000, + 1000, + 2000, + runnableTaskListener + ) + ); } for (final ExecutorBuilder builder : customBuilders) { diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java index 9bd44185baf24..768a6c73af380 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java @@ -35,12 +35,15 @@ import org.opensearch.common.collect.Tuple; import org.opensearch.common.regex.Regex; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.ThreadResourceInfo; import org.opensearch.test.tasks.MockTaskManagerListener; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; /** @@ -52,6 +55,7 @@ public class RecordingTaskManagerListener implements MockTaskManagerListener { private String localNodeId; private List> events = new ArrayList<>(); + private List>>> threadStats = new ArrayList<>(); public RecordingTaskManagerListener(String localNodeId, String... actionMasks) { this.actionMasks = actionMasks; @@ -68,7 +72,9 @@ public synchronized void onTaskRegistered(Task task) { @Override public synchronized void onTaskUnregistered(Task task) { if (Regex.simpleMatch(actionMasks, task.getAction())) { - events.add(new Tuple<>(false, task.taskInfo(localNodeId, true))); + TaskInfo taskInfo = task.taskInfo(localNodeId, true); + events.add(new Tuple<>(false, taskInfo)); + threadStats.add(new Tuple<>(taskInfo.getTaskId(), task.getResourceStats())); } } @@ -82,6 +88,10 @@ public synchronized List> getEvents() { return Collections.unmodifiableList(new ArrayList<>(events)); } + public synchronized List>>> getThreadStats() { + return List.copyOf(threadStats); + } + public synchronized List getRegistrationEvents() { List events = this.events.stream().filter(Tuple::v1).map(Tuple::v2).collect(Collectors.toList()); return Collections.unmodifiableList(events); diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java index 8ba23c5d3219c..03257ee2a0a84 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java @@ -20,8 +20,13 @@ import org.opensearch.threadpool.ThreadPool; import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.tasks.ResourceStats.CPU; import static org.opensearch.tasks.ResourceStats.MEMORY; import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; @@ -88,6 +93,49 @@ public void testStopTrackingHandlesCurrentActiveThread() { assertTrue(task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue() > 0); } + /** + * Test if taskResourceTrackingService properly tracks resource usage when multiple threads work on the same task + */ + public void testStartingTrackingHandlesMultipleThreadsPerTask() throws InterruptedException { + ExecutorService executor = threadPool.executor(ThreadPool.Names.GENERIC); + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>()); + taskResourceTrackingService.startTracking(task); + int numTasks = randomIntBetween(2, 100); + for (int i = 0; i < numTasks; i++) { + executor.execute(() -> { + long threadId = Thread.currentThread().getId(); + taskResourceTrackingService.taskExecutionStartedOnThread(task.getId(), threadId); + // The same thread may pick up multiple runnables for the same task id + assertEquals(1, task.getResourceStats().get(threadId).stream().filter(ThreadResourceInfo::isActive).count()); + taskResourceTrackingService.taskExecutionFinishedOnThread(task.getId(), threadId); + }); + } + executor.shutdown(); + while (true) { + try { + if (executor.awaitTermination(1, TimeUnit.MINUTES)) break; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + Map> stats = task.getResourceStats(); + int numExecutions = 0; + for (Long threadId : stats.keySet()) { + for (ThreadResourceInfo info : task.getResourceStats().get(threadId)) { + assertTrue(info.getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue() > 0); + assertTrue(info.getResourceUsageInfo().getStatsInfo().get(CPU).getTotalValue() > 0); + assertFalse(info.isActive()); + numExecutions++; + } + + } + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + assertTrue(task.getTotalResourceStats().getMemoryInBytes() > 0); + // Each execution of a runnable should record an entry in resourceStats even if it's the same thread + assertEquals(numTasks, numExecutions); + } + private void verifyThreadContextFixedHeaders(String key, String value) { assertEquals(threadPool.getThreadContext().getHeader(key), value); assertEquals(threadPool.getThreadContext().getTransient(key), value);