Skip to content

Commit

Permalink
fix after merging main
Browse files Browse the repository at this point in the history
Signed-off-by: Heng Qian <qianheng@amazon.com>
  • Loading branch information
qianheng-aws committed Jul 25, 2024
1 parent d300161 commit 631cf10
Showing 6 changed files with 97 additions and 110 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchAnomalyDetectorsTool.Factory.getInstance(),
SearchAnomalyResultsTool.Factory.getInstance(),
SearchMonitorsTool.Factory.getInstance(),
CreateAlertTool.Factory.getInstance()
CreateAlertTool.Factory.getInstance(),
CreateAnomalyDetectorTool.Factory.getInstance()
);
}
Original file line number Diff line number Diff line change
@@ -8,7 +8,6 @@
import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -180,8 +179,7 @@ private String getIndexMappingInfo(Map<String, String> parameters) throws Interr
}
if (indexList.isEmpty()) {
throw new IllegalArgumentException(
"The input indices is empty. Ask user to "
+ "provide index as your final answer directly without using any other tools"
"The input indices is empty. Ask user to " + "provide index as your final answer directly without using any other tools"
);
} else if (indexList.stream().anyMatch(index -> index.startsWith("."))) {
throw new IllegalArgumentException(
93 changes: 44 additions & 49 deletions src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java
Original file line number Diff line number Diff line change
@@ -5,11 +5,6 @@

package org.opensearch.agent.tools.utils;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import com.google.gson.Gson;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
@@ -22,55 +17,55 @@

@Log4j2
public class ToolHelper {
private static final Gson gson = new Gson();
private static final Gson gson = new Gson();

/**
* Load prompt from resource file of the invoking class
* @param source class which calls this function
* @param fileName the resource file name of prompt
* @return the LLM request prompt template.
*/
public static Map<String, String> loadDefaultPromptDictFromFile(Class<?> source, String fileName) {
try (InputStream searchResponseIns = source.getResourceAsStream(fileName)) {
if (searchResponseIns != null) {
String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
return gson.fromJson(defaultPromptContent, Map.class);
}
} catch (IOException e) {
log.error("Failed to load default prompt dict from file: {}", fileName, e);
/**
* Load prompt from the resource file of the invoking class
* @param source class which calls this function
* @param fileName the resource file name of prompt
* @return the LLM request prompt template.
*/
public static Map<String, String> loadDefaultPromptDictFromFile(Class<?> source, String fileName) {
try (InputStream searchResponseIns = source.getResourceAsStream(fileName)) {
if (searchResponseIns != null) {
String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
return gson.fromJson(defaultPromptContent, Map.class);
}
} catch (IOException e) {
log.error("Failed to load default prompt dict from file: {}", fileName, e);
}
return new HashMap<>();
}
return new HashMap<>();
}

/**
* Flatten all the fields in the mappings, insert the field to fieldType mapping to a map
* @param mappingSource the mappings of an index
* @param fieldsToType the result containing the field to fieldType mapping
* @param prefix the parent field path
*/
public static void extractFieldNamesTypes(Map<String, Object> mappingSource, Map<String, String> fieldsToType, String prefix) {
if (prefix.length() > 0) {
prefix += ".";
}
/**
* Flatten all the fields in the mappings, insert the field to fieldType mapping to a map
* @param mappingSource the mappings of an index
* @param fieldsToType the result containing the field to fieldType mapping
* @param prefix the parent field path
*/
public static void extractFieldNamesTypes(Map<String, Object> mappingSource, Map<String, String> fieldsToType, String prefix) {
if (prefix.length() > 0) {
prefix += ".";
}

for (Map.Entry<String, Object> entry : mappingSource.entrySet()) {
String n = entry.getKey();
Object v = entry.getValue();
for (Map.Entry<String, Object> entry : mappingSource.entrySet()) {
String n = entry.getKey();
Object v = entry.getValue();

if (v instanceof Map) {
Map<String, Object> vMap = (Map<String, Object>) v;
if (vMap.containsKey("type")) {
if (!((vMap.getOrDefault("type", "")).equals("alias"))) {
fieldsToType.put(prefix + n, (String) vMap.get("type"));
}
}
if (vMap.containsKey("properties")) {
extractFieldNamesTypes((Map<String, Object>) vMap.get("properties"), fieldsToType, prefix + n);
}
if (vMap.containsKey("fields")) {
extractFieldNamesTypes((Map<String, Object>) vMap.get("fields"), fieldsToType, prefix + n);
if (v instanceof Map) {
Map<String, Object> vMap = (Map<String, Object>) v;
if (vMap.containsKey("type")) {
if (!((vMap.getOrDefault("type", "")).equals("alias"))) {
fieldsToType.put(prefix + n, (String) vMap.get("type"));
}
}
if (vMap.containsKey("properties")) {
extractFieldNamesTypes((Map<String, Object>) vMap.get("properties"), fieldsToType, prefix + n);
}
if (vMap.containsKey("fields")) {
extractFieldNamesTypes((Map<String, Object>) vMap.get("fields"), fieldsToType, prefix + n);
}
}
}
}
}
}
}
37 changes: 28 additions & 9 deletions src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java
Original file line number Diff line number Diff line change
@@ -156,9 +156,14 @@ public void testTool_WithBlankModelId() {
public void testTool_WithNonSupportedModelType() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> CreateAlertTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType"))
() -> CreateAlertTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType"))
);
assertEquals(
"Failed to find the right prompt for modelType: non_supported_modelType, this tool supports prompts for these models: [CLAUDE,OPENAI]",
exception.getMessage()
);
assertEquals("Failed to find the right prompt for modelType: non_supported_modelType, this tool supports prompts for these models: [CLAUDE,OPENAI]", exception.getMessage());
}

@Test
@@ -267,7 +272,10 @@ public void testToolWithIllegalIndices() {
})
)
);
assertEquals("No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools", exception.getMessage());
assertEquals(
"No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools",
exception.getMessage()
);

// empty string as indices
exception = assertThrows(
@@ -280,8 +288,10 @@ public void testToolWithIllegalIndices() {
})
)
);
assertEquals("No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools", exception.getMessage());

