Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add CreateAlertTool #349

Merged
merged 12 commits into from
Aug 1, 2024
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.function.Supplier;

import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.CreateAlertTool;
import org.opensearch.agent.tools.CreateAnomalyDetectorTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
Expand Down Expand Up @@ -74,6 +75,7 @@ public Collection<Object> createComponents(
SearchAnomalyDetectorsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchAnomalyResultsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchMonitorsTool.Factory.getInstance().init(client);
CreateAlertTool.Factory.getInstance().init(client);
CreateAnomalyDetectorTool.Factory.getInstance().init(client);
return Collections.emptyList();
}
Expand All @@ -90,6 +92,7 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchAnomalyDetectorsTool.Factory.getInstance(),
SearchAnomalyResultsTool.Factory.getInstance(),
SearchMonitorsTool.Factory.getInstance(),
CreateAlertTool.Factory.getInstance(),
CreateAnomalyDetectorTool.Factory.getInstance()
);
}
Expand Down
282 changes: 282 additions & 0 deletions src/main/java/org/opensearch/agent/tools/CreateAlertTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.indices.get.GetIndexRequest;
import org.opensearch.action.admin.indices.get.GetIndexResponse;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.agent.tools.utils.ToolConstants.ModelType;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.logging.LoggerMessageFormat;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Log4j2
@ToolAnnotation(CreateAlertTool.TYPE)
public class CreateAlertTool implements Tool {
public static final String TYPE = "CreateAlertTool";

private static final String DEFAULT_DESCRIPTION =
"This is a tool that helps to create an alert(i.e. monitor with triggers), some parameters should be parsed based on user's question and context. The parameters should include: \n"
+ "1. indices: The input indices of the monitor, should be a list of string in json format.\n";

@Setter
@Getter
private String name = TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

private final Client client;
private final String modelId;
private final String TOOL_PROMPT_TEMPLATE;

private static final Gson gson = new Gson();
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
private static final String MODEL_ID = "model_id";
private static final String promptFilePath = "CreateAlertDefaultPrompt.json";
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
private static final String defaultQuestion = "Create an alert as your recommendation based on the context";

public CreateAlertTool(Client client, String modelId, String modelType) {
this.client = client;
this.modelId = modelId;
Map<String, String> promptDict = ToolHelper.loadDefaultPromptDictFromFile(this.getClass(), promptFilePath);
if (!promptDict.containsKey(modelType)) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"Failed to find the right prompt for modelType: {}, this tool supports prompts for these models: [{}]",
modelType,
String.join(",", promptDict.keySet())
)
);
}
TOOL_PROMPT_TEMPLATE = promptDict.get(modelType);
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public boolean validate(Map<String, String> parameters) {
return true;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
Map<String, String> tmpParams = new HashMap<>(parameters);
String mappingInfo = getIndexMappingInfo(tmpParams);
tmpParams.put("mapping_info", mappingInfo);
tmpParams.putIfAbsent("indices", "");
tmpParams.putIfAbsent("chat_history", "");
tmpParams.putIfAbsent("question", defaultQuestion); // In case no question is provided, use a default question.
StringSubstitutor substitute = new StringSubstitutor(tmpParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT_TEMPLATE);
tmpParams.put("prompt", finalToolPrompt);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParams).build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
Map<String, ?> dataMap = Optional
.ofNullable(modelTensorOutput.getMlModelOutputs())
.flatMap(outputs -> outputs.stream().findFirst())
.flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst())
.map(ModelTensor::getDataAsMap)
.orElse(null);
if (dataMap == null) {
throw new IllegalArgumentException("No dataMap returned from LLM.");
}
String response = "";
if (dataMap.containsKey("response")) {
response = (String) dataMap.get("response");
Pattern jsonPattern = Pattern.compile("```json(.*?)```", Pattern.DOTALL);
Matcher jsonBlockMatcher = jsonPattern.matcher(response);
if (jsonBlockMatcher.find()) {
response = jsonBlockMatcher.group(1);
response = response.replace("\\\"", "\"");
}
} else {
// LLM sometimes returns the tensor results as a json object directly instead of
// string response, and the json object is stored as a map.
response = StringUtils.toJson(dataMap);
}
if (!isJson(response)) {
throw new IllegalArgumentException(
LoggerMessageFormat.format(null, "The response from LLM is not a json: [{}]", response)
);
}
listener.onResponse((T) response);
}, e -> {
log.error("Failed to run model " + modelId, e);
listener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to call CreateAlertTool", e);
listener.onFailure(e);
}
}

