Skip to content

Commit

Permalink
Sigma Aggregation rule fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
petardz authored and sbcd90 committed Sep 29, 2023
1 parent 3c9b23a commit add3527
Show file tree
Hide file tree
Showing 14 changed files with 141 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ public List<AggregationItem> getAggregationItemsFromRule () throws SigmaError {
for (SigmaCondition condition: sigmaRule.getDetection().getParsedCondition()) {
Pair<ConditionItem, AggregationItem> parsedItems = condition.parsed();
AggregationItem aggItem = parsedItems.getRight();
aggItem.setTimeframe(sigmaRule.getDetection().getTimeframe());

Check warning on line 484 in src/main/java/org/opensearch/securityanalytics/model/Rule.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/securityanalytics/model/Rule.java#L484

Added line #L484 was not covered by tests
aggregationItems.add(aggItem);
}
return aggregationItems;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public class AggregationItem implements Serializable {

private Double threshold;

private String timeframe;

public void setAggFunction(String aggFunction) {
this.aggFunction = aggFunction;
}
Expand Down Expand Up @@ -59,4 +61,12 @@ public void setThreshold(Double threshold) {
public Double getThreshold() {
return threshold;
}

public void setTimeframe(String timeframe) {
this.timeframe = timeframe;
}

public String getTimeframe() {
return timeframe;

Check warning on line 70 in src/main/java/org/opensearch/securityanalytics/rules/aggregation/AggregationItem.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/securityanalytics/rules/aggregation/AggregationItem.java#L70

Added line #L70 was not covered by tests
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,10 @@ public AggregationQueries convertAggregation(AggregationItem aggregation) {
fmtAggQuery = String.format(Locale.getDefault(), aggCountQuery, "result_agg", aggregation.getGroupByField());
}
aggBuilder.field(fieldName);
fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, "_cnt", "_cnt", "result_agg", "_cnt", aggregation.getCompOperator(), aggregation.getThreshold());
fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, "_cnt", "_count", "result_agg", "_cnt", aggregation.getCompOperator(), aggregation.getThreshold());

Script script = new Script(String.format(Locale.getDefault(), bucketTriggerScript, "_cnt", aggregation.getCompOperator(), aggregation.getThreshold()));
condition = new BucketSelectorExtAggregationBuilder(bucketTriggerSelectorId, Collections.singletonMap("_cnt", "_cnt"), script, "result_agg", null);
condition = new BucketSelectorExtAggregationBuilder(bucketTriggerSelectorId, Collections.singletonMap("_cnt", "_count"), script, "result_agg", null);
} else {
fmtAggQuery = String.format(Locale.getDefault(), aggQuery, "result_agg", aggregation.getGroupByField(), aggregation.getAggField(), aggregation.getAggFunction(), aggregation.getAggField());
fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, aggregation.getAggField(), aggregation.getAggField(), "result_agg", aggregation.getAggField(), aggregation.getCompOperator(), aggregation.getThreshold());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public List<Object> convertRule(SigmaRule rule) throws SigmaError {
}
queries.add(query);
if (aggItem != null) {
aggItem.setTimeframe(rule.getDetection().getTimeframe());
queries.add(convertAggregation(aggItem));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ public class SigmaDetections {

private List<String> condition;

private String timeframe;

private List<SigmaCondition> parsedCondition;

public SigmaDetections(Map<String, SigmaDetection> detections, List<String> condition) throws SigmaDetectionError {
public SigmaDetections(Map<String, SigmaDetection> detections, List<String> condition, String timeframe) throws SigmaDetectionError {
this.detections = detections;
this.condition = condition;
this.timeframe = timeframe;

if (this.detections.isEmpty()) {
throw new SigmaDetectionError("No detections defined in Sigma rule");
Expand Down Expand Up @@ -55,7 +58,12 @@ protected static SigmaDetections fromDict(Map<String, Object> detectionMap) thro
}
}

return new SigmaDetections(detections, conditionList);
String timeframe = null;
if (detectionMap.containsKey("timeframe")) {
timeframe = detectionMap.get("timeframe").toString();
}

return new SigmaDetections(detections, conditionList, timeframe);
}

public Map<String, SigmaDetection> getDetections() {
Expand All @@ -69,4 +77,8 @@ public List<String> getCondition() {
public List<SigmaCondition> getParsedCondition() {
return parsedCondition;
}

public String getTimeframe() {
return timeframe;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
import org.opensearch.securityanalytics.model.DetectorTrigger;
import org.opensearch.securityanalytics.model.Rule;
import org.opensearch.securityanalytics.model.Value;
import org.opensearch.securityanalytics.rules.aggregation.AggregationItem;
import org.opensearch.securityanalytics.rules.backend.OSQueryBackend;
import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries;
import org.opensearch.securityanalytics.rules.backend.QueryBackend;
Expand Down Expand Up @@ -784,7 +785,8 @@ private IndexMonitorRequest createBucketLevelMonitorRequest(

List<String> indices = detector.getInputs().get(0).getIndices();

AggregationQueries aggregationQueries = queryBackend.convertAggregation(rule.getAggregationItemsFromRule().get(0));
AggregationItem aggItem = rule.getAggregationItemsFromRule().get(0);
AggregationQueries aggregationQueries = queryBackend.convertAggregation(aggItem);

Check warning on line 789 in src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java#L788-L789

Added lines #L788 - L789 were not covered by tests

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.seqNoAndPrimaryTerm(true)
Expand Down Expand Up @@ -814,7 +816,7 @@ private IndexMonitorRequest createBucketLevelMonitorRequest(
? new BoolQueryBuilder()
: QueryBuilders.boolQuery().must(searchSourceBuilder.query());
RangeQueryBuilder timeRangeFilter = QueryBuilders.rangeQuery(TIMESTAMP_FIELD_ALIAS)
.gt("{{period_end}}||-1h")
.gt("{{period_end}}||-" + aggItem.getTimeframe())

Check warning on line 819 in src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java#L819

Added line #L819 was not covered by tests
.lte("{{period_end}}")
.format("epoch_millis");
boolQueryBuilder.must(timeRangeFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ private List<Rule> getQueries(QueryBackend backend, String category, List<String

Rule ruleModel = new Rule(
rule.getId().toString(), NO_VERSION, rule, category,
ruleQueries,
ruleQueries.stream().map(Object::toString).collect(Collectors.toList()),

Check warning on line 298 in src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java#L298

Added line #L298 was not covered by tests
new ArrayList<>(queryFieldNames),
ruleStr
);
Expand Down
25 changes: 25 additions & 0 deletions src/test/java/org/opensearch/securityanalytics/TestHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ public static String productIndexMaxAggRule() {

public static String randomProductDocument(){
return "{\n" +
" \"name\": \"laptop\",\n" +
" \"fieldA\": 123,\n" +
" \"mappedB\": 111,\n" +
" \"fieldC\": \"valueC\"\n" +
Expand Down Expand Up @@ -560,6 +561,9 @@ public static String netFlowMappings() {

public static String productIndexMapping(){
return "\"properties\":{\n" +
" \"name\":{\n" +
" \"type\":\"keyword\"\n" +
" },\n" +
" \"fieldA\":{\n" +
" \"type\":\"long\"\n" +
" },\n" +
Expand Down Expand Up @@ -588,13 +592,32 @@ public static String productIndexAvgAggRule(){
" category: test_category\n" +
" product: test_product\n" +
" detection:\n" +
" timeframe: 5m\n" +
" sel:\n" +
" fieldA: 123\n" +
" fieldB: 111\n" +
" fieldC: valueC\n" +
" condition: sel | avg(fieldA) by fieldC > 110";
}

public static String productIndexCountAggRule(){
return " title: Test\n" +
" id: 39f918f3-981b-4e6f-a975-8af7e507ef2b\n" +
" status: test\n" +
" level: critical\n" +
" description: Detects QuarksPwDump clearing access history in hive\n" +
" author: Florian Roth\n" +
" date: 2017/05/15\n" +
" logsource:\n" +
" category: test_category\n" +
" product: test_product\n" +
" detection:\n" +
" timeframe: 5m\n" +
" sel:\n" +
" name: laptop\n" +
" condition: sel | count(*) by name > 2";
}

public static String randomAggregationRule(String aggFunction, String signAndValue) {
String rule = "title: Remote Encrypting File System Abuse\n" +
"id: 5f92fff9-82e2-48eb-8fc1-8b133556a551\n" +
Expand All @@ -616,6 +639,7 @@ public static String randomAggregationRule(String aggFunction, String signAndVa
" category: application\n" +
" definition: 'Requirements: install and apply the RPC Firewall to all processes with \"audit:true action:block uuid:df1941c5-fe89-4e79-bf10-463657acf44d or c681d488-d850-11d0-8c52-00c04fd90f7e'\n" +
"detection:\n" +
" timeframe: 5m\n" +
" sel:\n" +
" Opcode: Info\n" +
" condition: sel | %s(SeverityValue) by Version %s\n" +
Expand Down Expand Up @@ -646,6 +670,7 @@ public static String randomAggregationRule(String aggFunction, String signAndVa
" category: application\n" +
" definition: 'Requirements: install and apply the RPC Firewall to all processes with \"audit:true action:block uuid:df1941c5-fe89-4e79-bf10-463657acf44d or c681d488-d850-11d0-8c52-00c04fd90f7e'\n" +
"detection:\n" +
" timeframe: 5m\n" +
" sel:\n" +
" Opcode: %s\n" +
" condition: sel | %s(SeverityValue) by Version %s\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.stream.Collectors;
import org.opensearch.securityanalytics.model.DetectorTrigger;

import static org.junit.Assert.assertNotNull;
import static org.opensearch.securityanalytics.TestHelpers.*;
import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE;

Expand Down Expand Up @@ -182,7 +183,7 @@ public void test_searchDetectors_detectorsIndexNotExists() throws IOException {
HttpEntity requestEntity = new StringEntity(request, ContentType.APPLICATION_JSON);
Response searchResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + "_search", Collections.emptyMap(), requestEntity);
Map<String, Object> searchResponseBody = asMap(searchResponse);
Assert.assertNotNull("response is not null", searchResponseBody);
assertNotNull("response is not null", searchResponseBody);
Map<String, Object> searchResponseHits = (Map) searchResponseBody.get("hits");
Map<String, Object> searchResponseTotal = (Map) searchResponseHits.get("total");
Assert.assertEquals(0, searchResponseTotal.get("value"));
Expand Down Expand Up @@ -409,7 +410,7 @@ public void testGettingADetector() throws IOException {
Response getResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + createdId, Collections.emptyMap(), null);
Map<String, Object> responseBody = asMap(getResponse);
Assert.assertEquals(createdId, responseBody.get("_id"));
Assert.assertNotNull(responseBody.get("detector"));
assertNotNull(responseBody.get("detector"));

String detectorTypeInResponse = (String) ((Map<String, Object>)responseBody.get("detector")).get("detector_type");
Assert.assertEquals("Detector type incorrect", randomDetectorType().toLowerCase(Locale.ROOT), detectorTypeInResponse);
Expand Down Expand Up @@ -445,7 +446,7 @@ public void testSearchingDetectors() throws IOException {
HttpEntity requestEntity = new StringEntity(queryJson, ContentType.APPLICATION_JSON);
Response searchResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + "_search", Collections.emptyMap(), requestEntity);
Map<String, Object> searchResponseBody = asMap(searchResponse);
Assert.assertNotNull("response is not null", searchResponseBody);
assertNotNull("response is not null", searchResponseBody);
Map<String, Object> searchResponseHits = (Map) searchResponseBody.get("hits");
Map<String, Object> searchResponseTotal = (Map) searchResponseHits.get("total");
Assert.assertEquals(1, searchResponseTotal.get("value"));
Expand Down Expand Up @@ -613,11 +614,73 @@ public void testCreatingADetectorWithAggregationRules() throws IOException {
HashMap<String, Object> docLevelQuery = (HashMap<String, Object>) ((List<?>) finding.get("queries")).get(0);
String ruleId = docLevelQuery.get("id").toString();
// Verify if the rule id in bucket level finding is the same as rule used for bucket monitor creation
assertEquals(customAvgRuleId, ruleId);
Assert.assertEquals(customAvgRuleId, ruleId);
Response getResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), null);
String getDetectorResponseString = new String(getResponse.getEntity().getContent().readAllBytes());
Assert.assertTrue(getDetectorResponseString.contains(ruleId));
}

public void testAggRuleCount() throws IOException {
String index = createTestIndex(randomIndex(), productIndexMapping());

String customAggRule = createRule(productIndexCountAggRule());

DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(customAggRule)),
getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()));
Detector detector = randomDetectorWithInputs(List.of(input));

Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector));
Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));

Map<String, Object> responseBody = asMap(createResponse);
String detectorId = responseBody.get("_id").toString();

String request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + detectorId + "\"\n" +
" }\n" +
" }\n" +
"}";
List<SearchHit> hits = executeSearch(Detector.DETECTORS_INDEX, request);
SearchHit hit = hits.get(0);

Map<String, Object> detectorAsMap = (Map<String, Object>) hit.getSourceAsMap().get("detector");

String bucketLevelMonitorId = ((List<String>) (detectorAsMap).get("monitor_id")).get(1);
// condition: sel | count(*) by name > 2
indexDoc(index, "1", randomProductDocument());
indexDoc(index, "2", randomProductDocument());
// Verify that 2 documents aren't enough to satisfy trigger condition
Map<String, Object> executeResults = entityAsMap(executeAlertingMonitor(bucketLevelMonitorId, Collections.emptyMap()));
Map<String, Object> trigger = (Map<String, Object>) ((Map<String, Object>)executeResults.get("trigger_results")).entrySet().iterator().next().getValue();
assertEquals(0, ((Map)(trigger.get("agg_result_buckets"))).size() );
// 3 will be fine
indexDoc(index, "3", randomProductDocument());

executeResults = entityAsMap(executeAlertingMonitor(bucketLevelMonitorId, Collections.emptyMap()));
trigger = (Map<String, Object>) ((Map<String, Object>)executeResults.get("trigger_results")).entrySet().iterator().next().getValue();
assertEquals(1, ((Map)(trigger.get("agg_result_buckets"))).size() );
// verify bucket level monitor findings
Map<String, String> params = new HashMap<>();
params.put("detector_id", detectorId);
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);
assertNotNull(getFindingsBody);
Assert.assertEquals(1, getFindingsBody.get("total_findings"));
List<?> findings = (List<?>) getFindingsBody.get("findings");
Assert.assertEquals(findings.size(), 1);
HashMap<String, Object> finding = (HashMap<String, Object>) findings.get(0);
Assert.assertTrue(finding.containsKey("queries"));
HashMap<String, Object> docLevelQuery = (HashMap<String, Object>) ((List<?>) finding.get("queries")).get(0);
String ruleId = docLevelQuery.get("id").toString();
// Verify if the rule id in bucket level finding is the same as rule used for bucket monitor creation
Assert.assertEquals(customAggRule, ruleId);
Response getResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), null);
String getDetectorResponseString = new String(getResponse.getEntity().getContent().readAllBytes());
Assert.assertTrue(getDetectorResponseString.contains(ruleId));
}

public void testUpdateADetector() throws IOException {
String index = createTestIndex(randomIndex(), windowsIndexMapping());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public void testCreatingAggregationRule() throws SigmaError, IOException {
Rule result = Rule.docParse(xcp, null, null);

Assert.assertEquals(1, result.getAggregationQueries().size());
String expected = "{\"aggQuery\":\"{\\\"result_agg\\\":{\\\"terms\\\":{\\\"field\\\":\\\"_index\\\"}}}\",\"bucketTriggerQuery\":\"{\\\"buckets_path\\\":{\\\"_cnt\\\":\\\"_cnt\\\"},\\\"parent_bucket_path\\\":\\\"result_agg\\\",\\\"script\\\":{\\\"source\\\":\\\"params._cnt > 1.0\\\",\\\"lang\\\":\\\"painless\\\"}}\"}";
String expected = "{\"aggQuery\":\"{\\\"result_agg\\\":{\\\"terms\\\":{\\\"field\\\":\\\"_index\\\"}}}\",\"bucketTriggerQuery\":\"{\\\"buckets_path\\\":{\\\"_cnt\\\":\\\"_count\\\"},\\\"parent_bucket_path\\\":\\\"result_agg\\\",\\\"script\\\":{\\\"source\\\":\\\"params._cnt > 1.0\\\",\\\"lang\\\":\\\"painless\\\"}}\"}";
Assert.assertEquals(expected, result.getAggregationQueries().get(0).getValue());
}

Expand Down
Loading

0 comments on commit add3527

Please sign in to comment.