assertEquals(
"No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools",
exception.getMessage()
);

// indices is an empty list
exception = assertThrows(
@@ -294,7 +304,10 @@ public void testToolWithIllegalIndices() {
})
)
);
assertEquals("The input indices is empty. Ask user to provide index as your final answer directly without using any other tools", exception.getMessage());
assertEquals(
"The input indices is empty. Ask user to provide index as your final answer directly without using any other tools",
exception.getMessage()
);

// indices contain system index
exception = assertThrows(
@@ -307,10 +320,13 @@ public void testToolWithIllegalIndices() {
})
)
);
assertEquals("The provided indices [[.kibana]] contains system index, which is not allowed. Ask user to check the provided indices as your final answer without using any other.", exception.getMessage());
assertEquals(
"The provided indices [[.kibana]] contains system index, which is not allowed. Ask user to check the provided indices as your final answer without using any other.",
exception.getMessage()
);

// Cannot find provided indices in opensearch
when(getIndexResponse.indices()).thenReturn(new String[]{});
when(getIndexResponse.indices()).thenReturn(new String[] {});
exception = assertThrows(
RuntimeException.class,
() -> tool
@@ -321,6 +337,9 @@ public void testToolWithIllegalIndices() {
})
)
);
assertEquals("Cannot find provided indices [non_existed_index]. Ask user to check the provided indices as your final answer without using any other tools", exception.getMessage());
assertEquals(
"Cannot find provided indices [non_existed_index]. Ask user to check the provided indices as your final answer without using any other tools",
exception.getMessage()
);
}
}
11 changes: 1 addition & 10 deletions src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java
Original file line number Diff line number Diff line change
@@ -397,16 +397,7 @@ public static Response makeRequest(
@SneakyThrows
protected String registerAgent(String modelId, String requestBodyResourceFile) {
String registerAgentRequestBody = Files
.readString(
Path
.of(
this
.getClass()
.getClassLoader()
.getResource(requestBodyResourceFile)
.toURI()
)
);
.readString(Path.of(this.getClass().getClassLoader().getResource(requestBodyResourceFile).toURI()));
registerAgentRequestBody = registerAgentRequestBody.replace("<MODEL_ID>", modelId);
return createAgent(registerAgentRequestBody);
}
60 changes: 22 additions & 38 deletions src/test/java/org/opensearch/integTest/CreateAlertToolIT.java
Original file line number Diff line number Diff line change
@@ -9,21 +9,24 @@

import java.io.IOException;
import java.util.List;
import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;

import org.hamcrest.MatcherAssert;
import org.junit.Before;
import org.opensearch.agent.tools.CreateAlertTool;
import org.opensearch.client.ResponseException;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;

@Log4j2
public class CreateAlertToolIT extends ToolIntegrationTest {
private final String requestBodyResourceFile = "org/opensearch/agent/tools/register_flow_agent_of_create_alert_tool_request_body.json";
private final String NORMAL_INDEX = "normal_index";
private final String NON_EXISTENT_INDEX = "non-existent";
private final String SYSTEM_INDEX = ".kibana";

private final String alertJson = "{\"name\": \"Error 500 Response Alert\",\"search\": {\"indices\": [\"opensearch_dashboards_sample_data_logs\"],\"timeField\": \"timestamp\",\"bucketValue\": 60,\"bucketUnitOfTime\": \"m\",\"filters\": [{\"fieldName\": [{\"label\": \"response\",\"type\": \"text\"}],\"fieldValue\": \"500\",\"operator\": \"is\"}],\"aggregations\": [{\"aggregationType\": \"count\",\"fieldName\": \"bytes\"}]},\"triggers\": [{\"name\": \"Error 500 Response Count Above 1\",\"severity\": 1,\"thresholdValue\": 1,\"thresholdEnum\": \"ABOVE\"}]}";
private final String alertJson =
"{\"name\": \"Error 500 Response Alert\",\"search\": {\"indices\": [\"opensearch_dashboards_sample_data_logs\"],\"timeField\": \"timestamp\",\"bucketValue\": 60,\"bucketUnitOfTime\": \"m\",\"filters\": [{\"fieldName\": [{\"label\": \"response\",\"type\": \"text\"}],\"fieldValue\": \"500\",\"operator\": \"is\"}],\"aggregations\": [{\"aggregationType\": \"count\",\"fieldName\": \"bytes\"}]},\"triggers\": [{\"name\": \"Error 500 Response Count Above 1\",\"severity\": 1,\"thresholdValue\": 1,\"thresholdEnum\": \"ABOVE\"}]}";
private final String question = "Create alert on the index when count of peoples whose age greater than 50 exceeds 10";
private final String pureJsonResponseIndicator = "$PURE_JSON";
private final String noJsonResponseIndicator = "$NO_JSON";
@@ -64,74 +67,55 @@ String toolType() {
@SneakyThrows
public void testCreateAlertTool() {
prepareIndex();
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX);
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX);
String result = executeAgent(agentId, requestBody);
assertEquals(
alertJson,
result
);
assertEquals(alertJson, result);
}

public void testCreateAlertToolWithPureJsonResponse() {
prepareIndex();
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + pureJsonResponseIndicator, NORMAL_INDEX);
String requestBody = String
.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + pureJsonResponseIndicator, NORMAL_INDEX);
String result = executeAgent(agentId, requestBody);
assertEquals(
alertJson,
result
);
assertEquals(alertJson, result);
}

