From 088a4aa02d21f1b1022d7efc398f2516989c87a9 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:15:33 -0700 Subject: [PATCH] Sigma Aggregation rule fixes (#622) (#640) Signed-off-by: Subhobrata Dey --- .../findings/FindingsService.java | 13 +- .../securityanalytics/model/Rule.java | 1 + .../rules/aggregation/AggregationItem.java | 10 ++ .../rules/backend/OSQueryBackend.java | 4 +- .../rules/backend/QueryBackend.java | 1 + .../rules/objects/SigmaDetections.java | 16 +- .../TransportIndexDetectorAction.java | 6 +- .../securityanalytics/util/RuleIndices.java | 2 +- .../securityanalytics/TestHelpers.java | 30 ++++ .../resthandler/DetectorMonitorRestApiIT.java | 15 +- .../resthandler/DetectorRestApiIT.java | 151 +++++++++++++++++- .../resthandler/RuleRestApiIT.java | 2 +- .../aggregation/AggregationBackendTests.java | 10 +- .../rules/condition/ConditionTests.java | 6 +- .../rules/objects/SigmaDetectionsTests.java | 2 +- .../rules/objects/SigmaRuleTests.java | 6 +- 16 files changed, 240 insertions(+), 35 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java index 83d8ffbb3..755b124db 100644 --- a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java +++ b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java @@ -59,7 +59,6 @@ public void getFindingsByDetectorId(String detectorId, Table table, ActionListen public void onResponse(GetDetectorResponse getDetectorResponse) { // Get all monitor ids from detector Detector detector = getDetectorResponse.getDetector(); - List monitorIds = detector.getMonitorIds(); ActionListener getFindingsResponseListener = new ActionListener<>() { @Override public void onResponse(GetFindingsResponse resp) { @@ -87,12 +86,20 @@ public void onFailure(Exception e) { // monitor --> detectorId mapping Map monitorToDetectorMapping = new HashMap<>(); detector.getMonitorIds().forEach( - monitorId -> monitorToDetectorMapping.put(monitorId, detector) + monitorId -> { + if (detector.getRuleIdMonitorIdMap().containsKey("chained_findings_monitor")) { + if (!detector.getRuleIdMonitorIdMap().get("chained_findings_monitor").equals(monitorId)) { + monitorToDetectorMapping.put(monitorId, detector); + } + } else { + monitorToDetectorMapping.put(monitorId, detector); + } + } ); // Get findings for all monitor ids FindingsService.this.getFindingsByMonitorIds( monitorToDetectorMapping, - monitorIds, + new ArrayList<>(monitorToDetectorMapping.keySet()), DetectorMonitorConfig.getAllFindingsIndicesPattern(detector.getDetectorType()), table, getFindingsResponseListener diff --git a/src/main/java/org/opensearch/securityanalytics/model/Rule.java b/src/main/java/org/opensearch/securityanalytics/model/Rule.java index fcbd95349..4131252ee 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Rule.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Rule.java @@ -481,6 +481,7 @@ public List getAggregationItemsFromRule () throws SigmaError { for (SigmaCondition condition: sigmaRule.getDetection().getParsedCondition()) { Pair parsedItems = condition.parsed(); AggregationItem aggItem = parsedItems.getRight(); + aggItem.setTimeframe(sigmaRule.getDetection().getTimeframe()); aggregationItems.add(aggItem); } return aggregationItems; diff --git a/src/main/java/org/opensearch/securityanalytics/rules/aggregation/AggregationItem.java b/src/main/java/org/opensearch/securityanalytics/rules/aggregation/AggregationItem.java index c25ed9b43..0d9e8ae7a 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/aggregation/AggregationItem.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/aggregation/AggregationItem.java @@ -20,6 +20,8 @@ public class AggregationItem implements Serializable { private Double threshold; + private String timeframe; + public void setAggFunction(String aggFunction) { this.aggFunction = aggFunction; } @@ -59,4 +61,12 @@ public void setThreshold(Double threshold) { public Double getThreshold() { return threshold; } + + public void setTimeframe(String timeframe) { + this.timeframe = timeframe; + } + + public String getTimeframe() { + return timeframe; + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java index ea26804c5..021b3de3b 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java @@ -381,10 +381,10 @@ public AggregationQueries convertAggregation(AggregationItem aggregation) { fmtAggQuery = String.format(Locale.getDefault(), aggCountQuery, "result_agg", aggregation.getGroupByField()); } aggBuilder.field(fieldName); - fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, "_cnt", "_cnt", "result_agg", "_cnt", aggregation.getCompOperator(), aggregation.getThreshold()); + fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, "_cnt", "_count", "result_agg", "_cnt", aggregation.getCompOperator(), aggregation.getThreshold()); Script script = new Script(String.format(Locale.getDefault(), bucketTriggerScript, "_cnt", aggregation.getCompOperator(), aggregation.getThreshold())); - condition = new BucketSelectorExtAggregationBuilder(bucketTriggerSelectorId, Collections.singletonMap("_cnt", "_cnt"), script, "result_agg", null); + condition = new BucketSelectorExtAggregationBuilder(bucketTriggerSelectorId, Collections.singletonMap("_cnt", "_count"), script, "result_agg", null); } else { fmtAggQuery = String.format(Locale.getDefault(), aggQuery, "result_agg", aggregation.getGroupByField(), aggregation.getAggField(), aggregation.getAggFunction(), aggregation.getAggField()); fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, aggregation.getAggField(), aggregation.getAggField(), "result_agg", aggregation.getAggField(), aggregation.getCompOperator(), aggregation.getThreshold()); diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java index 8a3e4d17c..c63dce05d 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java @@ -97,6 +97,7 @@ public List convertRule(SigmaRule rule) throws SigmaError { } queries.add(query); if (aggItem != null) { + aggItem.setTimeframe(rule.getDetection().getTimeframe()); queries.add(convertAggregation(aggItem)); } } diff --git a/src/main/java/org/opensearch/securityanalytics/rules/objects/SigmaDetections.java b/src/main/java/org/opensearch/securityanalytics/rules/objects/SigmaDetections.java index 235a7df9b..7937ac9fd 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/objects/SigmaDetections.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/objects/SigmaDetections.java @@ -21,11 +21,14 @@ public class SigmaDetections { private List condition; + private String timeframe; + private List parsedCondition; - public SigmaDetections(Map detections, List condition) throws SigmaDetectionError { + public SigmaDetections(Map detections, List condition, String timeframe) throws SigmaDetectionError { this.detections = detections; this.condition = condition; + this.timeframe = timeframe; if (this.detections.isEmpty()) { throw new SigmaDetectionError("No detections defined in Sigma rule"); @@ -55,7 +58,12 @@ protected static SigmaDetections fromDict(Map detectionMap) thro } } - return new SigmaDetections(detections, conditionList); + String timeframe = null; + if (detectionMap.containsKey("timeframe")) { + timeframe = detectionMap.get("timeframe").toString(); + } + + return new SigmaDetections(detections, conditionList, timeframe); } public Map getDetections() { @@ -69,4 +77,8 @@ public List getCondition() { public List getParsedCondition() { return parsedCondition; } + + public String getTimeframe() { + return timeframe; + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index 02d6f438d..0e731153d 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -90,6 +90,7 @@ import org.opensearch.securityanalytics.model.DetectorTrigger; import org.opensearch.securityanalytics.model.Rule; import org.opensearch.securityanalytics.model.Value; +import org.opensearch.securityanalytics.rules.aggregation.AggregationItem; import org.opensearch.securityanalytics.rules.backend.OSQueryBackend; import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries; import org.opensearch.securityanalytics.rules.backend.QueryBackend; @@ -784,7 +785,8 @@ private IndexMonitorRequest createBucketLevelMonitorRequest( List indices = detector.getInputs().get(0).getIndices(); - AggregationQueries aggregationQueries = queryBackend.convertAggregation(rule.getAggregationItemsFromRule().get(0)); + AggregationItem aggItem = rule.getAggregationItemsFromRule().get(0); + AggregationQueries aggregationQueries = queryBackend.convertAggregation(aggItem); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .seqNoAndPrimaryTerm(true) @@ -814,7 +816,7 @@ private IndexMonitorRequest createBucketLevelMonitorRequest( ? new BoolQueryBuilder() : QueryBuilders.boolQuery().must(searchSourceBuilder.query()); RangeQueryBuilder timeRangeFilter = QueryBuilders.rangeQuery(TIMESTAMP_FIELD_ALIAS) - .gt("{{period_end}}||-1h") + .gt("{{period_end}}||-" + (aggItem.getTimeframe() != null? aggItem.getTimeframe(): "1h")) .lte("{{period_end}}") .format("epoch_millis"); boolQueryBuilder.must(timeRangeFilter); diff --git a/src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java b/src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java index be5b9f1db..d38e82d90 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java +++ b/src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java @@ -300,7 +300,7 @@ private List getQueries(QueryBackend backend, String category, List(queryFieldNames), ruleStr ); diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index 5333df17a..1aaa09086 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -88,6 +88,11 @@ public static Detector randomDetectorWithTriggers(List rules, List rules, List triggers, Schedule schedule, boolean enabled) { + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), Collections.emptyList(), + rules.stream().map(DetectorRule::new).collect(Collectors.toList())); + return randomDetector(null, null, null, List.of(input), triggers, schedule, enabled, null, null); + } public static Detector randomDetectorWithTriggers(List rules, List triggers, String detectorType, DetectorInput input) { return randomDetector(null, detectorType, null, List.of(input), triggers, null, null, null, null); @@ -352,6 +357,7 @@ public static String productIndexMaxAggRule() { public static String randomProductDocument(){ return "{\n" + + " \"name\": \"laptop\",\n" + " \"fieldA\": 123,\n" + " \"mappedB\": 111,\n" + " \"fieldC\": \"valueC\"\n" + @@ -563,6 +569,9 @@ public static String netFlowMappings() { public static String productIndexMapping(){ return "\"properties\":{\n" + + " \"name\":{\n" + + " \"type\":\"keyword\"\n" + + " },\n" + " \"fieldA\":{\n" + " \"type\":\"long\"\n" + " },\n" + @@ -591,6 +600,7 @@ public static String productIndexAvgAggRule(){ " category: test_category\n" + " product: test_product\n" + " detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " fieldA: 123\n" + " fieldB: 111\n" + @@ -598,6 +608,24 @@ public static String productIndexAvgAggRule(){ " condition: sel | avg(fieldA) by fieldC > 110"; } + public static String productIndexCountAggRule(){ + return " title: Test\n" + + " id: 39f918f3-981b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " timeframe: 5m\n" + + " sel:\n" + + " name: laptop\n" + + " condition: sel | count(*) by name > 2"; + } + public static String randomAggregationRule(String aggFunction, String signAndValue) { String rule = "title: Remote Encrypting File System Abuse\n" + "id: 5f92fff9-82e2-48eb-8fc1-8b133556a551\n" + @@ -619,6 +647,7 @@ public static String randomAggregationRule(String aggFunction, String signAndVa " category: application\n" + " definition: 'Requirements: install and apply the RPC Firewall to all processes with \"audit:true action:block uuid:df1941c5-fe89-4e79-bf10-463657acf44d or c681d488-d850-11d0-8c52-00c04fd90f7e'\n" + "detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " Opcode: Info\n" + " condition: sel | %s(SeverityValue) by Version %s\n" + @@ -649,6 +678,7 @@ public static String randomAggregationRule(String aggFunction, String signAndVa " category: application\n" + " definition: 'Requirements: install and apply the RPC Firewall to all processes with \"audit:true action:block uuid:df1941c5-fe89-4e79-bf10-463657acf44d or c681d488-d850-11d0-8c52-00c04fd90f7e'\n" + "detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " Opcode: %s\n" + " condition: sel | %s(SeverityValue) by Version %s\n" + diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java index df22c6eb9..11414cd62 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java @@ -1483,7 +1483,7 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor Map getFindingsBody = entityAsMap(getFindingsResponse); assertNotNull(getFindingsBody); - assertEquals(10, getFindingsBody.get("total_findings")); + assertEquals(6, getFindingsBody.get("total_findings")); String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); @@ -1495,7 +1495,6 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor List> findings = (List)getFindingsBody.get("findings"); Set docLevelRules = new HashSet<>(List.of(randomDocRuleId)); - List bucketLevelMonitorFindingDocs = new ArrayList<>(); for(Map finding : findings) { List> queries = (List>) finding.get("queries"); Set findingRules = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet()); @@ -1504,16 +1503,10 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor docLevelFinding.addAll((List) finding.get("related_doc_ids")); } else { List findingDocs = (List) finding.get("related_doc_ids"); - if (((Map) ((List) finding.get("queries")).get(0)).get("query").equals("_id:*")) { - Assert.assertEquals(1, findingDocs.size()); - bucketLevelMonitorFindingDocs.addAll(findingDocs); - } else { - Assert.assertEquals(4, findingDocs.size()); - assertTrue(Arrays.asList("1", "2", "3", "4").containsAll(findingDocs)); - } + Assert.assertEquals(4, findingDocs.size()); + assertTrue(Arrays.asList("1", "2", "3", "4").containsAll(findingDocs)); } } - assertTrue(bucketLevelMonitorFindingDocs.containsAll(Arrays.asList("1", "2", "3", "4"))); // Verify doc level finding assertTrue(Arrays.asList("1", "2", "3", "4", "5").containsAll(docLevelFinding)); } @@ -1652,7 +1645,7 @@ public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLeve // Assert findings assertNotNull(getFindingsBody); - assertEquals(33, getFindingsBody.get("total_findings")); + assertEquals(19, getFindingsBody.get("total_findings")); } diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java index 622656007..f6cd2a4e9 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java @@ -4,6 +4,7 @@ */ package org.opensearch.securityanalytics.resthandler; +import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -14,11 +15,13 @@ import org.apache.http.entity.StringEntity; import org.apache.http.message.BasicHeader; import org.junit.Assert; +import org.junit.Ignore; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; import org.opensearch.client.ResponseException; +import org.opensearch.commons.alerting.model.IntervalSchedule; import org.opensearch.commons.alerting.model.Monitor.MonitorType; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -38,6 +41,7 @@ import java.util.stream.Collectors; import org.opensearch.securityanalytics.model.DetectorTrigger; +import static org.junit.Assert.assertNotNull; import static org.opensearch.securityanalytics.TestHelpers.*; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; @@ -165,6 +169,83 @@ public void testCreatingADetector() throws IOException { Assert.assertEquals(5, noOfSigmaRuleMatches); } + @Ignore + public void testCreatingADetectorScheduledJobFinding() throws IOException, InterruptedException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + + Detector detector = randomDetectorWithTriggersAndScheduleAndEnabled(getRandomPrePackagedRules(), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of())), + new IntervalSchedule(1, ChronoUnit.MINUTES, null), true); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + int createdVersion = Integer.parseInt(responseBody.get("_version").toString()); + Assert.assertNotEquals("response is missing Id", Detector.NO_ID, createdId); + Assert.assertTrue("incorrect version", createdVersion > 0); + Assert.assertEquals("Incorrect Location header", String.format(Locale.getDefault(), "%s/%s", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, createdId), createResponse.getHeader("Location")); + Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("rule_topic_index")); + Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("findings_index")); + Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("alert_index")); + + String detectorTypeInResponse = (String) ((Map)responseBody.get("detector")).get("detector_type"); + Assert.assertEquals("Detector type incorrect", randomDetectorType().toLowerCase(Locale.ROOT), detectorTypeInResponse); + + Thread.sleep(30000); + indexDoc(index, "1", randomDoc()); + Thread.sleep(70000); + + // Call GetFindings API + Map params = new HashMap<>(); + params.put("detector_id", createdId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + Assert.assertEquals(1, getFindingsBody.get("total_findings")); + + // Call GetAlerts API + params = new HashMap<>(); + params.put("detector_id", createdId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + // TODO enable asserts here when able + Assert.assertEquals(1, getAlertsBody.get("total_alerts")); + + Thread.sleep(30000); + indexDoc(index, "2", randomDoc()); + Thread.sleep(70000); + + // Call GetFindings API + params = new HashMap<>(); + params.put("detector_id", createdId); + getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + getFindingsBody = entityAsMap(getFindingsResponse); + Assert.assertEquals(2, getFindingsBody.get("total_findings")); + + // Call GetAlerts API + params = new HashMap<>(); + params.put("detector_id", createdId); + getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + getAlertsBody = asMap(getAlertsResponse); + // TODO enable asserts here when able + Assert.assertEquals(2, getAlertsBody.get("total_alerts")); + } + @SuppressWarnings("unchecked") public void test_searchDetectors_detectorsIndexNotExists() throws IOException { try { @@ -182,7 +263,7 @@ public void test_searchDetectors_detectorsIndexNotExists() throws IOException { HttpEntity requestEntity = new StringEntity(request, ContentType.APPLICATION_JSON); Response searchResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + "_search", Collections.emptyMap(), requestEntity); Map searchResponseBody = asMap(searchResponse); - Assert.assertNotNull("response is not null", searchResponseBody); + assertNotNull("response is not null", searchResponseBody); Map searchResponseHits = (Map) searchResponseBody.get("hits"); Map searchResponseTotal = (Map) searchResponseHits.get("total"); Assert.assertEquals(0, searchResponseTotal.get("value")); @@ -409,7 +490,7 @@ public void testGettingADetector() throws IOException { Response getResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + createdId, Collections.emptyMap(), null); Map responseBody = asMap(getResponse); Assert.assertEquals(createdId, responseBody.get("_id")); - Assert.assertNotNull(responseBody.get("detector")); + assertNotNull(responseBody.get("detector")); String detectorTypeInResponse = (String) ((Map)responseBody.get("detector")).get("detector_type"); Assert.assertEquals("Detector type incorrect", randomDetectorType().toLowerCase(Locale.ROOT), detectorTypeInResponse); @@ -445,7 +526,7 @@ public void testSearchingDetectors() throws IOException { HttpEntity requestEntity = new NStringEntity(queryJson, ContentType.APPLICATION_JSON); Response searchResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + "_search", Collections.emptyMap(), requestEntity); Map searchResponseBody = asMap(searchResponse); - Assert.assertNotNull("response is not null", searchResponseBody); + assertNotNull("response is not null", searchResponseBody); Map searchResponseHits = (Map) searchResponseBody.get("hits"); Map searchResponseTotal = (Map) searchResponseHits.get("total"); Assert.assertEquals(1, searchResponseTotal.get("value")); @@ -613,11 +694,73 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { HashMap docLevelQuery = (HashMap) ((List) finding.get("queries")).get(0); String ruleId = docLevelQuery.get("id").toString(); // Verify if the rule id in bucket level finding is the same as rule used for bucket monitor creation - assertEquals(customAvgRuleId, ruleId); + Assert.assertEquals(customAvgRuleId, ruleId); Response getResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), null); String getDetectorResponseString = new String(getResponse.getEntity().getContent().readAllBytes()); Assert.assertTrue(getDetectorResponseString.contains(ruleId)); } + + public void testAggRuleCount() throws IOException { + String index = createTestIndex(randomIndex(), productIndexMapping()); + + String customAggRule = createRule(productIndexCountAggRule()); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(customAggRule)), + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + String detectorId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + Map detectorAsMap = (Map) hit.getSourceAsMap().get("detector"); + + String bucketLevelMonitorId = ((List) (detectorAsMap).get("monitor_id")).get(1); + // condition: sel | count(*) by name > 2 + indexDoc(index, "1", randomProductDocument()); + indexDoc(index, "2", randomProductDocument()); + // Verify that 2 documents aren't enough to satisfy trigger condition + Map executeResults = entityAsMap(executeAlertingMonitor(bucketLevelMonitorId, Collections.emptyMap())); + Map trigger = (Map) ((Map)executeResults.get("trigger_results")).entrySet().iterator().next().getValue(); + assertEquals(0, ((Map)(trigger.get("agg_result_buckets"))).size() ); + // 3 will be fine + indexDoc(index, "3", randomProductDocument()); + + executeResults = entityAsMap(executeAlertingMonitor(bucketLevelMonitorId, Collections.emptyMap())); + trigger = (Map) ((Map)executeResults.get("trigger_results")).entrySet().iterator().next().getValue(); + assertEquals(1, ((Map)(trigger.get("agg_result_buckets"))).size() ); + // verify bucket level monitor findings + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + assertNotNull(getFindingsBody); + Assert.assertEquals(1, getFindingsBody.get("total_findings")); + List findings = (List) getFindingsBody.get("findings"); + Assert.assertEquals(findings.size(), 1); + HashMap finding = (HashMap) findings.get(0); + Assert.assertTrue(finding.containsKey("queries")); + HashMap docLevelQuery = (HashMap) ((List) finding.get("queries")).get(0); + String ruleId = docLevelQuery.get("id").toString(); + // Verify if the rule id in bucket level finding is the same as rule used for bucket monitor creation + Assert.assertEquals(customAggRule, ruleId); + Response getResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), null); + String getDetectorResponseString = new String(getResponse.getEntity().getContent().readAllBytes()); + Assert.assertTrue(getDetectorResponseString.contains(ruleId)); + } + public void testUpdateADetector() throws IOException { String index = createTestIndex(randomIndex(), windowsIndexMapping()); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/RuleRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/RuleRestApiIT.java index 439c7c9cb..b0ed9dc32 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/RuleRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/RuleRestApiIT.java @@ -153,7 +153,7 @@ public void testCreatingAggregationRule() throws SigmaError, IOException { Rule result = Rule.docParse(xcp, null, null); Assert.assertEquals(1, result.getAggregationQueries().size()); - String expected = "{\"aggQuery\":\"{\\\"result_agg\\\":{\\\"terms\\\":{\\\"field\\\":\\\"_index\\\"}}}\",\"bucketTriggerQuery\":\"{\\\"buckets_path\\\":{\\\"_cnt\\\":\\\"_cnt\\\"},\\\"parent_bucket_path\\\":\\\"result_agg\\\",\\\"script\\\":{\\\"source\\\":\\\"params._cnt > 1.0\\\",\\\"lang\\\":\\\"painless\\\"}}\"}"; + String expected = "{\"aggQuery\":\"{\\\"result_agg\\\":{\\\"terms\\\":{\\\"field\\\":\\\"_index\\\"}}}\",\"bucketTriggerQuery\":\"{\\\"buckets_path\\\":{\\\"_cnt\\\":\\\"_count\\\"},\\\"parent_bucket_path\\\":\\\"result_agg\\\",\\\"script\\\":{\\\"source\\\":\\\"params._cnt > 1.0\\\",\\\"lang\\\":\\\"painless\\\"}}\"}"; Assert.assertEquals(expected, result.getAggregationQueries().get(0).getValue()); } diff --git a/src/test/java/org/opensearch/securityanalytics/rules/aggregation/AggregationBackendTests.java b/src/test/java/org/opensearch/securityanalytics/rules/aggregation/AggregationBackendTests.java index 43db549c8..395f15a79 100644 --- a/src/test/java/org/opensearch/securityanalytics/rules/aggregation/AggregationBackendTests.java +++ b/src/test/java/org/opensearch/securityanalytics/rules/aggregation/AggregationBackendTests.java @@ -36,6 +36,7 @@ public void testCountAggregation() throws SigmaError, IOException { " category: test_category\n" + " product: test_product\n" + " detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " fieldA: valueA\n" + " fieldB: valueB\n" + @@ -50,7 +51,7 @@ public void testCountAggregation() throws SigmaError, IOException { String bucketTriggerQuery = aggQueries.getBucketTriggerQuery(); Assert.assertEquals("{\"result_agg\":{\"terms\":{\"field\":\"_index\"}}}", aggQuery); - Assert.assertEquals("{\"buckets_path\":{\"_cnt\":\"_cnt\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params._cnt > 1.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); + Assert.assertEquals("{\"buckets_path\":{\"_cnt\":\"_count\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params._cnt > 1.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); } public void testCountAggregationWithGroupBy() throws IOException, SigmaError { @@ -67,6 +68,7 @@ public void testCountAggregationWithGroupBy() throws IOException, SigmaError { " category: test_category\n" + " product: test_product\n" + " detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " fieldA: valueA\n" + " fieldB: valueB\n" + @@ -81,7 +83,7 @@ public void testCountAggregationWithGroupBy() throws IOException, SigmaError { String bucketTriggerQuery = aggQueries.getBucketTriggerQuery(); Assert.assertEquals("{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"}}}", aggQuery); - Assert.assertEquals("{\"buckets_path\":{\"_cnt\":\"_cnt\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params._cnt > 1.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); + Assert.assertEquals("{\"buckets_path\":{\"_cnt\":\"_count\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params._cnt > 1.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); } public void testSumAggregationWithGroupBy() throws IOException, SigmaError { @@ -98,6 +100,7 @@ public void testSumAggregationWithGroupBy() throws IOException, SigmaError { " category: test_category\n" + " product: test_product\n" + " detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " fieldA: valueA\n" + " fieldB: valueB\n" + @@ -132,6 +135,7 @@ public void testMinAggregationWithGroupBy() throws IOException, SigmaError { " category: test_category\n" + " product: test_product\n" + " detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " fieldA: valueA\n" + " fieldB: valueB\n" + @@ -163,6 +167,7 @@ public void testMaxAggregationWithGroupBy() throws IOException, SigmaError { " category: test_category\n" + " product: test_product\n" + " detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " fieldA: valueA\n" + " fieldB: valueB\n" + @@ -194,6 +199,7 @@ public void testAvgAggregationWithGroupBy() throws IOException, SigmaError { " category: test_category\n" + " product: test_product\n" + " detection:\n" + + " timeframe: 5m\n" + " sel:\n" + " fieldA: valueA\n" + " fieldB: valueB\n" + diff --git a/src/test/java/org/opensearch/securityanalytics/rules/condition/ConditionTests.java b/src/test/java/org/opensearch/securityanalytics/rules/condition/ConditionTests.java index 5d778406e..cd1f2d969 100644 --- a/src/test/java/org/opensearch/securityanalytics/rules/condition/ConditionTests.java +++ b/src/test/java/org/opensearch/securityanalytics/rules/condition/ConditionTests.java @@ -345,7 +345,7 @@ private SigmaDetections sigmaSimpleDetections() throws SigmaError { detections.put("other", other); - return new SigmaDetections(detections, Collections.emptyList()); + return new SigmaDetections(detections, Collections.emptyList(), null); } private SigmaDetections sigmaDetections() throws SigmaError { @@ -394,7 +394,7 @@ private SigmaDetections sigmaDetections() throws SigmaError { SigmaDetection detection7 = new SigmaDetection(List.of(Either.left(detectionItem11)), null); detections.put("empty-field", detection7); - return new SigmaDetections(detections, Collections.emptyList()); + return new SigmaDetections(detections, Collections.emptyList(), null); } private SigmaDetections sigmaInvalidDetections() throws SigmaError { @@ -405,6 +405,6 @@ private SigmaDetections sigmaInvalidDetections() throws SigmaError { SigmaDetection detection = new SigmaDetection(List.of(Either.left(detectionItem)), null); detections.put("null-keyword", detection); - return new SigmaDetections(detections, Collections.emptyList()); + return new SigmaDetections(detections, Collections.emptyList(), null); } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/rules/objects/SigmaDetectionsTests.java b/src/test/java/org/opensearch/securityanalytics/rules/objects/SigmaDetectionsTests.java index dba65b1d0..dcab12544 100644 --- a/src/test/java/org/opensearch/securityanalytics/rules/objects/SigmaDetectionsTests.java +++ b/src/test/java/org/opensearch/securityanalytics/rules/objects/SigmaDetectionsTests.java @@ -50,7 +50,7 @@ public void testSigmaDetectionsFromDict() throws SigmaError{ SigmaDetection detection = new SigmaDetection(List.of(Either.left(detectionItem1), Either.left(detectionItem2), Either.left(detectionItem4)), Either.right(ConditionOR.class)); - SigmaDetections expectedSigmaDetections = new SigmaDetections(Collections.singletonMap("selection", detection), Collections.singletonList("selection")); + SigmaDetections expectedSigmaDetections = new SigmaDetections(Collections.singletonMap("selection", detection), Collections.singletonList("selection"), null); Assert.assertEquals(expectedSigmaDetections.getCondition().size(), actualSigmaDetections.getCondition().size()); Assert.assertEquals(expectedSigmaDetections.getCondition().get(0), actualSigmaDetections.getCondition().get(0)); diff --git a/src/test/java/org/opensearch/securityanalytics/rules/objects/SigmaRuleTests.java b/src/test/java/org/opensearch/securityanalytics/rules/objects/SigmaRuleTests.java index 385d31ecb..2246404e5 100644 --- a/src/test/java/org/opensearch/securityanalytics/rules/objects/SigmaRuleTests.java +++ b/src/test/java/org/opensearch/securityanalytics/rules/objects/SigmaRuleTests.java @@ -134,7 +134,7 @@ public void testSigmaRuleNoneToList() throws SigmaRegularExpressionError, SigmaV SigmaDetection detection = new SigmaDetection(Collections.singletonList(Either.left(detectionItem)), Either.right(ConditionOR.class)); - SigmaDetections detections = new SigmaDetections(Collections.singletonMap("selection", detection), Collections.singletonList("selection")); + SigmaDetections detections = new SigmaDetections(Collections.singletonMap("selection", detection), Collections.singletonList("selection"), null); SigmaRule rule = new SigmaRule("Test", logSource, detections, null, null, null, null, null, null, null, null, null, null, null); @@ -195,7 +195,7 @@ public void testSigmaRuleFromYaml() throws SigmaError, ParseException { public void testEmptyDetection() { Exception exception = assertThrows(SigmaDetectionError.class, () -> { - new SigmaDetections(Collections.emptyMap(), Collections.emptyList()); + new SigmaDetections(Collections.emptyMap(), Collections.emptyList(), null); }); String expectedMessage = "No detections defined in Sigma rule"; @@ -216,7 +216,7 @@ private SigmaRule sigmaRule() throws SigmaRegularExpressionError, SigmaValueErro SigmaDetection detection = new SigmaDetection(List.of(Either.left(detectionItem1), Either.left(detectionItem2), Either.left(detectionItem4)), Either.right(ConditionOR.class)); - SigmaDetections detections = new SigmaDetections(Collections.singletonMap("selection", detection), Collections.singletonList("selection")); + SigmaDetections detections = new SigmaDetections(Collections.singletonMap("selection", detection), Collections.singletonList("selection"), null); SimpleDateFormat formatter = new SimpleDateFormat("yyyy/MM/dd", Locale.getDefault()); Date ruleDate = formatter.parse("2017/05/15");