Skip to content

Commit

Permalink
add field based rules support in correlation engine
Browse files Browse the repository at this point in the history
Signed-off-by: Subhobrata Dey <[email protected]>
  • Loading branch information
sbcd90 committed Nov 30, 2023
1 parent 5d5d6dd commit 27f320f
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ private void getValidDocuments(String detectorType, List<String> indices, List<C
@Override
public void onResponse(MultiSearchResponse items) {
MultiSearchResponse.Item[] responses = items.getResponses();
List<Triple<CorrelationRule, SearchHit[], String>> filteredCorrelationRules = new ArrayList<>();
List<FilteredCorrelationRule> filteredCorrelationRules = new ArrayList<>();

int idx = 0;
for (MultiSearchResponse.Item response : responses) {
Expand All @@ -320,17 +320,17 @@ public void onResponse(MultiSearchResponse items) {
}

if (response.getResponse().getHits().getTotalHits().value > 0L) {
filteredCorrelationRules.add(Triple.of(validCorrelationRules.get(idx),
filteredCorrelationRules.add(new FilteredCorrelationRule(validCorrelationRules.get(idx),
response.getResponse().getHits().getHits(), validFields.get(idx)));
}
++idx;
}

Map<String, List<CorrelationQuery>> categoryToQueriesMap = new HashMap<>();
Map<String, Long> categoryToTimeWindowMap = new HashMap<>();
for (Triple<CorrelationRule, SearchHit[], String> rule: filteredCorrelationRules) {
List<CorrelationQuery> queries = rule.getLeft().getCorrelationQueries();
Long timeWindow = rule.getLeft().getCorrTimeWindow();
for (FilteredCorrelationRule rule: filteredCorrelationRules) {
List<CorrelationQuery> queries = rule.correlationRule.getCorrelationQueries();
Long timeWindow = rule.correlationRule.getCorrTimeWindow();

for (CorrelationQuery query: queries) {
List<CorrelationQuery> correlationQueries;
Expand All @@ -348,10 +348,10 @@ public void onResponse(MultiSearchResponse items) {
if (query.getField() == null) {
correlationQueries.add(query);
} else {
SearchHit[] hits = rule.getMiddle();
SearchHit[] hits = rule.filteredDocs;
StringBuilder qb = new StringBuilder(query.getField()).append(":(");
for (int i = 0; i < hits.length; ++i) {
String value = hits[i].field(rule.getRight()).getValue();
String value = hits[i].field(rule.field).getValue();
qb.append(value);
if (i < hits.length-1) {
qb.append(" OR ");
Expand All @@ -368,7 +368,7 @@ public void onResponse(MultiSearchResponse items) {
}
}
searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap,
filteredCorrelationRules.stream().map(Triple::getLeft).map(CorrelationRule::getId).collect(Collectors.toList()),
filteredCorrelationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()),
autoCorrelations
);
}
Expand Down Expand Up @@ -630,15 +630,15 @@ public DocSearchCriteria(List<String> indices, List<String> queries, List<String
}
}

static class ParentJoinCriteria {
String category;
String index;
String parentJoinQuery;
static class FilteredCorrelationRule {
CorrelationRule correlationRule;
SearchHit[] filteredDocs;
String field;

public ParentJoinCriteria(String category, String index, String parentJoinQuery) {
this.category = category;
this.index = index;
this.parentJoinQuery = parentJoinQuery;
public FilteredCorrelationRule(CorrelationRule correlationRule, SearchHit[] filteredDocs, String field) {
this.correlationRule = correlationRule;
this.filteredDocs = filteredDocs;
this.field = field;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,11 @@ public class SecurityAnalyticsSettings {
Setting.Property.NodeScope, Setting.Property.Dynamic
);

/**
* Setting which enables auto correlations
*/
public static final Setting<Boolean> ENABLE_AUTO_CORRELATIONS = Setting.boolSetting(
"plugins.security_analytics.enable_auto_correlations",
"plugins.security_analytics.auto_correlations_enabled",
false,
Setting.Property.NodeScope, Setting.Property.Dynamic
);
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/mappings/finding_mapping.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"dynamic": "strict",
"_meta" : {
"schema_version": 3
"schema_version": 4
},
"properties": {
"schema_version": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public static String randomRule() {
"level: high";
}

public static String randomRuleForCorrelations(String value) {
public static String randomCloudtrailRuleForCorrelations(String value) {
return "id: 5f92fff9-82e2-48ab-8fc1-8b133556a551\n" +
"logsource:\n" +
" product: cloudtrail\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.securityanalytics.correlation;

import org.apache.hc.core5.http.HttpStatus;
import org.apache.hc.core5.http.io.entity.StringEntity;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Assert;
Expand All @@ -13,8 +14,10 @@
import org.opensearch.search.SearchHit;
import org.opensearch.securityanalytics.SecurityAnalyticsPlugin;
import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase;
import org.opensearch.securityanalytics.TestHelpers;
import org.opensearch.securityanalytics.model.CorrelationQuery;
import org.opensearch.securityanalytics.model.CorrelationRule;
import org.opensearch.securityanalytics.model.CustomLogType;
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.model.DetectorInput;
import org.opensearch.securityanalytics.model.DetectorRule;
Expand Down Expand Up @@ -553,14 +556,14 @@ public void testBasicCorrelationEngineWorkflowWithFieldBasedRules() throws IOExc
Response response = client().performRequest(createMappingRequest);
assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode());

String rule1 = randomRuleForCorrelations("CreateUser");
String rule1 = randomCloudtrailRuleForCorrelations("CreateUser");
Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"),
new StringEntity(rule1), new BasicHeader("Content-Type", "application/json"));
Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
String createdId1 = responseBody.get("_id").toString();

String rule2 = randomRuleForCorrelations("DeleteUser");
String rule2 = randomCloudtrailRuleForCorrelations("DeleteUser");
createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"),
new StringEntity(rule2), new BasicHeader("Content-Type", "application/json"));
Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse));
Expand Down Expand Up @@ -722,14 +725,14 @@ public void testBasicCorrelationEngineWorkflowWithFieldBasedRulesAndDynamicTimeW
Response response = client().performRequest(createMappingRequest);
assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode());

String rule1 = randomRuleForCorrelations("CreateUser");
String rule1 = randomCloudtrailRuleForCorrelations("CreateUser");
Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"),
new StringEntity(rule1), new BasicHeader("Content-Type", "application/json"));
Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
String createdId1 = responseBody.get("_id").toString();

String rule2 = randomRuleForCorrelations("DeleteUser");
String rule2 = randomCloudtrailRuleForCorrelations("DeleteUser");
createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"),
new StringEntity(rule2), new BasicHeader("Content-Type", "application/json"));
Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse));
Expand Down Expand Up @@ -805,6 +808,114 @@ public void testBasicCorrelationEngineWorkflowWithFieldBasedRulesAndDynamicTimeW
Assert.assertEquals(2, count);
}

public void testBasicCorrelationEngineWorkflowWithCustomLogTypes() throws IOException, InterruptedException {
LogIndices indices = new LogIndices();
indices.vpcFlowsIndex = createTestIndex("vpc_flow1", vpcFlowMappings());

String vpcFlowMonitorId = createVpcFlowDetector(indices.vpcFlowsIndex);
String index = createTestIndex(randomIndex(), windowsIndexMapping());

CustomLogType customLogType = TestHelpers.randomCustomLogType(null, null, null, "Custom");
Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.CUSTOM_LOG_TYPE_URI, Collections.emptyMap(), toHttpEntity(customLogType));
Assert.assertEquals("Create custom log type failed", RestStatus.CREATED, restStatus(createResponse));

// Execute CreateMappingsAction to add alias mapping for index
Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
// both req params and req body are supported
createMappingRequest.setJsonEntity(
"{ \"index_name\":\"" + index + "\"," +
" \"rule_topic\":\"" + customLogType.getName() + "\", " +
" \"partial\":true, " +
" \"alias_mappings\":{}" +
"}"
);

Response response = client().performRequest(createMappingRequest);
assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode());

String rule = randomRule();

createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", customLogType.getName()),
new StringEntity(rule), new BasicHeader("Content-Type", "application/json"));
Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse));

