diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index cfff7da26..086c26f91 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -91,7 +91,7 @@ public void onSearchDetectorResponse(Detector detector, Finding finding) { onAutoCorrelations(detector, finding, Map.of()); } } catch (IOException ex) { - correlateFindingAction.onFailures(ex); + onFailure(ex); } } @@ -114,102 +114,86 @@ private void generateAutoCorrelations(Detector detector, Finding finding) throws SearchRequest request = new SearchRequest(); request.source(searchSourceBuilder); - logTypeService.searchLogTypes(request, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - MultiSearchRequest mSearchRequest = new MultiSearchRequest(); - SearchHit[] logTypes = response.getHits().getHits(); - List logTypeNames = new ArrayList<>(); - for (SearchHit logType: logTypes) { - String logTypeName = logType.getSourceAsMap().get("name").toString(); - logTypeNames.add(logTypeName); - - RangeQueryBuilder queryBuilder = QueryBuilders.rangeQuery("timestamp") - .gte(findingTimestamp - corrTimeWindow) - .lte(findingTimestamp + corrTimeWindow); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.size(10000); - searchSourceBuilder.fetchField("queries"); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(DetectorMonitorConfig.getAllFindingsIndicesPattern(logTypeName)); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - mSearchRequest.add(searchRequest); - } + logTypeService.searchLogTypes(request, ActionListener.wrap(response -> { + MultiSearchRequest mSearchRequest = new MultiSearchRequest(); + SearchHit[] logTypes = response.getHits().getHits(); + List logTypeNames = new ArrayList<>(); + for (SearchHit logType: logTypes) { + String logTypeName = logType.getSourceAsMap().get("name").toString(); + logTypeNames.add(logTypeName); + + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery("timestamp") + .gte(findingTimestamp - corrTimeWindow) + .lte(findingTimestamp + corrTimeWindow); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(rangeQueryBuilder); + sourceBuilder.size(10000); + sourceBuilder.fetchField("queries"); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(DetectorMonitorConfig.getAllFindingsIndicesPattern(logTypeName)); + searchRequest.source(sourceBuilder); + searchRequest.preference(Preference.PRIMARY_FIRST.type()); + mSearchRequest.add(searchRequest); + } + + if (!mSearchRequest.requests().isEmpty()) { + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + + Map> autoCorrelationsMap = new HashMap<>(); + int idx = 0; + for (MultiSearchResponse.Item item : responses) { + if (item.isFailure()) { + log.info(item.getFailureMessage()); + continue; + } + String logTypeName = logTypeNames.get(idx); + + SearchHit[] findings = item.getResponse().getHits().getHits(); + + for (SearchHit foundFinding : findings) { + if (!foundFinding.getId().equals(finding.getId())) { + Set findingTags = new HashSet<>(); + List> queries = (List>) foundFinding.getSourceAsMap().get("queries"); + for (Map query : queries) { + List queryTags = (List) query.get("tags"); + findingTags.addAll(queryTags.stream().filter(queryTag -> queryTag.startsWith("attack.")).collect(Collectors.toList())); + } - if (!mSearchRequest.requests().isEmpty()) { - client.multiSearch(mSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - - Map> autoCorrelationsMap = new HashMap<>(); - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; + boolean canCorrelate = false; + for (String tag: tags) { + if (findingTags.contains(tag)) { + canCorrelate = true; + break; + } } - String logTypeName = logTypeNames.get(idx); - - SearchHit[] findings = response.getResponse().getHits().getHits(); - - for (SearchHit foundFinding : findings) { - if (!foundFinding.getId().equals(finding.getId())) { - Set findingTags = new HashSet<>(); - List> queries = (List>) foundFinding.getSourceAsMap().get("queries"); - for (Map query : queries) { - List queryTags = (List) query.get("tags"); - findingTags.addAll(queryTags.stream().filter(queryTag -> queryTag.startsWith("attack.")).collect(Collectors.toList())); - } - - boolean canCorrelate = false; - for (String tag: tags) { - if (findingTags.contains(tag)) { - canCorrelate = true; - break; - } - } - - Set foundIntrusionSets = AutoCorrelationsRepo.validIntrusionSets(autoCorrelations, findingTags); - for (String validIntrusionSet: validIntrusionSets) { - if (foundIntrusionSets.contains(validIntrusionSet)) { - canCorrelate = true; - break; - } - } - - if (canCorrelate) { - if (autoCorrelationsMap.containsKey(logTypeName)) { - autoCorrelationsMap.get(logTypeName).add(foundFinding.getId()); - } else { - List autoCorrelatedFindings = new ArrayList<>(); - autoCorrelatedFindings.add(foundFinding.getId()); - autoCorrelationsMap.put(logTypeName, autoCorrelatedFindings); - } - } + + Set foundIntrusionSets = AutoCorrelationsRepo.validIntrusionSets(autoCorrelations, findingTags); + for (String validIntrusionSet: validIntrusionSets) { + if (foundIntrusionSets.contains(validIntrusionSet)) { + canCorrelate = true; + break; } } - ++idx; - } - onAutoCorrelations(detector, finding, autoCorrelationsMap); - } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + if (canCorrelate) { + if (autoCorrelationsMap.containsKey(logTypeName)) { + autoCorrelationsMap.get(logTypeName).add(foundFinding.getId()); + } else { + List autoCorrelatedFindings = new ArrayList<>(); + autoCorrelatedFindings.add(foundFinding.getId()); + autoCorrelationsMap.put(logTypeName, autoCorrelatedFindings); + } + } + } } - }); - } - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + ++idx; + } + onAutoCorrelations(detector, finding, autoCorrelationsMap); + }, this::onFailure)); } - }); + }, this::onFailure)); } private void onAutoCorrelations(Detector detector, Finding finding, Map> autoCorrelations) { @@ -231,39 +215,34 @@ private void onAutoCorrelations(Detector detector, Finding finding, Map() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - Iterator hits = response.getHits().iterator(); - List correlationRules = new ArrayList<>(); - while (hits.hasNext()) { - try { - SearchHit hit = hits.next(); - - XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() - ); - - CorrelationRule rule = CorrelationRule.parse(xcp, hit.getId(), hit.getVersion()); - correlationRules.add(rule); - } catch (IOException e) { - correlateFindingAction.onFailures(e); - } - } - - getValidDocuments(detectorType, indices, correlationRules, relatedDocIds, autoCorrelations); + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - @Override - public void onFailure(Exception e) { - getValidDocuments(detectorType, indices, List.of(), List.of(), autoCorrelations); + Iterator hits = response.getHits().iterator(); + List correlationRules = new ArrayList<>(); + while (hits.hasNext()) { + try { + SearchHit hit = hits.next(); + + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceAsString()); + + CorrelationRule rule = CorrelationRule.parse(xcp, hit.getId(), hit.getVersion()); + correlationRules.add(rule); + } catch (IOException e) { + onFailure(e); + } } - }); + getValidDocuments(detectorType, indices, correlationRules, relatedDocIds, autoCorrelations); + }, e -> { + log.error("[CORRELATIONS] Exception encountered while searching correlation rule index for finding id {}", + finding.getId(), e); + getValidDocuments(detectorType, indices, List.of(), List.of(), autoCorrelations); + })); } /** @@ -306,84 +285,72 @@ private void getValidDocuments(String detectorType, List indices, List() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - List filteredCorrelationRules = new ArrayList<>(); - - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; - } + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + List filteredCorrelationRules = new ArrayList<>(); + + int idx = 0; + for (MultiSearchResponse.Item response : responses) { + if (response.isFailure()) { + log.info(response.getFailureMessage()); + continue; + } - if (response.getResponse().getHits().getTotalHits().value > 0L) { - filteredCorrelationRules.add(new FilteredCorrelationRule(validCorrelationRules.get(idx), - response.getResponse().getHits().getHits(), validFields.get(idx))); - } - ++idx; + if (response.getResponse().getHits().getTotalHits().value > 0L) { + filteredCorrelationRules.add(new FilteredCorrelationRule(validCorrelationRules.get(idx), + response.getResponse().getHits().getHits(), validFields.get(idx))); } + ++idx; + } - Map> categoryToQueriesMap = new HashMap<>(); - Map categoryToTimeWindowMap = new HashMap<>(); - for (FilteredCorrelationRule rule: filteredCorrelationRules) { - List queries = rule.correlationRule.getCorrelationQueries(); - Long timeWindow = rule.correlationRule.getCorrTimeWindow(); - - for (CorrelationQuery query: queries) { - List correlationQueries; - if (categoryToQueriesMap.containsKey(query.getCategory())) { - correlationQueries = categoryToQueriesMap.get(query.getCategory()); - } else { - correlationQueries = new ArrayList<>(); - } - if (categoryToTimeWindowMap.containsKey(query.getCategory())) { - categoryToTimeWindowMap.put(query.getCategory(), Math.max(timeWindow, categoryToTimeWindowMap.get(query.getCategory()))); - } else { - categoryToTimeWindowMap.put(query.getCategory(), timeWindow); - } + Map> categoryToQueriesMap = new HashMap<>(); + Map categoryToTimeWindowMap = new HashMap<>(); + for (FilteredCorrelationRule rule: filteredCorrelationRules) { + List queries = rule.correlationRule.getCorrelationQueries(); + Long timeWindow = rule.correlationRule.getCorrTimeWindow(); - if (query.getField() == null) { - correlationQueries.add(query); - } else { - SearchHit[] hits = rule.filteredDocs; - StringBuilder qb = new StringBuilder(query.getField()).append(":("); - for (int i = 0; i < hits.length; ++i) { - String value = hits[i].field(rule.field).getValue(); - qb.append(value); - if (i < hits.length-1) { - qb.append(" OR "); - } else { - qb.append(")"); - } - } - if (query.getQuery() != null) { - qb.append(" AND ").append(query.getQuery()); + for (CorrelationQuery query: queries) { + List correlationQueries; + if (categoryToQueriesMap.containsKey(query.getCategory())) { + correlationQueries = categoryToQueriesMap.get(query.getCategory()); + } else { + correlationQueries = new ArrayList<>(); + } + if (categoryToTimeWindowMap.containsKey(query.getCategory())) { + categoryToTimeWindowMap.put(query.getCategory(), Math.max(timeWindow, categoryToTimeWindowMap.get(query.getCategory()))); + } else { + categoryToTimeWindowMap.put(query.getCategory(), timeWindow); + } + + if (query.getField() == null) { + correlationQueries.add(query); + } else { + SearchHit[] hits = rule.filteredDocs; + StringBuilder qb = new StringBuilder(query.getField()).append(":("); + for (int i = 0; i < hits.length; ++i) { + String value = hits[i].field(rule.field).getValue(); + qb.append(value); + if (i < hits.length-1) { + qb.append(" OR "); + } else { + qb.append(")"); } - correlationQueries.add(new CorrelationQuery(query.getIndex(), qb.toString(), query.getCategory(), null)); } - categoryToQueriesMap.put(query.getCategory(), correlationQueries); + if (query.getQuery() != null) { + qb.append(" AND ").append(query.getQuery()); + } + correlationQueries.add(new CorrelationQuery(query.getIndex(), qb.toString(), query.getCategory(), null)); } + categoryToQueriesMap.put(query.getCategory(), correlationQueries); } - searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap, - filteredCorrelationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()), - autoCorrelations - ); - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); } - }); + searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap, + filteredCorrelationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()), + autoCorrelations + ); + }, this::onFailure)); } else { - if (!autoCorrelations.isEmpty()) { - correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); - } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of()); - } + getTimestampFeature(detectorType, List.of(), autoCorrelations); } } @@ -415,50 +382,38 @@ private void searchFindingsByTimestamp(String detectorType, Map() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - Map relatedDocsMap = new HashMap<>(); - - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; - } - - List relatedDocIds = new ArrayList<>(); - SearchHit[] hits = response.getResponse().getHits().getHits(); - for (SearchHit hit : hits) { - relatedDocIds.addAll(hit.getFields().get("correlated_doc_ids").getValues().stream() - .map(Object::toString).collect(Collectors.toList())); - } + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + Map relatedDocsMap = new HashMap<>(); + + int idx = 0; + for (MultiSearchResponse.Item response : responses) { + if (response.isFailure()) { + log.info(response.getFailureMessage()); + continue; + } - List correlationQueries = categoryToQueriesPairs.get(idx).getValue(); - List indices = correlationQueries.stream().map(CorrelationQuery::getIndex).collect(Collectors.toList()); - List queries = correlationQueries.stream().map(CorrelationQuery::getQuery).collect(Collectors.toList()); - relatedDocsMap.put(categoryToQueriesPairs.get(idx).getKey(), - new DocSearchCriteria( - indices, - queries, - relatedDocIds)); - ++idx; + List relatedDocIds = new ArrayList<>(); + SearchHit[] hits = response.getResponse().getHits().getHits(); + for (SearchHit hit : hits) { + relatedDocIds.addAll(hit.getFields().get("correlated_doc_ids").getValues().stream() + .map(Object::toString).collect(Collectors.toList())); } - searchDocsWithFilterKeys(detectorType, relatedDocsMap, categoryToTimeWindowMap, correlationRules, autoCorrelations); - } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + List correlationQueries = categoryToQueriesPairs.get(idx).getValue(); + List indices = correlationQueries.stream().map(CorrelationQuery::getIndex).collect(Collectors.toList()); + List queries = correlationQueries.stream().map(CorrelationQuery::getQuery).collect(Collectors.toList()); + relatedDocsMap.put(categoryToQueriesPairs.get(idx).getKey(), + new DocSearchCriteria( + indices, + queries, + relatedDocIds)); + ++idx; } - }); + searchDocsWithFilterKeys(detectorType, relatedDocsMap, categoryToTimeWindowMap, correlationRules, autoCorrelations); + }, this::onFailure)); } else { - if (!autoCorrelations.isEmpty()) { - correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); - } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); - } + getTimestampFeature(detectorType, correlationRules, autoCorrelations); } } @@ -492,42 +447,30 @@ private void searchDocsWithFilterKeys(String detectorType, Map() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - Map> filteredRelatedDocIds = new HashMap<>(); - - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; - } + client.multiSearch(mSearchRequest, ActionListener.wrap( items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + Map> filteredRelatedDocIds = new HashMap<>(); + + int idx = 0; + for (MultiSearchResponse.Item response : responses) { + if (response.isFailure()) { + log.info(response.getFailureMessage()); + continue; + } - SearchHit[] hits = response.getResponse().getHits().getHits(); - List docIds = new ArrayList<>(); + SearchHit[] hits = response.getResponse().getHits().getHits(); + List docIds = new ArrayList<>(); - for (SearchHit hit : hits) { - docIds.add(hit.getId()); - } - filteredRelatedDocIds.put(categories.get(idx), docIds); - ++idx; + for (SearchHit hit : hits) { + docIds.add(hit.getId()); } - getCorrelatedFindings(detectorType, filteredRelatedDocIds, categoryToTimeWindowMap, correlationRules, autoCorrelations); + filteredRelatedDocIds.put(categories.get(idx), docIds); + ++idx; } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); + getCorrelatedFindings(detectorType, filteredRelatedDocIds, categoryToTimeWindowMap, correlationRules, autoCorrelations); + }, this::onFailure)); } else { - if (!autoCorrelations.isEmpty()) { - correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); - } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); - } + getTimestampFeature(detectorType, correlationRules, autoCorrelations); } } @@ -565,59 +508,59 @@ private void getCorrelatedFindings(String detectorType, Map } if (!mSearchRequest.requests().isEmpty()) { - client.multiSearch(mSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - Map> correlatedFindings = new HashMap<>(); - - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - ++idx; - continue; - } - - SearchHit[] hits = response.getResponse().getHits().getHits(); - List findings = new ArrayList<>(); + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + Map> correlatedFindings = new HashMap<>(); + + int idx = 0; + for (MultiSearchResponse.Item response : responses) { + if (response.isFailure()) { + log.info(response.getFailureMessage()); + ++idx; + continue; + } - for (SearchHit hit : hits) { - findings.add(hit.getId()); - } + SearchHit[] hits = response.getResponse().getHits().getHits(); + List findings = new ArrayList<>(); - if (!findings.isEmpty()) { - correlatedFindings.put(categories.get(idx), findings); - } - ++idx; + for (SearchHit hit : hits) { + findings.add(hit.getId()); } - for (Map.Entry> autoCorrelation: autoCorrelations.entrySet()) { - if (correlatedFindings.containsKey(autoCorrelation.getKey())) { - Set alreadyCorrelatedFindings = new HashSet<>(correlatedFindings.get(autoCorrelation.getKey())); - alreadyCorrelatedFindings.addAll(autoCorrelation.getValue()); - correlatedFindings.put(autoCorrelation.getKey(), new ArrayList<>(alreadyCorrelatedFindings)); - } else { - correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue()); - } + if (!findings.isEmpty()) { + correlatedFindings.put(categories.get(idx), findings); } - correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules); + ++idx; } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + for (Map.Entry> autoCorrelation: autoCorrelations.entrySet()) { + if (correlatedFindings.containsKey(autoCorrelation.getKey())) { + Set alreadyCorrelatedFindings = new HashSet<>(correlatedFindings.get(autoCorrelation.getKey())); + alreadyCorrelatedFindings.addAll(autoCorrelation.getValue()); + correlatedFindings.put(autoCorrelation.getKey(), new ArrayList<>(alreadyCorrelatedFindings)); + } else { + correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue()); + } } - }); + correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules); + }, this::onFailure)); } else { - if (!autoCorrelations.isEmpty()) { - correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); - } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); - } + getTimestampFeature(detectorType, correlationRules, autoCorrelations); } } + private void getTimestampFeature(String detectorType, List correlationRules, Map> autoCorrelations) { + if (!autoCorrelations.isEmpty()) { + correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); + } else { + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); + } + } + + private void onFailure(Exception e) { + correlateFindingAction.onFailures(e); + } + static class DocSearchCriteria { List indices; List queries; diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java index cab8798f2..c1232b8d2 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java @@ -11,13 +11,10 @@ import org.opensearch.cluster.routing.Preference; import org.opensearch.core.action.ActionListener; import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.MultiSearchResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; @@ -32,7 +29,6 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.correlation.index.query.CorrelationQueryBuilder; import org.opensearch.securityanalytics.model.CustomLogType; -import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction; import org.opensearch.securityanalytics.util.CorrelationIndices; @@ -62,213 +58,203 @@ public VectorEmbeddingsEngine(Client client, TimeValue indexTimeout, long corrTi } public void insertCorrelatedFindings(String detectorType, Finding finding, String logType, List correlatedFindings, float timestampFeature, List correlationRules, Map logTypes) { + SearchRequest searchRequest = getSearchMetadataIndexRequest(detectorType, finding, logTypes); Map tags = logTypes.get(detectorType).getTags(); String correlationId = tags.get("correlation_id").toString(); long findingTimestamp = finding.getTimestamp().toEpochMilli(); - MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( - "root", true - ); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - if (response.getHits().getHits().length == 0) { - correlateFindingAction.onFailures( - new ResourceNotFoundException("Failed to find hits in metadata index for finding id {}", finding.getId())); - } - - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - long counter = Long.parseLong(hitSource.get("counter").toString()); - - MultiSearchRequest mSearchRequest = new MultiSearchRequest(); - - for (String correlatedFinding: correlatedFindings) { - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.matchQuery( - "finding1", correlatedFinding - )).must(QueryBuilders.matchQuery( - "finding2", "" - ))/*.must(QueryBuilders.matchQuery( - "counter", counter - ))*/; - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - mSearchRequest.add(searchRequest); - } - - client.multiSearch(mSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - long prevCounter = -1L; - long totalNeighbors = 0L; - for (MultiSearchResponse.Item response: responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; - } + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - long totalHits = response.getResponse().getHits().getHits().length; - totalNeighbors += totalHits; + if (response.getHits().getHits().length == 0) { + onFailure( + new ResourceNotFoundException("Failed to find hits in metadata index for finding id {}", finding.getId())); + } - for (int idx = 0; idx < totalHits; ++idx) { - SearchHit hit = response.getResponse().getHits().getHits()[idx]; - Map hitSource = hit.getSourceAsMap(); - long neighborCounter = Long.parseLong(hitSource.get("counter").toString()); - String correlatedFinding = hitSource.get("finding1").toString(); + Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); + long counter = Long.parseLong(hitSource.get("counter").toString()); + + MultiSearchRequest mSearchRequest = new MultiSearchRequest(); + + for (String correlatedFinding: correlatedFindings) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.matchQuery( + "finding1", correlatedFinding + )).must(QueryBuilders.matchQuery( + "finding2", "" + ))/*.must(QueryBuilders.matchQuery( + "counter", counter + ))*/; + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(10000); + SearchRequest request = new SearchRequest(); + request.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); + request.source(searchSourceBuilder); + request.preference(Preference.PRIMARY_FIRST.type()); + + mSearchRequest.add(request); + } - try { - float[] corrVector = new float[3]; - if (counter != prevCounter) { - for (int i = 0; i < 2; ++i) { - corrVector[i] = ((float) counter) - 50.0f; - } + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + BulkRequest bulkRequest = new BulkRequest(); + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + long prevCounter = -1L; + long totalNeighbors = 0L; + for (MultiSearchResponse.Item item: responses) { + if (item.isFailure()) { + log.info(item.getFailureMessage()); + continue; + } - corrVector[0] = (float) counter; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", counter); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", correlationId); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout); - bulkRequest.add(indexRequest); - } + long totalHits = item.getResponse().getHits().getHits().length; + totalNeighbors += totalHits; - corrVector = new float[3]; - for (int i = 0; i < 2; ++i) { - corrVector[i] = ((float) counter) - 50.0f; - } - corrVector[0] = (2.0f * ((float) counter) - 50.0f) / 2.0f; - corrVector[1] = (2.0f * ((float) neighborCounter) - 50.0f) / 2.0f; - corrVector[2] = timestampFeature; + for (int idx = 0; idx < totalHits; ++idx) { + SearchHit hit = item.getResponse().getHits().getHits()[idx]; + Map sourceAsMap = hit.getSourceAsMap(); + long neighborCounter = Long.parseLong(sourceAsMap.get("counter").toString()); + String correlatedFinding = sourceAsMap.get("finding1").toString(); - XContentBuilder corrBuilder = XContentFactory.jsonBuilder().startObject(); - corrBuilder.field("root", false); - corrBuilder.field("counter", (long) ((2.0f * ((float) counter) - 50.0f) / 2.0f)); - corrBuilder.field("finding1", finding.getId()); - corrBuilder.field("finding2", correlatedFinding); - corrBuilder.field("logType", String.format(Locale.ROOT, "%s-%s", detectorType, logType)); - corrBuilder.field("timestamp", findingTimestamp); - corrBuilder.field("corr_vector", corrVector); - corrBuilder.field("recordType", "finding-finding"); - corrBuilder.field("scoreTimestamp", 0L); - corrBuilder.field("corrRules", correlationRules); - corrBuilder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(corrBuilder) - .timeout(indexTimeout); - bulkRequest.add(indexRequest); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); + try { + float[] corrVector = new float[3]; + if (counter != prevCounter) { + for (int i = 0; i < 2; ++i) { + corrVector[i] = ((float) counter) - 50.0f; } - prevCounter = counter; - } - } - if (totalNeighbors > 0L) { - client.bulk(bulkRequest, new ActionListener<>() { - @Override - public void onResponse(BulkResponse response) { - if (response.hasFailures()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Correlation of finding failed", RestStatus.INTERNAL_SERVER_ERROR)); - } - correlateFindingAction.onOperation(); - } + corrVector[0] = (float) counter; + corrVector[2] = timestampFeature; + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("root", false); + builder.field("counter", counter); + builder.field("finding1", finding.getId()); + builder.field("finding2", ""); + builder.field("logType", correlationId); + builder.field("timestamp", findingTimestamp); + builder.field("corr_vector", corrVector); + builder.field("recordType", "finding"); + builder.field("scoreTimestamp", 0L); + builder.endObject(); + + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) + .source(builder) + .timeout(indexTimeout); + bulkRequest.add(indexRequest); + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } else { - insertOrphanFindings(detectorType, finding, timestampFeature, logTypes); + corrVector = new float[3]; + for (int i = 0; i < 2; ++i) { + corrVector[i] = ((float) counter) - 50.0f; + } + corrVector[0] = (2.0f * ((float) counter) - 50.0f) / 2.0f; + corrVector[1] = (2.0f * ((float) neighborCounter) - 50.0f) / 2.0f; + corrVector[2] = timestampFeature; + + XContentBuilder corrBuilder = XContentFactory.jsonBuilder().startObject(); + corrBuilder.field("root", false); + corrBuilder.field("counter", (long) ((2.0f * ((float) counter) - 50.0f) / 2.0f)); + corrBuilder.field("finding1", finding.getId()); + corrBuilder.field("finding2", correlatedFinding); + corrBuilder.field("logType", String.format(Locale.ROOT, "%s-%s", detectorType, logType)); + corrBuilder.field("timestamp", findingTimestamp); + corrBuilder.field("corr_vector", corrVector); + corrBuilder.field("recordType", "finding-finding"); + corrBuilder.field("scoreTimestamp", 0L); + corrBuilder.field("corrRules", correlationRules); + corrBuilder.endObject(); + + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) + .source(corrBuilder) + .timeout(indexTimeout); + bulkRequest.add(indexRequest); + } catch (IOException ex) { + onFailure(ex); } + prevCounter = counter; } + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); + if (totalNeighbors > 0L) { + client.bulk(bulkRequest, ActionListener.wrap( bulkResponse -> { + if (bulkResponse.hasFailures()) { + onFailure(new OpenSearchStatusException("Correlation of finding failed", RestStatus.INTERNAL_SERVER_ERROR)); + } + correlateFindingAction.onOperation(); + }, this::onFailure)); + } else { + insertOrphanFindings(detectorType, finding, timestampFeature, logTypes); + } + }, this::onFailure)); + }, this::onFailure)); } public void insertOrphanFindings(String detectorType, Finding finding, float timestampFeature, Map logTypes) { - if (logTypes.get(detectorType) == null) { - log.error("LogTypes Index is missing the detector type {}", detectorType); - correlateFindingAction.onFailures(new OpenSearchStatusException("LogTypes Index is missing the detector type", RestStatus.INTERNAL_SERVER_ERROR)); - } - + SearchRequest searchRequest = getSearchMetadataIndexRequest(detectorType, finding, logTypes); Map tags = logTypes.get(detectorType).getTags(); String correlationId = tags.get("correlation_id").toString(); - long findingTimestamp = finding.getTimestamp().toEpochMilli(); - MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( - "root", true - ); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - try { - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - String id = response.getHits().getHits()[0].getId(); - long counter = Long.parseLong(hitSource.get("counter").toString()); - long timestamp = Long.parseLong(hitSource.get("timestamp").toString()); - if (counter == 0L) { + try { + Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); + String id = response.getHits().getHits()[0].getId(); + long counter = Long.parseLong(hitSource.get("counter").toString()); + long timestamp = Long.parseLong(hitSource.get("timestamp").toString()); + if (counter == 0L) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("root", true); + builder.field("counter", 50L); + builder.field("finding1", ""); + builder.field("finding2", ""); + builder.field("logType", ""); + builder.field("timestamp", findingTimestamp); + builder.field("scoreTimestamp", 0L); + builder.endObject(); + + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) + .id(id) + .source(builder) + .timeout(indexTimeout) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + if (indexResponse.status().equals(RestStatus.OK)) { + try { + float[] corrVector = new float[3]; + corrVector[0] = 50.0f; + corrVector[2] = timestampFeature; + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject(); + xContentBuilder.field("root", false); + xContentBuilder.field("counter", 50L); + xContentBuilder.field("finding1", finding.getId()); + xContentBuilder.field("finding2", ""); + xContentBuilder.field("logType", correlationId); + xContentBuilder.field("timestamp", findingTimestamp); + xContentBuilder.field("corr_vector", corrVector); + xContentBuilder.field("recordType", "finding"); + xContentBuilder.field("scoreTimestamp", 0L); + xContentBuilder.endObject(); + + indexCorrelatedFindings(xContentBuilder); + } catch (IOException ex) { + onFailure(ex); + } + } + }, this::onFailure)); + } else { + if (findingTimestamp - timestamp > corrTimeWindow) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); builder.field("root", true); builder.field("counter", 50L); @@ -285,308 +271,190 @@ public void onResponse(SearchResponse response) { .timeout(indexTimeout) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.OK)) { - try { - float[] corrVector = new float[3]; - corrVector[0] = 50.0f; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", 50L); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", correlationId); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.CREATED)) { - correlateFindingAction.onOperation(); - } else { - correlateFindingAction.onFailures(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + if (indexResponse.status().equals(RestStatus.OK)) { + correlateFindingAction.onOperation(); + try { + float[] corrVector = new float[3]; + corrVector[0] = 50.0f; + corrVector[2] = timestampFeature; - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); - } + XContentBuilder contentBuilder = XContentFactory.jsonBuilder().startObject(); + contentBuilder.field("root", false); + contentBuilder.field("counter", 50L); + contentBuilder.field("finding1", finding.getId()); + contentBuilder.field("finding2", ""); + contentBuilder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); + contentBuilder.field("timestamp", findingTimestamp); + contentBuilder.field("corr_vector", corrVector); + contentBuilder.field("recordType", "finding"); + contentBuilder.field("scoreTimestamp", 0L); + contentBuilder.endObject(); + + indexCorrelatedFindings(contentBuilder); + } catch (IOException ex) { + onFailure(ex); } } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); + }, this::onFailure)); } else { - if (findingTimestamp - timestamp > corrTimeWindow) { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", true); - builder.field("counter", 50L); - builder.field("finding1", ""); - builder.field("finding2", ""); - builder.field("logType", ""); - builder.field("timestamp", findingTimestamp); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) - .id(id) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.OK)) { - correlateFindingAction.onOperation(); - try { - float[] corrVector = new float[3]; - corrVector[0] = 50.0f; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", 50L); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.CREATED)) { - correlateFindingAction.onOperation(); - } else { - correlateFindingAction.onFailures(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } + float[] query = new float[3]; + for (int i = 0; i < 2; ++i) { + query[i] = (2.0f * ((float) counter) - 50.0f) / 2.0f; + } + query[2] = timestampFeature; + + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder("corr_vector", query, 100, QueryBuilders.boolQuery() + .mustNot(QueryBuilders.matchQuery( + "finding1", "" + )).mustNot(QueryBuilders.matchQuery( + "finding2", "" + )).filter(QueryBuilders.rangeQuery("timestamp") + .gte(findingTimestamp - corrTimeWindow) + .lte(findingTimestamp + corrTimeWindow))); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(correlationQueryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(1); + SearchRequest request = new SearchRequest(); + request.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); + request.source(searchSourceBuilder); + request.preference(Preference.PRIMARY_FIRST.type()); + + client.search(request, ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); - } - } - } + long totalHits = searchResponse.getHits().getTotalHits().value; + SearchHit hit = totalHits > 0? searchResponse.getHits().getHits()[0]: null; + long existCounter = 0L; - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } else { - float[] query = new float[3]; - for (int i = 0; i < 2; ++i) { - query[i] = (2.0f * ((float) counter) - 50.0f) / 2.0f; + if (hit != null) { + Map sourceAsMap = searchResponse.getHits().getHits()[0].getSourceAsMap(); + existCounter = Long.parseLong(sourceAsMap.get("counter").toString()); } - query[2] = timestampFeature; - - CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder("corr_vector", query, 100, QueryBuilders.boolQuery() - .mustNot(QueryBuilders.matchQuery( - "finding1", "" - )).mustNot(QueryBuilders.matchQuery( - "finding2", "" - )).filter(QueryBuilders.rangeQuery("timestamp") - .gte(findingTimestamp - corrTimeWindow) - .lte(findingTimestamp + corrTimeWindow))); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(correlationQueryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - long totalHits = response.getHits().getTotalHits().value; - SearchHit hit = totalHits > 0? response.getHits().getHits()[0]: null; - long existCounter = 0L; - - if (hit != null) { - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - existCounter = Long.parseLong(hitSource.get("counter").toString()); + if (totalHits == 0L || existCounter != ((long) (2.0f * ((float) counter) - 50.0f) / 2.0f)) { + try { + float[] corrVector = new float[3]; + for (int i = 0; i < 2; ++i) { + corrVector[i] = ((float) counter) - 50.0f; } + corrVector[0] = (float) counter; + corrVector[2] = timestampFeature; - if (totalHits == 0L || existCounter != ((long) (2.0f * ((float) counter) - 50.0f) / 2.0f)) { - try { - float[] corrVector = new float[3]; - for (int i = 0; i < 2; ++i) { - corrVector[i] = ((float) counter) - 50.0f; - } - corrVector[0] = (float) counter; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", counter); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.CREATED)) { - correlateFindingAction.onOperation(); - } else { - correlateFindingAction.onFailures(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); - } - } else { - try { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", true); - builder.field("counter", counter + 50L); - builder.field("finding1", ""); - builder.field("finding2", ""); - builder.field("logType", ""); - builder.field("timestamp", findingTimestamp); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) - .id(id) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.OK)) { - try { - float[] corrVector = new float[3]; - for (int i = 0; i < 2; ++i) { - corrVector[i] = (float) counter; - } - corrVector[0] = counter + 50.0f; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", counter + 50L); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.CREATED)) { - correlateFindingAction.onOperation(); - } else { - correlateFindingAction.onFailures(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); - } - } - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("root", false); + builder.field("counter", counter); + builder.field("finding1", finding.getId()); + builder.field("finding2", ""); + builder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); + builder.field("timestamp", findingTimestamp); + builder.field("corr_vector", corrVector); + builder.field("recordType", "finding"); + builder.field("scoreTimestamp", 0L); + builder.endObject(); + + indexCorrelatedFindings(builder); + } catch (IOException ex) { + onFailure(ex); + } + } else { + try { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("root", true); + builder.field("counter", counter + 50L); + builder.field("finding1", ""); + builder.field("finding2", ""); + builder.field("logType", ""); + builder.field("timestamp", findingTimestamp); + builder.field("scoreTimestamp", 0L); + builder.endObject(); + + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) + .id(id) + .source(builder) + .timeout(indexTimeout) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + if (indexResponse.status().equals(RestStatus.OK)) { + try { + float[] corrVector = new float[3]; + for (int i = 0; i < 2; ++i) { + corrVector[i] = (float) counter; } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); + corrVector[0] = counter + 50.0f; + corrVector[2] = timestampFeature; + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject(); + xContentBuilder.field("root", false); + xContentBuilder.field("counter", counter + 50L); + xContentBuilder.field("finding1", finding.getId()); + xContentBuilder.field("finding2", ""); + xContentBuilder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); + xContentBuilder.field("timestamp", findingTimestamp); + xContentBuilder.field("corr_vector", corrVector); + xContentBuilder.field("recordType", "finding"); + xContentBuilder.field("scoreTimestamp", 0L); + xContentBuilder.endObject(); + + indexCorrelatedFindings(xContentBuilder); + } catch (IOException ex) { + onFailure(ex); + } } - } - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + }, this::onFailure)); + } catch (IOException ex) { + onFailure(ex); } - }); - } + } + }, this::onFailure)); } - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); } + } catch (IOException ex) { + onFailure(ex); } + }, this::onFailure)); + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + private void indexCorrelatedFindings(XContentBuilder builder) { + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) + .source(builder) + .timeout(indexTimeout) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, ActionListener.wrap(response -> { + if (response.status().equals(RestStatus.CREATED)) { + correlateFindingAction.onOperation(); + } else { + onFailure(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); } - }); + }, this::onFailure)); + } + + private SearchRequest getSearchMetadataIndexRequest(String detectorType, Finding finding, Map logTypes) { + if (logTypes.get(detectorType) == null) { + throw new OpenSearchStatusException("LogTypes Index is missing the detector type", RestStatus.INTERNAL_SERVER_ERROR); + } + + Map tags = logTypes.get(detectorType).getTags(); + MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( + "root", true + ); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(1); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); + searchRequest.source(searchSourceBuilder); + searchRequest.preference(Preference.PRIMARY_FIRST.type()); + return searchRequest; + } + + private void onFailure(Exception e) { + correlateFindingAction.onFailures(e); } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index 63c31f99b..9288ca7f8 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -135,67 +135,43 @@ protected void doExecute(Task task, ActionRequest request, ActionListener() { - @Override - public void onResponse(CreateIndexResponse response) { - if (response.isAcknowledged()) { - IndexUtils.correlationIndexUpdated(); - if (IndexUtils.correlationIndexUpdated) { - IndexUtils.lastUpdatedCorrelationHistoryIndex = IndexUtils.getIndexNameWithAlias( - clusterService.state(), - CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX - ); - } + AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); + this.correlationIndices.initCorrelationIndex(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + IndexUtils.correlationIndexUpdated(); + if (IndexUtils.correlationIndexUpdated) { + IndexUtils.lastUpdatedCorrelationHistoryIndex = IndexUtils.getIndexNameWithAlias( + clusterService.state(), + CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX + ); + } - if (!correlationIndices.correlationMetadataIndexExists()) { - try { - correlationIndices.initCorrelationMetadataIndex(new ActionListener<>() { - @Override - public void onResponse(CreateIndexResponse response) { - if (response.isAcknowledged()) { - IndexUtils.correlationMetadataIndexUpdated(); - - correlationIndices.setupCorrelationIndex(indexTimeout, setupTimestamp, new ActionListener<>() { - @Override - public void onResponse(BulkResponse response) { - if (response.hasFailures()) { - log.error(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - - AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); - correlateFindingAction.start(); - } - @Override - public void onFailure(Exception e) { - log.error(e); - } - }); - } else { - log.error(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); - } - } + if (!correlationIndices.correlationMetadataIndexExists()) { + try { + correlationIndices.initCorrelationMetadataIndex(ActionListener.wrap(createIndexResponse -> { + if (createIndexResponse.isAcknowledged()) { + IndexUtils.correlationMetadataIndexUpdated(); - @Override - public void onFailure(Exception e) { - - } - }); - } catch (Exception ex) { - onFailure(ex); - } + correlationIndices.setupCorrelationIndex(indexTimeout, setupTimestamp, ActionListener.wrap(bulkResponse -> { + if (bulkResponse.hasFailures()) { + log.error(new OpenSearchStatusException(createIndexResponse.toString(), RestStatus.INTERNAL_SERVER_ERROR)); + } + correlateFindingAction.start(); + }, correlateFindingAction::onFailures)); + } else { + log.error(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); + } + }, correlateFindingAction::onFailures)); + } catch (Exception ex) { + correlateFindingAction.onFailures(ex); } - } else { - log.error(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } + } else { + log.error(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } - - @Override - public void onFailure(Exception e) { - log.error(e); - } - }); + }, correlateFindingAction::onFailures)); } catch (IOException ex) { log.error(ex); } @@ -254,39 +230,31 @@ void start() { searchRequest.source(searchSourceBuilder); searchRequest.preference(Preference.PRIMARY_FIRST.type()); - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - SearchHits hits = response.getHits(); - // Detectors Index hits count could be more even if we fetch one - if (hits.getTotalHits().value >= 1 && hits.getHits().length > 0) { - try { - SearchHit hit = hits.getAt(0); - - XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() - ); - Detector detector = Detector.docParse(xcp, hit.getId(), hit.getVersion()); - joinEngine.onSearchDetectorResponse(detector, finding); - } catch (IOException e) { - log.error("IOException for request {}", searchRequest.toString(), e); - onFailures(e); - } - } else { - onFailures(new OpenSearchStatusException("detector not found given monitor id", RestStatus.INTERNAL_SERVER_ERROR)); - } + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - @Override - public void onFailure(Exception e) { - onFailures(e); + SearchHits hits = response.getHits(); + // Detectors Index hits count could be more even if we fetch one + if (hits.getTotalHits().value >= 1 && hits.getHits().length > 0) { + try { + SearchHit hit = hits.getAt(0); + + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() + ); + Detector detector = Detector.docParse(xcp, hit.getId(), hit.getVersion()); + joinEngine.onSearchDetectorResponse(detector, finding); + } catch (IOException e) { + log.error("IOException for request {}", searchRequest.toString(), e); + onFailures(e); + } + } else { + onFailures(new OpenSearchStatusException("detector not found given monitor id", RestStatus.INTERNAL_SERVER_ERROR)); } - }); + }, this::onFailures)); } else { onFailures(new SecurityAnalyticsException(String.format(Locale.getDefault(), "Detector index %s doesnt exist", Detector.DETECTORS_INDEX), RestStatus.INTERNAL_SERVER_ERROR, new RuntimeException())); } @@ -298,22 +266,14 @@ public void initCorrelationIndex(String detectorType, Map> IndexUtils.updateIndexMapping( CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX, CorrelationIndices.correlationMappings(), clusterService.state(), client.admin().indices(), - new ActionListener<>() { - @Override - public void onResponse(AcknowledgedResponse response) { - if (response.isAcknowledged()) { - IndexUtils.correlationIndexUpdated(); - getTimestampFeature(detectorType, correlatedFindings, null, correlationRules); - } else { - onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); - } + ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + IndexUtils.correlationIndexUpdated(); + getTimestampFeature(detectorType, correlatedFindings, null, correlationRules); + } else { + onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } - - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }, + }, this::onFailures), true ); } else { @@ -327,332 +287,205 @@ public void onFailure(Exception e) { public void getTimestampFeature(String detectorType, Map> correlatedFindings, Finding orphanFinding, List correlationRules) { if (!correlationIndices.correlationMetadataIndexExists()) { try { - correlationIndices.initCorrelationMetadataIndex(new ActionListener<>() { - @Override - public void onResponse(CreateIndexResponse response) { - if (response.isAcknowledged()) { - IndexUtils.correlationMetadataIndexUpdated(); - - correlationIndices.setupCorrelationIndex(indexTimeout, setupTimestamp, new ActionListener<>() { - @Override - public void onResponse(BulkResponse response) { - if (response.hasFailures()) { - log.error(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } + correlationIndices.initCorrelationMetadataIndex(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + IndexUtils.correlationMetadataIndexUpdated(); - long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - String id = response.getHits().getHits()[0].getId(); - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); - - if (findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL > scoreTimestamp) { - try { - XContentBuilder scoreBuilder = XContentFactory.jsonBuilder().startObject(); - scoreBuilder.field("scoreTimestamp", findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL); - scoreBuilder.field("root", false); - scoreBuilder.endObject(); - - IndexRequest scoreIndexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) - .id(id) - .source(scoreBuilder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(scoreIndexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.existsQuery("source")); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); - searchRequest.source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - SearchHit[] hits = response.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } - - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); - } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } - } - - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } - - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } catch (Exception ex) { - onFailures(ex); - } - } else { - float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); - - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.existsQuery("source")); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); - searchRequest.source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - SearchHit[] hits = response.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } - - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), timestampFeature, logTypes); - } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - timestampFeature, correlationRules, logTypes); - } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature, logTypes); - } - } + correlationIndices.setupCorrelationIndex(indexTimeout, setupTimestamp, ActionListener.wrap(bulkResponse -> { + if (bulkResponse.hasFailures()) { + log.error(new OpenSearchStatusException(bulkResponse.toString(), RestStatus.INTERNAL_SERVER_ERROR)); + } - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } - } + long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); + SearchRequest searchMetadataIndexRequest = getSearchMetadataIndexRequest(); - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } + client.search(searchMetadataIndexRequest, ActionListener.wrap(searchMetadataResponse -> { + String id = searchMetadataResponse.getHits().getHits()[0].getId(); + Map hitSource = searchMetadataResponse.getHits().getHits()[0].getSourceAsMap(); + long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); - @Override - public void onFailure(Exception e) { - log.error(e); - } - }); - } else { - log.error(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); - } - } + long newScoreTimestamp = findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL; + if (newScoreTimestamp > scoreTimestamp) { + try { + IndexRequest scoreIndexRequest = getCorrelationMetadataIndexRequest(id, newScoreTimestamp); + + client.index(scoreIndexRequest, ActionListener.wrap(indexResponse -> { + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } + + SearchHit[] hits = searchResponse.getHits().getHits(); + Map logTypes = new HashMap<>(); + for (SearchHit hit : hits) { + Map sourceMap = hit.getSourceAsMap(); + logTypes.put(sourceMap.get("name").toString(), + new CustomLogType(sourceMap)); + } - @Override - public void onFailure(Exception e) { + if (correlatedFindings != null) { + if (correlatedFindings.isEmpty()) { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { + vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), + Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); + } + } else { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + }, this::onFailures)); + }, this::onFailures)); + } catch (Exception ex) { + onFailures(ex); + } + } else { + float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); + insertFindings(timestampFeature, searchRequest, correlatedFindings, detectorType, correlationRules, orphanFinding); + } + }, this::onFailures)); + }, this::onFailures)); + } else { + log.error(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); } - }); + }, this::onFailures)); } catch (Exception ex) { onFailures(ex); } } else { long findingTimestamp = this.request.getFinding().getTimestamp().toEpochMilli(); - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.getHits().getHits().length == 0) { - onFailures(new ResourceNotFoundException( - "Failed to find hits in metadata index for finding id {}", request.getFinding().getId())); - } + SearchRequest searchMetadataIndexRequest = getSearchMetadataIndexRequest(); - String id = response.getHits().getHits()[0].getId(); - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); - - if (findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL > scoreTimestamp) { - try { - XContentBuilder scoreBuilder = XContentFactory.jsonBuilder().startObject(); - scoreBuilder.field("scoreTimestamp", findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL); - scoreBuilder.field("root", false); - scoreBuilder.endObject(); - - IndexRequest scoreIndexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) - .id(id) - .source(scoreBuilder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(scoreIndexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.existsQuery("source")); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); - searchRequest.source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - SearchHit[] hits = response.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } - - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); - } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } - } + client.search(searchMetadataIndexRequest, ActionListener.wrap(response -> { + if (response.getHits().getHits().length == 0) { + onFailures(new ResourceNotFoundException( + "Failed to find hits in metadata index for finding id {}", request.getFinding().getId())); + } - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } + String id = response.getHits().getHits()[0].getId(); + Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); + long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } catch (Exception ex) { - onFailures(ex); - } - } else { - float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); - - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.existsQuery("source")); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); - searchRequest.source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } + long newScoreTimestamp = findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL; + if (newScoreTimestamp > scoreTimestamp) { + try { + IndexRequest scoreIndexRequest = getCorrelationMetadataIndexRequest(id, newScoreTimestamp); - SearchHit[] hits = response.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } + client.index(scoreIndexRequest, ActionListener.wrap(indexResponse -> { + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), timestampFeature, logTypes); + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - timestampFeature, correlationRules, logTypes); + + SearchHit[] hits = searchResponse.getHits().getHits(); + Map logTypes = new HashMap<>(); + for (SearchHit hit : hits) { + Map sourceMap = hit.getSourceAsMap(); + logTypes.put(sourceMap.get("name").toString(), + new CustomLogType(sourceMap)); } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature, logTypes); - } - } - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); + if (correlatedFindings != null) { + if (correlatedFindings.isEmpty()) { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { + vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), + Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); + } + } else { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + }, this::onFailures)); + }, this::onFailures)); + } catch (Exception ex) { + onFailures(ex); } - } + } else { + float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); - @Override - public void onFailure(Exception e) { - onFailures(e); + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); + insertFindings(timestampFeature, searchRequest, correlatedFindings, detectorType, correlationRules, orphanFinding); } - }); + }, this::onFailures)); } } + private SearchRequest getSearchLogTypeIndexRequest() { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.existsQuery("source")); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(10000); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); + searchRequest.source(searchSourceBuilder); + return searchRequest; + } + + private IndexRequest getCorrelationMetadataIndexRequest(String id, long newScoreTimestamp) throws IOException { + XContentBuilder scoreBuilder = XContentFactory.jsonBuilder().startObject(); + scoreBuilder.field("scoreTimestamp", newScoreTimestamp); + scoreBuilder.field("root", false); + scoreBuilder.endObject(); + + IndexRequest scoreIndexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) + .id(id) + .source(scoreBuilder) + .timeout(indexTimeout) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + return scoreIndexRequest; + } + private void insertFindings(float timestampFeature, SearchRequest searchRequest, Map> correlatedFindings, String detectorType, List correlationRules, Finding orphanFinding) { + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } + + SearchHit[] hits = response.getHits().getHits(); + Map logTypes = new HashMap<>(); + for (SearchHit hit : hits) { + Map sourceMap = hit.getSourceAsMap(); + logTypes.put(sourceMap.get("name").toString(), + new CustomLogType(sourceMap)); + } + + if (correlatedFindings != null) { + if (correlatedFindings.isEmpty()) { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), timestampFeature, logTypes); + } + for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { + vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), + timestampFeature, correlationRules, logTypes); + } + } else { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature, logTypes); + } + }, this::onFailures)); + } + + private SearchRequest getSearchMetadataIndexRequest() { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(1); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); + searchRequest.source(searchSourceBuilder); + searchRequest.preference(Preference.PRIMARY_FIRST.type()); + + return searchRequest; + } + public void onOperation() { this.response.set(RestStatus.OK); if (counter.compareAndSet(false, true)) {