Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize sigma aggregation rule based detectors execution workflow #1418

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.securityanalytics.transport.TransportIndexDetectorAction.CHAINED_FINDINGS_MONITOR_STRING;

/**
* Alerts Service implements operations involving interaction with Alerting Plugin
*/
Expand Down Expand Up @@ -84,12 +86,21 @@ public void onResponse(GetDetectorResponse getDetectorResponse) {
// monitor --> detectorId mapping
Map<String, String> monitorToDetectorMapping = new HashMap<>();
detector.getMonitorIds().forEach(
monitorId -> monitorToDetectorMapping.put(monitorId, detector.getId())
monitorId -> {
if (detector.getRuleIdMonitorIdMap().containsKey(CHAINED_FINDINGS_MONITOR_STRING)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this?

if (detector.getRuleIdMonitorIdMap().get(CHAINED_FINDINGS_MONITOR_STRING).equals(monitorId) ||
(detector.getRuleIdMonitorIdMap().containsKey("-1") && detector.getRuleIdMonitorIdMap().get("-1").equals(monitorId))) {
monitorToDetectorMapping.put(monitorId, detector.getId());
}
} else {
monitorToDetectorMapping.put(monitorId, detector.getId());
}
}
);
// Get alerts for all monitor ids
AlertsService.this.getAlertsByMonitorIds(
monitorToDetectorMapping,
monitorIds,
new ArrayList<>(monitorToDetectorMapping.keySet()),
DetectorMonitorConfig.getAllAlertsIndicesPattern(detector.getDetectorType()),
table,
severityLevel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ private Monitor buildThreatIntelMonitor(IndexThreatIntelMonitorRequest request)
Collections.emptyMap(),
new DataSources(),
false,
null,
PLUGIN_OWNER_FIELD
);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List<Pair<String, Rule>
detector.getAlertsHistoryIndex(),
detector.getAlertsHistoryIndexPattern(),
DetectorMonitorConfig.getRuleIndexMappingsByType(),
true), enableDetectorWithDedicatedQueryIndices, PLUGIN_OWNER_FIELD);
true), enableDetectorWithDedicatedQueryIndices, null, PLUGIN_OWNER_FIELD);

return new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null);
}
Expand Down Expand Up @@ -902,7 +902,7 @@ private IndexMonitorRequest createDocLevelMonitorMatchAllRequest(
detector.getAlertsHistoryIndex(),
detector.getAlertsHistoryIndexPattern(),
DetectorMonitorConfig.getRuleIndexMappingsByType(),
true), enableDetectorWithDedicatedQueryIndices, PLUGIN_OWNER_FIELD);
true), enableDetectorWithDedicatedQueryIndices, true, PLUGIN_OWNER_FIELD);

return new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null);
}
Expand Down Expand Up @@ -1078,7 +1078,7 @@ public void onResponse(GetIndexMappingsResponse getIndexMappingsResponse) {
detector.getAlertsHistoryIndex(),
detector.getAlertsHistoryIndexPattern(),
DetectorMonitorConfig.getRuleIndexMappingsByType(),
true), false, PLUGIN_OWNER_FIELD);
true), false, null, PLUGIN_OWNER_FIELD);

listener.onResponse(new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ public void upsertWorkflow(
}
cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId();
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIds(monitorResponses));
} else if (updatedMonitorResponses != null && updatedMonitorResponses.stream().anyMatch(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName()))) {
List<IndexMonitorResponse> monitorResponses = new ArrayList<>(updatedMonitorResponses);
monitorResponses.addAll(updatedMonitorResponses);
cmfMonitorId = updatedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId();
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIds(monitorResponses));
}

IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public void testGetAlerts_success() {
Map.of(),
new DataSources(),
true,
null,
TransportIndexDetectorAction.PLUGIN_OWNER_FIELD
),
new DocumentLevelTrigger("trigger_id_1", "my_trigger", "severity_low", List.of(), new Script("")),
Expand Down Expand Up @@ -131,6 +132,7 @@ public void testGetAlerts_success() {
Map.of(),
new DataSources(),
true,
null,
TransportIndexDetectorAction.PLUGIN_OWNER_FIELD
),
new DocumentLevelTrigger("trigger_id_1", "my_trigger", "severity_low", List.of(), new Script("")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException

Response createMappingResponse = client().performRequest(createMappingRequest);

assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode());
assertEquals(org.apache.http.HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode());

String infoOpCode = "Info";

Expand Down Expand Up @@ -850,28 +850,11 @@ public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException
SearchHit hit = hits.get(0);
Map<String, List> updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));