Map<String, Object> responseBody = asMap(createResponse);
String createdId = responseBody.get("_id").toString();

DetectorInput input = new DetectorInput("custom log type detector for security analytics", List.of(index), List.of(new DetectorRule(createdId)),
List.of());
Detector detector = randomDetectorWithInputs(List.of(input), customLogType.getName());

createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector));
Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));

responseBody = asMap(createResponse);
createdId = responseBody.get("_id").toString();

String detectorTypeInResponse = (String) ((Map<String, Object>)responseBody.get("detector")).get("detector_type");
Assert.assertEquals("Detector type incorrect", customLogType.getName(), detectorTypeInResponse);

String request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + createdId + "\"\n" +
" }\n" +
" }\n" +
"}";
List<SearchHit> hits = executeSearch(Detector.DETECTORS_INDEX, request);
SearchHit hit = hits.get(0);

String monitorId = ((List<String>) ((Map<String, Object>) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0);
String ruleId = createNetworkToCustomLogTypeFieldBasedRule(indices, customLogType.getName(), index);

indexDoc(index, "1", randomDoc());
Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());
Map<String, Object> executeResults = entityAsMap(executeResponse);
int noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
Assert.assertEquals(1, noOfSigmaRuleMatches);

indexDoc(indices.vpcFlowsIndex, "1", randomVpcFlowDoc());
executeResponse = executeAlertingMonitor(vpcFlowMonitorId, Collections.emptyMap());
executeResults = entityAsMap(executeResponse);
noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
Assert.assertEquals(1, noOfSigmaRuleMatches);
Thread.sleep(5000);

