diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 458fec61..823320f0 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -92,7 +92,7 @@ public List> getToolFactories() { SearchAnomalyDetectorsTool.Factory.getInstance(), SearchAnomalyResultsTool.Factory.getInstance(), SearchMonitorsTool.Factory.getInstance(), - CreateAlertTool.Factory.getInstance() + CreateAlertTool.Factory.getInstance(), CreateAnomalyDetectorTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java index 58424c08..af6fa94e 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -8,7 +8,6 @@ import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; import static org.opensearch.ml.common.utils.StringUtils.isJson; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -180,8 +179,7 @@ private String getIndexMappingInfo(Map parameters) throws Interr } if (indexList.isEmpty()) { throw new IllegalArgumentException( - "The input indices is empty. Ask user to " - + "provide index as your final answer directly without using any other tools" + "The input indices is empty. Ask user to " + "provide index as your final answer directly without using any other tools" ); } else if (indexList.stream().anyMatch(index -> index.startsWith("."))) { throw new IllegalArgumentException( diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java index 3877b9d8..35cf39b8 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java @@ -5,11 +5,6 @@ package org.opensearch.agent.tools.utils; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import com.google.gson.Gson; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; @@ -22,55 +17,55 @@ @Log4j2 public class ToolHelper { - private static final Gson gson = new Gson(); + private static final Gson gson = new Gson(); - /** - * Load prompt from resource file of the invoking class - * @param source class which calls this function - * @param fileName the resource file name of prompt - * @return the LLM request prompt template. - */ - public static Map loadDefaultPromptDictFromFile(Class source, String fileName) { - try (InputStream searchResponseIns = source.getResourceAsStream(fileName)) { - if (searchResponseIns != null) { - String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); - return gson.fromJson(defaultPromptContent, Map.class); - } - } catch (IOException e) { - log.error("Failed to load default prompt dict from file: {}", fileName, e); + /** + * Load prompt from the resource file of the invoking class + * @param source class which calls this function + * @param fileName the resource file name of prompt + * @return the LLM request prompt template. + */ + public static Map loadDefaultPromptDictFromFile(Class source, String fileName) { + try (InputStream searchResponseIns = source.getResourceAsStream(fileName)) { + if (searchResponseIns != null) { + String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + return gson.fromJson(defaultPromptContent, Map.class); + } + } catch (IOException e) { + log.error("Failed to load default prompt dict from file: {}", fileName, e); + } + return new HashMap<>(); } - return new HashMap<>(); - } - /** - * Flatten all the fields in the mappings, insert the field to fieldType mapping to a map - * @param mappingSource the mappings of an index - * @param fieldsToType the result containing the field to fieldType mapping - * @param prefix the parent field path - */ - public static void extractFieldNamesTypes(Map mappingSource, Map fieldsToType, String prefix) { - if (prefix.length() > 0) { - prefix += "."; - } + /** + * Flatten all the fields in the mappings, insert the field to fieldType mapping to a map + * @param mappingSource the mappings of an index + * @param fieldsToType the result containing the field to fieldType mapping + * @param prefix the parent field path + */ + public static void extractFieldNamesTypes(Map mappingSource, Map fieldsToType, String prefix) { + if (prefix.length() > 0) { + prefix += "."; + } - for (Map.Entry entry : mappingSource.entrySet()) { - String n = entry.getKey(); - Object v = entry.getValue(); + for (Map.Entry entry : mappingSource.entrySet()) { + String n = entry.getKey(); + Object v = entry.getValue(); - if (v instanceof Map) { - Map vMap = (Map) v; - if (vMap.containsKey("type")) { - if (!((vMap.getOrDefault("type", "")).equals("alias"))) { - fieldsToType.put(prefix + n, (String) vMap.get("type")); - } - } - if (vMap.containsKey("properties")) { - extractFieldNamesTypes((Map) vMap.get("properties"), fieldsToType, prefix + n); - } - if (vMap.containsKey("fields")) { - extractFieldNamesTypes((Map) vMap.get("fields"), fieldsToType, prefix + n); + if (v instanceof Map) { + Map vMap = (Map) v; + if (vMap.containsKey("type")) { + if (!((vMap.getOrDefault("type", "")).equals("alias"))) { + fieldsToType.put(prefix + n, (String) vMap.get("type")); + } + } + if (vMap.containsKey("properties")) { + extractFieldNamesTypes((Map) vMap.get("properties"), fieldsToType, prefix + n); + } + if (vMap.containsKey("fields")) { + extractFieldNamesTypes((Map) vMap.get("fields"), fieldsToType, prefix + n); + } + } } - } } - } } diff --git a/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java b/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java index b4a36563..0df05467 100644 --- a/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java @@ -156,9 +156,14 @@ public void testTool_WithBlankModelId() { public void testTool_WithNonSupportedModelType() { Exception exception = assertThrows( IllegalArgumentException.class, - () -> CreateAlertTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType")) + () -> CreateAlertTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType")) + ); + assertEquals( + "Failed to find the right prompt for modelType: non_supported_modelType, this tool supports prompts for these models: [CLAUDE,OPENAI]", + exception.getMessage() ); - assertEquals("Failed to find the right prompt for modelType: non_supported_modelType, this tool supports prompts for these models: [CLAUDE,OPENAI]", exception.getMessage()); } @Test @@ -267,7 +272,10 @@ public void testToolWithIllegalIndices() { }) ) ); - assertEquals("No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools", exception.getMessage()); + assertEquals( + "No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools", + exception.getMessage() + ); // empty string as indices exception = assertThrows( @@ -280,8 +288,10 @@ public void testToolWithIllegalIndices() { }) ) ); - assertEquals("No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools", exception.getMessage()); - + assertEquals( + "No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools", + exception.getMessage() + ); // indices is an empty list exception = assertThrows( @@ -294,7 +304,10 @@ public void testToolWithIllegalIndices() { }) ) ); - assertEquals("The input indices is empty. Ask user to provide index as your final answer directly without using any other tools", exception.getMessage()); + assertEquals( + "The input indices is empty. Ask user to provide index as your final answer directly without using any other tools", + exception.getMessage() + ); // indices contain system index exception = assertThrows( @@ -307,10 +320,13 @@ public void testToolWithIllegalIndices() { }) ) ); - assertEquals("The provided indices [[.kibana]] contains system index, which is not allowed. Ask user to check the provided indices as your final answer without using any other.", exception.getMessage()); + assertEquals( + "The provided indices [[.kibana]] contains system index, which is not allowed. Ask user to check the provided indices as your final answer without using any other.", + exception.getMessage() + ); // Cannot find provided indices in opensearch - when(getIndexResponse.indices()).thenReturn(new String[]{}); + when(getIndexResponse.indices()).thenReturn(new String[] {}); exception = assertThrows( RuntimeException.class, () -> tool @@ -321,6 +337,9 @@ public void testToolWithIllegalIndices() { }) ) ); - assertEquals("Cannot find provided indices [non_existed_index]. Ask user to check the provided indices as your final answer without using any other tools", exception.getMessage()); + assertEquals( + "Cannot find provided indices [non_existed_index]. Ask user to check the provided indices as your final answer without using any other tools", + exception.getMessage() + ); } } diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java index 482989ae..dcbf194d 100644 --- a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -397,16 +397,7 @@ public static Response makeRequest( @SneakyThrows protected String registerAgent(String modelId, String requestBodyResourceFile) { String registerAgentRequestBody = Files - .readString( - Path - .of( - this - .getClass() - .getClassLoader() - .getResource(requestBodyResourceFile) - .toURI() - ) - ); + .readString(Path.of(this.getClass().getClassLoader().getResource(requestBodyResourceFile).toURI())); registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); return createAgent(registerAgentRequestBody); } diff --git a/src/test/java/org/opensearch/integTest/CreateAlertToolIT.java b/src/test/java/org/opensearch/integTest/CreateAlertToolIT.java index e4a70d4b..ef214d06 100644 --- a/src/test/java/org/opensearch/integTest/CreateAlertToolIT.java +++ b/src/test/java/org/opensearch/integTest/CreateAlertToolIT.java @@ -9,13 +9,15 @@ import java.io.IOException; import java.util.List; -import lombok.SneakyThrows; -import lombok.extern.log4j.Log4j2; + import org.hamcrest.MatcherAssert; import org.junit.Before; import org.opensearch.agent.tools.CreateAlertTool; import org.opensearch.client.ResponseException; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + @Log4j2 public class CreateAlertToolIT extends ToolIntegrationTest { private final String requestBodyResourceFile = "org/opensearch/agent/tools/register_flow_agent_of_create_alert_tool_request_body.json"; @@ -23,7 +25,8 @@ public class CreateAlertToolIT extends ToolIntegrationTest { private final String NON_EXISTENT_INDEX = "non-existent"; private final String SYSTEM_INDEX = ".kibana"; - private final String alertJson = "{\"name\": \"Error 500 Response Alert\",\"search\": {\"indices\": [\"opensearch_dashboards_sample_data_logs\"],\"timeField\": \"timestamp\",\"bucketValue\": 60,\"bucketUnitOfTime\": \"m\",\"filters\": [{\"fieldName\": [{\"label\": \"response\",\"type\": \"text\"}],\"fieldValue\": \"500\",\"operator\": \"is\"}],\"aggregations\": [{\"aggregationType\": \"count\",\"fieldName\": \"bytes\"}]},\"triggers\": [{\"name\": \"Error 500 Response Count Above 1\",\"severity\": 1,\"thresholdValue\": 1,\"thresholdEnum\": \"ABOVE\"}]}"; + private final String alertJson = + "{\"name\": \"Error 500 Response Alert\",\"search\": {\"indices\": [\"opensearch_dashboards_sample_data_logs\"],\"timeField\": \"timestamp\",\"bucketValue\": 60,\"bucketUnitOfTime\": \"m\",\"filters\": [{\"fieldName\": [{\"label\": \"response\",\"type\": \"text\"}],\"fieldValue\": \"500\",\"operator\": \"is\"}],\"aggregations\": [{\"aggregationType\": \"count\",\"fieldName\": \"bytes\"}]},\"triggers\": [{\"name\": \"Error 500 Response Count Above 1\",\"severity\": 1,\"thresholdValue\": 1,\"thresholdEnum\": \"ABOVE\"}]}"; private final String question = "Create alert on the index when count of peoples whose age greater than 50 exceeds 10"; private final String pureJsonResponseIndicator = "$PURE_JSON"; private final String noJsonResponseIndicator = "$NO_JSON"; @@ -64,74 +67,55 @@ String toolType() { @SneakyThrows public void testCreateAlertTool() { prepareIndex(); - String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX); String result = executeAgent(agentId, requestBody); - assertEquals( - alertJson, - result - ); + assertEquals(alertJson, result); } public void testCreateAlertToolWithPureJsonResponse() { prepareIndex(); - String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + pureJsonResponseIndicator, NORMAL_INDEX); + String requestBody = String + .format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + pureJsonResponseIndicator, NORMAL_INDEX); String result = executeAgent(agentId, requestBody); - assertEquals( - alertJson, - result - ); + assertEquals(alertJson, result); } public void testCreateAlertToolWithNoJsonResponse() { prepareIndex(); - String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + noJsonResponseIndicator, NORMAL_INDEX); - Exception exception = assertThrows( - ResponseException.class, - () -> executeAgent(agentId, requestBody) - ); + String requestBody = String + .format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + noJsonResponseIndicator, NORMAL_INDEX); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody)); MatcherAssert.assertThat(exception.getMessage(), containsString("The response from LLM is not a json")); } public void testCreateAlertToolWithNonExistentModelId() { prepareIndex(); String abnormalAgentId = registerAgent("NON_EXISTENT_MODEL_ID", requestBodyResourceFile); - String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX); - Exception exception = assertThrows( - ResponseException.class, - () -> executeAgent(abnormalAgentId, requestBody) - ); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(abnormalAgentId, requestBody)); MatcherAssert.assertThat(exception.getMessage(), containsString("Failed to find model")); } public void testCreateAlertToolWithNonExistentIndex() { prepareIndex(); - String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NON_EXISTENT_INDEX); - Exception exception = assertThrows( - ResponseException.class, - () -> executeAgent(agentId, requestBody) - ); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NON_EXISTENT_INDEX); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody)); MatcherAssert.assertThat(exception.getMessage(), containsString("no such index")); } public void testCreateAlertToolWithSystemIndex() { prepareIndex(); String agentId = registerAgent(modelId, requestBodyResourceFile); - String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, SYSTEM_INDEX); - Exception exception = assertThrows( - ResponseException.class, - () -> executeAgent(agentId, requestBody) - ); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, SYSTEM_INDEX); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody)); MatcherAssert.assertThat(exception.getMessage(), containsString("contains system index, which is not allowed")); } public void testCreateAlertToolWithEmptyIndex() { prepareIndex(); String agentId = registerAgent(modelId, requestBodyResourceFile); - String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"\"}}", question); - Exception exception = assertThrows( - ResponseException.class, - () -> executeAgent(agentId, requestBody) - ); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"\"}}", question); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody)); MatcherAssert.assertThat(exception.getMessage(), containsString("No indices in the input parameter")); }