List<String> monitorIds = ((List<String>) (updatedDetectorMap).get("monitor_id"));
String workflowId = ((List<String>) (updatedDetectorMap).get("workflow_ids")).get(0);

indexDoc(index, "1", randomDoc(2, 4, infoOpCode));
indexDoc(index, "2", randomDoc(3, 4, infoOpCode));

Map<String, Integer> numberOfMonitorTypes = new HashMap<>();

for (String monitorId : monitorIds) {
Map<String, String> monitor = (Map<String, String>) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor");
numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum);
Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());

// Assert monitor executions
Map<String, Object> executeResults = entityAsMap(executeResponse);
if (Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type")) && false == monitor.get("name").equals(detector.getName() + "_chained_findings")) {
int noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
assertEquals(5, noOfSigmaRuleMatches);
}
}

assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue());
assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());
executeAlertingWorkflow(workflowId, Collections.emptyMap());

Map<String, String> params = new HashMap<>();
params.put("detector_id", detectorId);
Expand Down Expand Up @@ -911,15 +894,15 @@ public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException
Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null);
Map<String, Object> getAlertsBody = asMap(getAlertsResponse);
// TODO enable asserts here when able
Assert.assertEquals(3, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert
Assert.assertEquals(1, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert

input = new DetectorInput("updated", List.of("windows"), detectorRules,
Collections.emptyList());
Detector updatedDetector = randomDetectorWithInputsAndTriggers(List.of(input),
List.of(new DetectorTrigger("updated", "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))
);
/** update detector and verify chained findings monitor should still exist*/
Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector));
makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector));
hits = executeSearch(Detector.DETECTORS_INDEX, request);
hit = hits.get(0);
updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));
Expand All @@ -932,29 +915,48 @@ public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException
hit = hits.get(0);
updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));

monitorIds = ((List<String>) (updatedDetectorMap).get("monitor_id"));
numberOfMonitorTypes = new HashMap<>();
for (String monitorId : monitorIds) {
Map<String, String> monitor = (Map<String, String>) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor");
numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum);
Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());

// Assert monitor executions
Map<String, Object> executeResults = entityAsMap(executeResponse);

if (Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type"))) {
ArrayList triggerResults = new ArrayList(((Map<String, Object>) executeResults.get("trigger_results")).values());
assertEquals(triggerResults.size(), 1);
Map<String, Object> triggerResult = (Map<String, Object>) triggerResults.get(0);
assertTrue(triggerResult.containsKey("agg_result_buckets"));
HashMap<String, Object> aggResultBuckets = (HashMap<String, Object>) triggerResult.get("agg_result_buckets");
assertTrue(aggResultBuckets.containsKey("4"));
assertTrue(aggResultBuckets.containsKey("5"));
workflowId = ((List<String>) (updatedDetectorMap).get("workflow_ids")).get(0);
executeAlertingWorkflow(workflowId, Collections.emptyMap());

params = new HashMap<>();
params.put("detector_id", detectorId);
getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
getFindingsBody = entityAsMap(getFindingsResponse);

assertNotNull(getFindingsBody);
assertEquals(2, getFindingsBody.get("total_findings"));

findingDetectorId = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString();
assertEquals(detectorId, findingDetectorId);

findingIndex = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString();
assertEquals(index, findingIndex);

docLevelFinding = new ArrayList<>();
findings = (List) getFindingsBody.get("findings");


for (Map<String, Object> finding : findings) {
List<Map<String, Object>> queries = (List<Map<String, Object>>) finding.get("queries");
Set<String> findingRuleIds = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet());

// In the case of bucket level monitors, queries will always contain one value
String aggRuleId = findingRuleIds.iterator().next();
List<String> findingDocs = (List<String>) finding.get("related_doc_ids");

if (aggRuleId.equals(sumRuleId)) {
assertTrue(List.of("1", "2", "3", "4", "5", "6", "7").containsAll(findingDocs));
}
}

assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue());
assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());
assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding));

params1 = new HashMap<>();
params1.put("detector_id", detectorId);
getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null);
getAlertsBody = asMap(getAlertsResponse);
// TODO enable asserts here when able
Assert.assertEquals(2, getAlertsBody.get("total_alerts"));
}

@Ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public void testThreatInputSerde() throws IOException {
emptyMap(),
new DataSources(),
false,
null,
"security_analytics"
);
BytesStreamOutput monitorOut = new BytesStreamOutput();
Expand Down
Loading