Map<String, String> params = new HashMap<>();
params.put("detectorType", customLogType.getName());
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);
String finding = ((List<Map<String, Object>>) getFindingsBody.get("findings")).get(0).get("id").toString();

int count = 0;
while (true) {
try {
List<Map<String, Object>> correlatedFindings = searchCorrelatedFindings(finding, customLogType.getName(), 300000L, 10);
if (correlatedFindings.size() == 1) {
Assert.assertTrue(true);

Assert.assertTrue(correlatedFindings.get(0).get("rules") instanceof List);

for (var correlatedFinding: correlatedFindings) {
if (correlatedFinding.get("detector_type").equals("network")) {
Assert.assertEquals(1, ((List<String>) correlatedFinding.get("rules")).size());
Assert.assertTrue(((List<String>) correlatedFinding.get("rules")).contains(ruleId));
}
}
break;
}
} catch (Exception ex) {
// suppress ex
}
++count;
Thread.sleep(5000);
if (count >= 12) {
Assert.assertTrue(false);
break;
}
}
}

private LogIndices createIndices() throws IOException {
LogIndices indices = new LogIndices();
indices.adLdapLogsIndex = createTestIndex("ad_logs", adLdapLogMappings());
Expand All @@ -828,6 +939,19 @@ private String createNetworkToWindowsFieldBasedRule(LogIndices indices) throws I
return entityAsMap(response).get("_id").toString();
}

private String createNetworkToCustomLogTypeFieldBasedRule(LogIndices indices, String customLogTypeName, String customLogTypeIndex) throws IOException {
CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, null, "network", "srcaddr");
CorrelationQuery query4 = new CorrelationQuery(customLogTypeIndex, null, customLogTypeName, "SourceIp");

CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to custom log type", List.of(query1, query4), 300000L);
Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules");
request.setJsonEntity(toJsonString(rule));
Response response = client().performRequest(request);

Assert.assertEquals(201, response.getStatusLine().getStatusCode());
return entityAsMap(response).get("_id").toString();
}

private String createNetworkToAdLdapToWindowsRule(LogIndices indices) throws IOException {
CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "dstaddr:4.5.6.7", "network", null);
CorrelationQuery query2 = new CorrelationQuery(indices.adLdapLogsIndex, "ResultType:50126", "ad_ldap", null);
Expand Down

0 comments on commit 27f320f

Please sign in to comment.