private String getIndexMappingInfo(Map<String, String> parameters) throws InterruptedException, ExecutionException {
if (!parameters.containsKey("indices") || parameters.get("indices").isEmpty()) {
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalArgumentException(
"No indices in the input parameter. Ask user to "
+ "provide index as your final answer directly without using any other tools"
);
}
String rawIndex = parameters.getOrDefault("indices", "");
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
List<String> indexList;
try {
indexList = gson.fromJson(rawIndex, new TypeToken<List<String>>() {
}.getType());
} catch (Exception e) {
// LLM sometimes returns the indices as a string instead of a json list, although we require that in the tool description.
indexList = Collections.singletonList(rawIndex);
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
}
if (indexList.isEmpty()) {
throw new IllegalArgumentException(
"The input indices is empty. Ask user to " + "provide index as your final answer directly without using any other tools"
);
} else if (indexList.stream().anyMatch(index -> index.startsWith("."))) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"The provided indices [{}] contains system index, which is not allowed. Ask user to "
+ "check the provided indices as your final answer without using any other.",
rawIndex
)
);
}
final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY);
final GetIndexRequest getIndexRequest = new GetIndexRequest()
.indices(indices)
.indicesOptions(IndicesOptions.strictExpand())
.local(Boolean.parseBoolean(parameters.getOrDefault("local", "true")))
.clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT);
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
GetIndexResponse getIndexResponse = client.admin().indices().getIndex(getIndexRequest).get();
if (getIndexResponse.indices().length == 0) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"Cannot find provided indices {}. Ask "
+ "user to check the provided indices as your final answer without using any other "
+ "tools",
rawIndex
)
);
}
StringBuilder sb = new StringBuilder();
for (String index : getIndexResponse.indices()) {
sb.append("index: ").append(index).append("\n\n");

MappingMetadata mapping = getIndexResponse.mappings().get(index);
if (mapping != null) {
sb.append("mappings:\n");
for (Entry<String, Object> entry : mapping.sourceAsMap().entrySet()) {
sb.append(entry.getKey()).append("=").append(entry.getValue()).append('\n');
}
sb.append("\n\n");
}
}
return sb.toString();
}

public static class Factory implements Tool.Factory<CreateAlertTool> {

private Client client;

private static Factory INSTANCE;

public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (CreateAlertTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client) {
this.client = client;
}

@Override
public CreateAlertTool create(Map<String, Object> params) {
String modelId = (String) params.get(MODEL_ID);
if (Strings.isBlank(modelId)) {
throw new IllegalArgumentException("model_id cannot be null or blank.");
}
String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString());
return new CreateAlertTool(client, modelId, modelType);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public String getDefaultType() {
return TYPE;
}

@Override
public String getDefaultVersion() {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.agent.tools.utils;

import java.util.Locale;

public class ToolConstants {
// Detector state is not cleanly defined on the backend plugin. So, we persist a standard
// set of states here for users to interface with when fetching and filtering detectors.
Expand All @@ -17,6 +19,15 @@ public static enum DetectorStateString {
Initializing
}

public enum ModelType {
CLAUDE,
OPENAI;

public static ModelType from(String value) {
return valueOf(value.toUpperCase(Locale.ROOT));
}
}

// System indices constants are not cleanly exposed from the AD & Alerting plugins, so we persist our
// own constants here.
public static final String AD_RESULTS_INDEX_PATTERN = ".opendistro-anomaly-results*";
Expand Down
29 changes: 29 additions & 0 deletions src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,38 @@

package org.opensearch.agent.tools.utils;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

import com.google.gson.Gson;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class ToolHelper {
private static final Gson gson = new Gson();
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved

/**
* Load prompt from the resource file of the invoking class
* @param source class which calls this function
* @param fileName the resource file name of prompt
* @return the LLM request prompt template.
*/
public static Map<String, String> loadDefaultPromptDictFromFile(Class<?> source, String fileName) {
try (InputStream searchResponseIns = source.getResourceAsStream(fileName)) {
if (searchResponseIns != null) {
String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
return gson.fromJson(defaultPromptContent, Map.class);
}
} catch (IOException e) {
log.error("Failed to load default prompt dict from file: {}", fileName, e);
}
return new HashMap<>();
}

/**
* Flatten all the fields in the mappings, insert the field to fieldType mapping to a map
* @param mappingSource the mappings of an index
Expand Down
Loading
Loading