public void testCreateAlertToolWithNoJsonResponse() {
prepareIndex();
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + noJsonResponseIndicator, NORMAL_INDEX);
Exception exception = assertThrows(
ResponseException.class,
() -> executeAgent(agentId, requestBody)
);
String requestBody = String
.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + noJsonResponseIndicator, NORMAL_INDEX);
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody));
MatcherAssert.assertThat(exception.getMessage(), containsString("The response from LLM is not a json"));
}

public void testCreateAlertToolWithNonExistentModelId() {
prepareIndex();
String abnormalAgentId = registerAgent("NON_EXISTENT_MODEL_ID", requestBodyResourceFile);
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX);
Exception exception = assertThrows(
ResponseException.class,
() -> executeAgent(abnormalAgentId, requestBody)
);
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX);
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(abnormalAgentId, requestBody));
MatcherAssert.assertThat(exception.getMessage(), containsString("Failed to find model"));
}

public void testCreateAlertToolWithNonExistentIndex() {
prepareIndex();
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NON_EXISTENT_INDEX);
Exception exception = assertThrows(
ResponseException.class,
() -> executeAgent(agentId, requestBody)
);
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NON_EXISTENT_INDEX);
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody));
MatcherAssert.assertThat(exception.getMessage(), containsString("no such index"));
}

public void testCreateAlertToolWithSystemIndex() {
prepareIndex();
String agentId = registerAgent(modelId, requestBodyResourceFile);
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, SYSTEM_INDEX);
Exception exception = assertThrows(
ResponseException.class,
() -> executeAgent(agentId, requestBody)
);
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, SYSTEM_INDEX);
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody));
MatcherAssert.assertThat(exception.getMessage(), containsString("contains system index, which is not allowed"));
}

public void testCreateAlertToolWithEmptyIndex() {
prepareIndex();
String agentId = registerAgent(modelId, requestBodyResourceFile);
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"\"}}", question);
Exception exception = assertThrows(
ResponseException.class,
() -> executeAgent(agentId, requestBody)
);
String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"\"}}", question);
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody));
MatcherAssert.assertThat(exception.getMessage(), containsString("No indices in the input parameter"));
}

0 comments on commit 631cf10

Please sign in to comment.