diff --git a/build.gradle b/build.gradle index 2e16c6b70..681287dbb 100644 --- a/build.gradle +++ b/build.gradle @@ -69,6 +69,7 @@ opensearchplugin { name 'opensearch-security-analytics' description 'OpenSearch Security Analytics plugin' classname 'org.opensearch.securityanalytics.SecurityAnalyticsPlugin' + extendedPlugins = ['opensearch-job-scheduler'] } javaRestTest { @@ -142,12 +143,6 @@ repositories { sourceSets.main.java.srcDirs = ['src/main/generated','src/main/java'] configurations { zipArchive - - all { - resolutionStrategy { - force "com.google.guava:guava:32.0.1-jre" - } - } } dependencies { @@ -158,17 +153,14 @@ dependencies { api "org.opensearch:common-utils:${common_utils_version}@jar" api "org.opensearch.client:opensearch-rest-client:${opensearch_version}" implementation "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}" + compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + implementation "org.apache.commons:commons-csv:1.10.0" // Needed for integ tests zipArchive group: 'org.opensearch.plugin', name:'alerting', version: "${opensearch_build}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-notifications-core', version: "${opensearch_build}" zipArchive group: 'org.opensearch.plugin', name:'notifications', version: "${opensearch_build}" - - //spotless - implementation('com.google.googlejavaformat:google-java-format:1.17.0') { - exclude group: 'com.google.guava' - } - implementation 'com.google.guava:guava:32.0.1-jre' + zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" } // RPM & Debian build @@ -289,6 +281,22 @@ testClusters.integTest { } } })) + plugin(provider({ + new RegularFile() { + @Override + File getAsFile() { + return configurations.zipArchive.asFileTree.matching { + include '**/opensearch-job-scheduler*' + }.singleFile + } + } + })) + nodes.each { node -> + def plugins = node.plugins + def firstPlugin = plugins.get(0) + plugins.remove(0) + plugins.add(firstPlugin) + } } run { diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index 2c60321df..81fc4be38 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -8,16 +8,14 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.function.Supplier; +import java.util.Optional; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.cluster.routing.Preference; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.core.action.ActionResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; @@ -38,18 +36,21 @@ import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.index.mapper.Mapper; -import org.opensearch.index.query.QueryBuilders; +import org.opensearch.indices.SystemIndexDescriptor; +import org.opensearch.jobscheduler.spi.JobSchedulerExtension; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ClusterPlugin; import org.opensearch.plugins.EnginePlugin; import org.opensearch.plugins.MapperPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SearchPlugin; +import org.opensearch.plugins.SystemIndexPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.action.*; import org.opensearch.securityanalytics.correlation.index.codec.CorrelationCodecService; import org.opensearch.securityanalytics.correlation.index.mapper.CorrelationVectorFieldMapper; @@ -60,7 +61,18 @@ import org.opensearch.securityanalytics.mapper.IndexTemplateManager; import org.opensearch.securityanalytics.mapper.MapperService; import org.opensearch.securityanalytics.model.CustomLogType; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; import org.opensearch.securityanalytics.resthandler.*; +import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedDataService; +import org.opensearch.securityanalytics.threatIntel.action.PutTIFJobAction; +import org.opensearch.securityanalytics.threatIntel.action.TransportPutTIFJobAction; +import org.opensearch.securityanalytics.threatIntel.common.TIFLockService; +import org.opensearch.securityanalytics.threatIntel.feedMetadata.BuiltInTIFMetadataLoader; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameter; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameterService; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobRunner; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobUpdateService; import org.opensearch.securityanalytics.transport.*; import org.opensearch.securityanalytics.model.Rule; import org.opensearch.securityanalytics.model.Detector; @@ -75,7 +87,9 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; -public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, MapperPlugin, SearchPlugin, EnginePlugin, ClusterPlugin { +import static org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameter.THREAT_INTEL_DATA_INDEX_NAME_PREFIX; + +public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, MapperPlugin, SearchPlugin, EnginePlugin, ClusterPlugin, SystemIndexPlugin, JobSchedulerExtension { private static final Logger log = LogManager.getLogger(SecurityAnalyticsPlugin.class); @@ -91,6 +105,8 @@ public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, Map public static final String CORRELATION_RULES_BASE_URI = PLUGINS_BASE_URI + "/correlation/rules"; public static final String CUSTOM_LOG_TYPE_URI = PLUGINS_BASE_URI + "/logtype"; + public static final String JOB_INDEX_NAME = ".opensearch-sap--job"; + public static final Map TIF_JOB_INDEX_SETTING = Map.of(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1, IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-all", IndexMetadata.SETTING_INDEX_HIDDEN, true); private CorrelationRuleIndices correlationRuleIndices; @@ -113,8 +129,12 @@ public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, Map private BuiltinLogTypeLoader builtinLogTypeLoader; private LogTypeService logTypeService; + @Override + public Collection getSystemIndexDescriptors(Settings settings){ + return Collections.singletonList(new SystemIndexDescriptor(THREAT_INTEL_DATA_INDEX_NAME_PREFIX, "System index used for threat intel data")); + } + - private Client client; @Override public Collection createComponents(Client client, @@ -128,7 +148,9 @@ public Collection createComponents(Client client, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { + builtinLogTypeLoader = new BuiltinLogTypeLoader(); + BuiltInTIFMetadataLoader builtInTIFMetadataLoader = new BuiltInTIFMetadataLoader(); logTypeService = new LogTypeService(client, clusterService, xContentRegistry, builtinLogTypeLoader); detectorIndices = new DetectorIndices(client.admin(), clusterService, threadPool); ruleTopicIndices = new RuleTopicIndices(client, clusterService, logTypeService); @@ -138,12 +160,18 @@ public Collection createComponents(Client client, mapperService = new MapperService(client, clusterService, indexNameExpressionResolver, indexTemplateManager, logTypeService); ruleIndices = new RuleIndices(logTypeService, client, clusterService, threadPool); correlationRuleIndices = new CorrelationRuleIndices(client, clusterService); - this.client = client; + ThreatIntelFeedDataService threatIntelFeedDataService = new ThreatIntelFeedDataService(clusterService, client, indexNameExpressionResolver, xContentRegistry); + DetectorThreatIntelService detectorThreatIntelService = new DetectorThreatIntelService(threatIntelFeedDataService, client, xContentRegistry); + TIFJobParameterService tifJobParameterService = new TIFJobParameterService(client, clusterService); + TIFJobUpdateService tifJobUpdateService = new TIFJobUpdateService(clusterService, tifJobParameterService, threatIntelFeedDataService, builtInTIFMetadataLoader); + TIFLockService threatIntelLockService = new TIFLockService(clusterService, client); + + TIFJobRunner.getJobRunnerInstance().initialize(clusterService, tifJobUpdateService, tifJobParameterService, threatIntelLockService, threadPool, detectorThreatIntelService); return List.of( detectorIndices, correlationIndices, correlationRuleIndices, ruleTopicIndices, customLogTypeIndices, ruleIndices, - mapperService, indexTemplateManager, builtinLogTypeLoader - ); + mapperService, indexTemplateManager, builtinLogTypeLoader, builtInTIFMetadataLoader, threatIntelFeedDataService, detectorThreatIntelService, + tifJobUpdateService, tifJobParameterService, threatIntelLockService); } @Override @@ -187,13 +215,34 @@ public List getRestHandlers(Settings settings, ); } + @Override + public String getJobType() { + return "opensearch_sap_job"; + } + + @Override + public String getJobIndex() { + return JOB_INDEX_NAME; + } + + @Override + public ScheduledJobRunner getJobRunner() { + return TIFJobRunner.getJobRunnerInstance(); + } + + @Override + public ScheduledJobParser getJobParser() { + return (parser, id, jobDocVersion) -> TIFJobParameter.PARSER.parse(parser, null); + } + @Override public List getNamedXContent() { return List.of( Detector.XCONTENT_REGISTRY, DetectorInput.XCONTENT_REGISTRY, Rule.XCONTENT_REGISTRY, - CustomLogType.XCONTENT_REGISTRY + CustomLogType.XCONTENT_REGISTRY, + ThreatIntelFeedData.XCONTENT_REGISTRY ); } @@ -243,7 +292,10 @@ public List> getSettings() { SecurityAnalyticsSettings.IS_CORRELATION_INDEX_SETTING, SecurityAnalyticsSettings.CORRELATION_TIME_WINDOW, SecurityAnalyticsSettings.DEFAULT_MAPPING_SCHEMA, - SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE + SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE, + SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL, + SecurityAnalyticsSettings.BATCH_SIZE, + SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT ); } @@ -274,7 +326,8 @@ public List> getSettings() { new ActionPlugin.ActionHandler<>(SearchCorrelationRuleAction.INSTANCE, TransportSearchCorrelationRuleAction.class), new ActionHandler<>(IndexCustomLogTypeAction.INSTANCE, TransportIndexCustomLogTypeAction.class), new ActionHandler<>(SearchCustomLogTypeAction.INSTANCE, TransportSearchCustomLogTypeAction.class), - new ActionHandler<>(DeleteCustomLogTypeAction.INSTANCE, TransportDeleteCustomLogTypeAction.class) + new ActionHandler<>(DeleteCustomLogTypeAction.INSTANCE, TransportDeleteCustomLogTypeAction.class), + new ActionHandler<>(PutTIFJobAction.INSTANCE, TransportPutTIFJobAction.class) ); } @@ -292,5 +345,5 @@ public void onFailure(Exception e) { log.warn("Failed to initialize LogType config index and builtin log types"); } }); - } + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetDetectorResponse.java b/src/main/java/org/opensearch/securityanalytics/action/GetDetectorResponse.java index 3e4fc68d1..0d700b88c 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetDetectorResponse.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetDetectorResponse.java @@ -68,6 +68,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(Detector.INPUTS_FIELD, detector.getInputs()) .field(Detector.LAST_UPDATE_TIME_FIELD, detector.getLastUpdateTime()) .field(Detector.ENABLED_TIME_FIELD, detector.getEnabledTime()) + .field(Detector.THREAT_INTEL_ENABLED_FIELD, detector.getThreatIntelEnabled()) .endObject(); return builder.endObject(); } diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetMappingsViewResponse.java b/src/main/java/org/opensearch/securityanalytics/action/GetMappingsViewResponse.java index e242e69c4..7606d029f 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetMappingsViewResponse.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetMappingsViewResponse.java @@ -4,37 +4,49 @@ */ package org.opensearch.securityanalytics.action; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Objects; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.securityanalytics.mapper.MapperUtils; +import org.opensearch.securityanalytics.model.LogType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; public class GetMappingsViewResponse extends ActionResponse implements ToXContentObject { public static final String UNMAPPED_INDEX_FIELDS = "unmapped_index_fields"; public static final String UNMAPPED_FIELD_ALIASES = "unmapped_field_aliases"; + public static final String THREAT_INTEL_FIELD_ALIASES = "threat_intel_field_aliases"; private Map aliasMappings; List unmappedIndexFields; List unmappedFieldAliases; + /** This field sheds information on the list of field aliases that need to be mapped for a given IoC. + * For ex. one element for windows logtype would be + *{"ioc": "ip", "fields": ["destination.ip","source.ip"]} where "ip" is the IoC and the required field aliases to be mapped for + * threat intel based detection are "destination.ip","source.ip".*/ + private List threatIntelFieldAliases; + public GetMappingsViewResponse( Map aliasMappings, List unmappedIndexFields, - List unmappedFieldAliases + List unmappedFieldAliases, + List threatIntelFieldAliases ) { this.aliasMappings = aliasMappings; this.unmappedIndexFields = unmappedIndexFields; this.unmappedFieldAliases = unmappedFieldAliases; + this.threatIntelFieldAliases = threatIntelFieldAliases; } public GetMappingsViewResponse(StreamInput in) throws IOException { @@ -56,6 +68,7 @@ public GetMappingsViewResponse(StreamInput in) throws IOException { unmappedFieldAliases.add(in.readString()); } } + this.threatIntelFieldAliases = in.readList(LogType.IocFields::readFrom); } @Override @@ -82,6 +95,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeVInt(0); } + if(threatIntelFieldAliases!=null) { + out.writeBoolean(true); + out.writeCollection(threatIntelFieldAliases); + } else { + out.writeBoolean(false); + } } @Override @@ -96,6 +115,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (unmappedFieldAliases != null && unmappedFieldAliases.size() > 0) { builder.field(UNMAPPED_FIELD_ALIASES, unmappedFieldAliases); } + if(threatIntelFieldAliases != null && false == threatIntelFieldAliases.isEmpty()) { + builder.field(THREAT_INTEL_FIELD_ALIASES, threatIntelFieldAliases); + } return builder.endObject(); } diff --git a/src/main/java/org/opensearch/securityanalytics/action/IndexDetectorResponse.java b/src/main/java/org/opensearch/securityanalytics/action/IndexDetectorResponse.java index 6a7c268c1..67fe36f0b 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/IndexDetectorResponse.java +++ b/src/main/java/org/opensearch/securityanalytics/action/IndexDetectorResponse.java @@ -64,6 +64,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(Detector.TRIGGERS_FIELD, detector.getTriggers()) .field(Detector.LAST_UPDATE_TIME_FIELD, detector.getLastUpdateTime()) .field(Detector.ENABLED_TIME_FIELD, detector.getEnabledTime()) + .field(Detector.THREAT_INTEL_ENABLED_FIELD, detector.getThreatIntelEnabled()) .endObject(); return builder.endObject(); } diff --git a/src/main/java/org/opensearch/securityanalytics/logtype/LogTypeService.java b/src/main/java/org/opensearch/securityanalytics/logtype/LogTypeService.java index fe1402e59..bec6ef8ae 100644 --- a/src/main/java/org/opensearch/securityanalytics/logtype/LogTypeService.java +++ b/src/main/java/org/opensearch/securityanalytics/logtype/LogTypeService.java @@ -10,6 +10,7 @@ import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -660,6 +661,13 @@ public void getRuleFieldMappings(String logType, ActionListener getIocFieldsList(String logType) { + LogType logTypeByName = builtinLogTypeLoader.getLogTypeByName(logType); + if(logTypeByName == null) + return Collections.emptyList(); + return logTypeByName.getIocFieldsList(); + } + public void getRuleFieldMappingsAllSchemas(String logType, ActionListener> listener) { if (builtinLogTypeLoader.logTypeExists(logType)) { diff --git a/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java b/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java index 26f9c1602..3aedc0c8f 100644 --- a/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java +++ b/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java @@ -5,21 +5,10 @@ package org.opensearch.securityanalytics.mapper; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchStatusException; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.admin.indices.get.GetIndexRequest; import org.opensearch.action.admin.indices.get.GetIndexResponse; import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; @@ -33,8 +22,9 @@ import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.securityanalytics.action.GetIndexMappingsResponse; import org.opensearch.securityanalytics.action.GetMappingsViewResponse; import org.opensearch.securityanalytics.logtype.LogTypeService; @@ -43,6 +33,16 @@ import org.opensearch.securityanalytics.util.IndexUtils; import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import static org.opensearch.securityanalytics.mapper.MapperUtils.PATH; import static org.opensearch.securityanalytics.mapper.MapperUtils.PROPERTIES; @@ -57,7 +57,8 @@ public class MapperService { private IndexTemplateManager indexTemplateManager; private LogTypeService logTypeService; - public MapperService() {} + public MapperService() { + } public MapperService(Client client, ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver, IndexTemplateManager indexTemplateManager, LogTypeService logTypeService) { this.indicesClient = client.admin().indices(); @@ -122,7 +123,7 @@ public void onFailure(Exception e) { } private void applyAliasMappings(Map indexMappings, String logType, String aliasMappings, boolean partial, ActionListener> actionListener) { - int numOfIndices = indexMappings.size(); + int numOfIndices = indexMappings.size(); GroupedActionListener doCreateMappingActionsListener = new GroupedActionListener(new ActionListener>() { @Override @@ -150,12 +151,13 @@ public void onFailure(Exception e) { /** * Applies alias mappings to index. - * @param indexName Index name + * + * @param indexName Index name * @param mappingMetadata Index mappings - * @param logType Rule topic spcifying specific alias templates - * @param aliasMappings User-supplied alias mappings - * @param partial Partial flag indicating if we should apply mappings partially, in case source index doesn't have all paths specified in alias mappings - * @param actionListener actionListener used to return response/error + * @param logType Rule topic spcifying specific alias templates + * @param aliasMappings User-supplied alias mappings + * @param partial Partial flag indicating if we should apply mappings partially, in case source index doesn't have all paths specified in alias mappings + * @param actionListener actionListener used to return response/error */ private void doCreateMapping( String indexName, @@ -224,7 +226,7 @@ public void onResponse(List mappings) { List indexFields = MapperUtils.extractAllFieldsFlat(mappingMetadata); Map> aliasMappingFields = new HashMap<>(); XContentBuilder aliasMappingsObj = XContentFactory.jsonBuilder().startObject(); - for (LogType.Mapping mapping: mappings) { + for (LogType.Mapping mapping : mappings) { if (indexFields.contains(mapping.getRawField())) { aliasMappingFields.put(mapping.getEcs(), Map.of("type", "alias", "path", mapping.getRawField())); } else if (indexFields.contains(mapping.getOcsf())) { @@ -293,7 +295,7 @@ public void onFailure(Exception e) { } }); } - } catch(IOException | IllegalArgumentException e){ + } catch (IOException | IllegalArgumentException e) { actionListener.onFailure(e); } } @@ -308,7 +310,7 @@ private Map filterNonApplicableAliases( Map filteredAliasMappings = mappingsTraverser.traverseAndCopyAsFlat(); List> propertiesToSkip = new ArrayList<>(); - if(missingPathsInIndex.size() > 0) { + if (missingPathsInIndex.size() > 0) { // Filter out missing paths from alias mappings so that our PutMappings request succeeds propertiesToSkip.addAll( missingPathsInIndex.stream() @@ -423,6 +425,7 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { } }, actionListener::onFailure)); } + @Override public void onFailure(Exception e) { actionListener.onFailure(e); @@ -457,9 +460,10 @@ public void onFailure(Exception e) { /** * Constructs Mappings View of index - * @param logType Log Type + * + * @param logType Log Type * @param actionListener Action Listener - * @param concreteIndex Concrete Index name for which we're computing Mappings View + * @param concreteIndex Concrete Index name for which we're computing Mappings View */ private void doGetMappingsView(String logType, ActionListener actionListener, String concreteIndex) { GetMappingsRequest getMappingsRequest = new GetMappingsRequest().indices(concreteIndex); @@ -479,7 +483,7 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { // List of unapplayable aliases List unmappedFieldAliases = new ArrayList<>(); - for (LogType.Mapping requiredField: requiredFields) { + for (LogType.Mapping requiredField : requiredFields) { String alias = requiredField.getEcs(); String rawPath = requiredField.getRawField(); String ocsfPath = requiredField.getOcsf(); @@ -494,7 +498,7 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { } else if (allFieldsFromIndex.contains(ocsfPath)) { applyableAliases.add(alias); pathsOfApplyableAliases.add(ocsfPath); - } else if ((alias == null && allFieldsFromIndex.contains(rawPath) == false) || allFieldsFromIndex.contains(alias) == false) { + } else if ((alias == null && allFieldsFromIndex.contains(rawPath) == false) || allFieldsFromIndex.contains(alias) == false) { if (alias != null) { // we don't want to send back aliases which have same name as existing field in index unmappedFieldAliases.add(alias); @@ -506,7 +510,7 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { Map> aliasMappingFields = new HashMap<>(); XContentBuilder aliasMappingsObj = XContentFactory.jsonBuilder().startObject(); - for (LogType.Mapping mapping: requiredFields) { + for (LogType.Mapping mapping : requiredFields) { if (allFieldsFromIndex.contains(mapping.getOcsf())) { aliasMappingFields.put(mapping.getEcs(), Map.of("type", "alias", "path", mapping.getOcsf())); } else if (mapping.getEcs() != null) { @@ -523,15 +527,15 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { .stream() .filter(e -> pathsOfApplyableAliases.contains(e) == false) .collect(Collectors.toList()); - actionListener.onResponse( - new GetMappingsViewResponse(aliasMappings, unmappedIndexFields, unmappedFieldAliases) + new GetMappingsViewResponse(aliasMappings, unmappedIndexFields, unmappedFieldAliases, logTypeService.getIocFieldsList(logType)) ); } catch (Exception e) { actionListener.onFailure(e); } }, actionListener::onFailure)); } + @Override public void onFailure(Exception e) { actionListener.onFailure(e); @@ -542,7 +546,8 @@ public void onFailure(Exception e) { /** * Given index name, resolves it to single concrete index, depending on what initial indexName is. * In case of Datastream or Alias, WriteIndex would be returned. In case of index pattern, newest index by creation date would be returned. - * @param indexName Datastream, Alias, index patter or concrete index + * + * @param indexName Datastream, Alias, index patter or concrete index * @param actionListener Action Listener * @throws IOException */ @@ -583,6 +588,7 @@ public void onFailure(Exception e) { void setIndicesAdminClient(IndicesAdminClient client) { this.indicesClient = client; } + void setClusterService(ClusterService clusterService) { this.clusterService = clusterService; } diff --git a/src/main/java/org/opensearch/securityanalytics/model/Detector.java b/src/main/java/org/opensearch/securityanalytics/model/Detector.java index ff832d1e7..5a8e2f32b 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Detector.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Detector.java @@ -25,14 +25,11 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Objects; -import java.util.stream.Collectors; - public class Detector implements Writeable, ToXContentObject { private static final Logger log = LogManager.getLogger(Detector.class); @@ -51,6 +48,7 @@ public class Detector implements Writeable, ToXContentObject { public static final String TRIGGERS_FIELD = "triggers"; public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String ENABLED_TIME_FIELD = "enabled_time"; + public static final String THREAT_INTEL_ENABLED_FIELD = "threat_intel_enabled"; public static final String ALERTING_MONITOR_ID = "monitor_id"; public static final String ALERTING_WORKFLOW_ID = "workflow_ids"; @@ -82,6 +80,8 @@ public class Detector implements Writeable, ToXContentObject { private String name; + private Boolean threatIntelEnabled; + private Boolean enabled; private Schedule schedule; @@ -122,7 +122,8 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule Instant lastUpdateTime, Instant enabledTime, String logType, User user, List inputs, List triggers, List monitorIds, String ruleIndex, String alertsIndex, String alertsHistoryIndex, String alertsHistoryIndexPattern, - String findingsIndex, String findingsIndexPattern, Map rulePerMonitor, List workflowIds) { + String findingsIndex, String findingsIndexPattern, Map rulePerMonitor, + List workflowIds, Boolean threatIntelEnabled) { this.type = DETECTOR_TYPE; this.id = id != null ? id : NO_ID; @@ -145,6 +146,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule this.ruleIdMonitorIdMap = rulePerMonitor; this.logType = logType; this.workflowIds = workflowIds != null ? workflowIds : null; + this.threatIntelEnabled = threatIntelEnabled != null && threatIntelEnabled; if (enabled) { Objects.requireNonNull(enabledTime); @@ -172,7 +174,8 @@ public Detector(StreamInput sin) throws IOException { sin.readString(), sin.readString(), sin.readMap(StreamInput::readString, StreamInput::readString), - sin.readStringList() + sin.readStringList(), + sin.readBoolean() ); } @@ -211,6 +214,7 @@ public void writeTo(StreamOutput out) throws IOException { if (workflowIds != null) { out.writeStringCollection(workflowIds); } + out.writeBoolean(threatIntelEnabled); } public XContentBuilder toXContentWithUser(XContentBuilder builder, Params params) throws IOException { @@ -239,6 +243,7 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten } } + builder.field(THREAT_INTEL_ENABLED_FIELD, threatIntelEnabled); builder.field(ENABLED_FIELD, enabled); if (enabledTime == null) { @@ -280,7 +285,6 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten builder.field(FINDINGS_INDEX, findingsIndex); builder.field(FINDINGS_INDEX_PATTERN, findingsIndexPattern); - if (params.paramAsBoolean("with_type", false)) { builder.endObject(); } @@ -327,6 +331,7 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws String alertsHistoryIndexPattern = null; String findingsIndex = null; String findingsIndexPattern = null; + Boolean enableThreatIntel = false; XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { @@ -350,6 +355,9 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws case ENABLED_FIELD: enabled = xcp.booleanValue(); break; + case THREAT_INTEL_ENABLED_FIELD: + enableThreatIntel = xcp.booleanValue(); + break; case SCHEDULE_FIELD: schedule = Schedule.parse(xcp); break; @@ -459,7 +467,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws findingsIndex, findingsIndexPattern, rulePerMonitor, - workflowIds + workflowIds, + enableThreatIntel ); } @@ -600,6 +609,10 @@ public void setWorkflowIds(List workflowIds) { this.workflowIds = workflowIds; } + public void setThreatIntelEnabled(boolean threatIntelEnabled) { + this.threatIntelEnabled = threatIntelEnabled; + } + public List getWorkflowIds() { return workflowIds; } @@ -612,6 +625,10 @@ public boolean isWorkflowSupported() { return workflowIds != null && !workflowIds.isEmpty(); } + public Boolean getThreatIntelEnabled() { + return threatIntelEnabled; + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java b/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java index b74a71048..ed74ea9e0 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java +++ b/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java @@ -49,13 +49,23 @@ public class DetectorTrigger implements Writeable, ToXContentObject { private List actions; + /** + * detection type is a list of values that tells us what queries is the trigger trying to match - rules-based or threat_intel-based or both + */ + private List detectionTypes; // todo make it enum supports 'rules', 'threat_intel' + private static final String ID_FIELD = "id"; + private static final String SEVERITY_FIELD = "severity"; private static final String RULE_TYPES_FIELD = "types"; private static final String RULE_IDS_FIELD = "ids"; private static final String RULE_SEV_LEVELS_FIELD = "sev_levels"; private static final String RULE_TAGS_FIELD = "tags"; private static final String ACTIONS_FIELD = "actions"; + private static final String DETECTION_TYPES_FIELD = "detection_types"; + + public static final String RULES_DETECTION_TYPE = "rules"; + public static final String THREAT_INTEL_DETECTION_TYPE = "threat_intel"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( DetectorTrigger.class, @@ -63,17 +73,29 @@ public class DetectorTrigger implements Writeable, ToXContentObject { DetectorTrigger::parse ); - public DetectorTrigger(String id, String name, String severity, List ruleTypes, List ruleIds, List ruleSeverityLevels, List tags, List actions) { - this.id = id == null? UUIDs.base64UUID(): id; + public DetectorTrigger(String id, + String name, + String severity, + List ruleTypes, + List ruleIds, + List ruleSeverityLevels, + List tags, + List actions, + List detectionTypes) { + this.id = id == null ? UUIDs.base64UUID() : id; this.name = name; this.severity = severity; this.ruleTypes = ruleTypes.stream() - .map( e -> e.toLowerCase(Locale.ROOT)) + .map(e -> e.toLowerCase(Locale.ROOT)) .collect(Collectors.toList()); this.ruleIds = ruleIds; this.ruleSeverityLevels = ruleSeverityLevels; this.tags = tags; this.actions = actions; + this.detectionTypes = detectionTypes; + if(this.detectionTypes.isEmpty()) { + this.detectionTypes = Collections.singletonList(RULES_DETECTION_TYPE); // for backward compatibility + } } public DetectorTrigger(StreamInput sin) throws IOException { @@ -85,7 +107,8 @@ public DetectorTrigger(StreamInput sin) throws IOException { sin.readStringList(), sin.readStringList(), sin.readStringList(), - sin.readList(Action::readFrom) + sin.readList(Action::readFrom), + sin.readStringList() ); } @@ -95,7 +118,8 @@ public Map asTemplateArg() { RULE_IDS_FIELD, ruleIds, RULE_SEV_LEVELS_FIELD, ruleSeverityLevels, RULE_TAGS_FIELD, tags, - ACTIONS_FIELD, actions.stream().map(Action::asTemplateArg) + ACTIONS_FIELD, actions.stream().map(Action::asTemplateArg), + DETECTION_TYPES_FIELD, detectionTypes ); } @@ -109,6 +133,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(ruleSeverityLevels); out.writeStringCollection(tags); out.writeCollection(actions); + out.writeStringCollection(detectionTypes); } @Override @@ -128,6 +153,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws Action[] actionArray = new Action[]{}; actionArray = actions.toArray(actionArray); + String[] detectionTypesArray = new String[]{}; + detectionTypesArray = detectionTypes.toArray(detectionTypesArray); + return builder.startObject() .field(ID_FIELD, id) .field(Detector.NAME_FIELD, name) @@ -137,6 +165,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(RULE_SEV_LEVELS_FIELD, ruleSevLevelArray) .field(RULE_TAGS_FIELD, tagArray) .field(ACTIONS_FIELD, actionArray) + .field(DETECTION_TYPES_FIELD, detectionTypesArray) .endObject(); } @@ -149,6 +178,7 @@ public static DetectorTrigger parse(XContentParser xcp) throws IOException { List ruleSeverityLevels = new ArrayList<>(); List tags = new ArrayList<>(); List actions = new ArrayList<>(); + List detectionTypes = new ArrayList<>(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { @@ -193,6 +223,13 @@ public static DetectorTrigger parse(XContentParser xcp) throws IOException { tags.add(tag); } break; + case DETECTION_TYPES_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + String dt = xcp.text(); + detectionTypes.add(dt); + } + break; case ACTIONS_FIELD: XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { @@ -204,8 +241,10 @@ public static DetectorTrigger parse(XContentParser xcp) throws IOException { xcp.skipChildren(); } } - - return new DetectorTrigger(id, name, severity, ruleTypes, ruleNames, ruleSeverityLevels, tags, actions); + if(detectionTypes.isEmpty()) { + detectionTypes.add(RULES_DETECTION_TYPE); // for backward compatibility + } + return new DetectorTrigger(id, name, severity, ruleTypes, ruleNames, ruleSeverityLevels, tags, actions, detectionTypes); } public static DetectorTrigger readFrom(StreamInput sin) throws IOException { @@ -227,71 +266,83 @@ public int hashCode() { public Script convertToCondition() { StringBuilder condition = new StringBuilder(); + boolean triggerFlag = false; - StringBuilder ruleTypeBuilder = new StringBuilder(); - int size = ruleTypes.size(); - for (int idx = 0; idx < size; ++idx) { - ruleTypeBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleTypes.get(idx))); - if (idx < size - 1) { - ruleTypeBuilder.append(" || "); + int size = 0; + if (detectionTypes.contains(RULES_DETECTION_TYPE)) { // trigger should match rules based queries based on conditions + StringBuilder ruleTypeBuilder = new StringBuilder(); + size = ruleTypes.size(); + for (int idx = 0; idx < size; ++idx) { + ruleTypeBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleTypes.get(idx))); + if (idx < size - 1) { + ruleTypeBuilder.append(" || "); + } + } + if (size > 0) { + condition.append("(").append(ruleTypeBuilder).append(")"); + triggerFlag = true; } - } - if (size > 0) { - condition.append("(").append(ruleTypeBuilder).append(")"); - triggerFlag = true; - } - StringBuilder ruleNameBuilder = new StringBuilder(); - size = ruleIds.size(); - for (int idx = 0; idx < size; ++idx) { - ruleNameBuilder.append(String.format(Locale.getDefault(), "query[name=%s]", ruleIds.get(idx))); - if (idx < size - 1) { - ruleNameBuilder.append(" || "); + StringBuilder ruleNameBuilder = new StringBuilder(); + size = ruleIds.size(); + for (int idx = 0; idx < size; ++idx) { + ruleNameBuilder.append(String.format(Locale.getDefault(), "query[name=%s]", ruleIds.get(idx))); + if (idx < size - 1) { + ruleNameBuilder.append(" || "); + } } - } - if (size > 0) { - if (triggerFlag) { - condition.append(" && ").append("(").append(ruleNameBuilder).append(")"); - } else { - condition.append("(").append(ruleNameBuilder).append(")"); - triggerFlag = true; + if (size > 0) { + if (triggerFlag) { + condition.append(" && ").append("(").append(ruleNameBuilder).append(")"); + } else { + condition.append("(").append(ruleNameBuilder).append(")"); + triggerFlag = true; + } } - } - StringBuilder ruleSevLevelBuilder = new StringBuilder(); - size = ruleSeverityLevels.size(); - for (int idx = 0; idx < size; ++idx) { - ruleSevLevelBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleSeverityLevels.get(idx))); - if (idx < size - 1) { - ruleSevLevelBuilder.append(" || "); + StringBuilder ruleSevLevelBuilder = new StringBuilder(); + size = ruleSeverityLevels.size(); + for (int idx = 0; idx < size; ++idx) { + ruleSevLevelBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleSeverityLevels.get(idx))); + if (idx < size - 1) { + ruleSevLevelBuilder.append(" || "); + } } - } - if (size > 0) { - if (triggerFlag) { - condition.append(" && ").append("(").append(ruleSevLevelBuilder).append(")"); - } else { - condition.append("(").append(ruleSevLevelBuilder).append(")"); - triggerFlag = true; + if (size > 0) { + if (triggerFlag) { + condition.append(" && ").append("(").append(ruleSevLevelBuilder).append(")"); + } else { + condition.append("(").append(ruleSevLevelBuilder).append(")"); + triggerFlag = true; + } } - } - StringBuilder tagBuilder = new StringBuilder(); - size = tags.size(); - for (int idx = 0; idx < size; ++idx) { - tagBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", tags.get(idx))); - if (idx < size - 1) { - ruleSevLevelBuilder.append(" || "); + StringBuilder tagBuilder = new StringBuilder(); + size = tags.size(); + for (int idx = 0; idx < size; ++idx) { + tagBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", tags.get(idx))); + if (idx < size - 1) { + ruleSevLevelBuilder.append(" || "); + } } - } - if (size > 0) { - if (triggerFlag) { - condition.append(" && ").append("(").append(tagBuilder).append(")"); - } else { - condition.append("(").append(tagBuilder).append(")"); + if (size > 0) { + if (triggerFlag) { + condition.append(" && ").append("(").append(tagBuilder).append(")"); + } else { + condition.append("(").append(tagBuilder).append(")"); + } + } + } + if(detectionTypes.contains(THREAT_INTEL_DETECTION_TYPE)) { + StringBuilder threatIntelClauseBuilder = new StringBuilder(); + threatIntelClauseBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", "threat_intel")); + if (condition.length() > 0) { + condition.append(" || "); } + condition.append("(").append(threatIntelClauseBuilder).append(")"); } return new Script(condition.toString()); @@ -321,6 +372,10 @@ public List getRuleSeverityLevels() { return ruleSeverityLevels; } + public List getDetectionTypes() { + return detectionTypes; + } + public List getTags() { return tags; } @@ -329,8 +384,8 @@ public List getActions() { List transformedActions = new ArrayList<>(); if (actions != null) { - for (Action action: actions) { - String subjectTemplate = action.getSubjectTemplate() != null ? action.getSubjectTemplate().getIdOrCode(): ""; + for (Action action : actions) { + String subjectTemplate = action.getSubjectTemplate() != null ? action.getSubjectTemplate().getIdOrCode() : ""; subjectTemplate = subjectTemplate.replace("{{ctx.detector", "{{ctx.monitor"); action.getMessageTemplate(); diff --git a/src/main/java/org/opensearch/securityanalytics/model/LogType.java b/src/main/java/org/opensearch/securityanalytics/model/LogType.java index 7acc0d1f3..f70a462e2 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/LogType.java +++ b/src/main/java/org/opensearch/securityanalytics/model/LogType.java @@ -4,17 +4,19 @@ */ package org.opensearch.securityanalytics.model; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + public class LogType implements Writeable { private static final String ID = "id"; @@ -25,12 +27,16 @@ public class LogType implements Writeable { private static final String RAW_FIELD = "raw_field"; public static final String ECS = "ecs"; public static final String OCSF = "ocsf"; + public static final String IOC_FIELDS = "ioc_fields"; + public static final String IOC = "ioc"; + public static final String FIELDS = "fields"; private String id; private String name; private String description; private Boolean isBuiltIn; private List mappings; + private List iocFieldsList; public LogType(StreamInput sin) throws IOException { this.id = sin.readString(); @@ -38,14 +44,16 @@ public LogType(StreamInput sin) throws IOException { this.name = sin.readString(); this.description = sin.readString(); this.mappings = sin.readList(Mapping::readFrom); + this.iocFieldsList = sin.readList(IocFields::readFrom); } - public LogType(String id, String name, String description, boolean isBuiltIn, List mappings) { + public LogType(String id, String name, String description, boolean isBuiltIn, List mappings, List iocFieldsList) { this.id = id; this.name = name; this.description = description; this.isBuiltIn = isBuiltIn; this.mappings = mappings == null ? List.of() : mappings; + this.iocFieldsList = iocFieldsList == null ? List.of() : iocFieldsList; } public LogType(Map logTypeAsMap) { @@ -55,13 +63,21 @@ public LogType(Map logTypeAsMap) { if (logTypeAsMap.containsKey(IS_BUILTIN)) { this.isBuiltIn = (Boolean) logTypeAsMap.get(IS_BUILTIN); } - List> mappings = (List>)logTypeAsMap.get(MAPPINGS); + List> mappings = (List>) logTypeAsMap.get(MAPPINGS); if (mappings.size() > 0) { this.mappings = new ArrayList<>(mappings.size()); this.mappings = mappings.stream().map(e -> new Mapping(e.get(RAW_FIELD), e.get(ECS), e.get(OCSF)) ).collect(Collectors.toList()); } + if (logTypeAsMap.containsKey(IOC_FIELDS)) { + List> iocFieldsList = (List>) logTypeAsMap.get(IOC_FIELDS); + this.iocFieldsList = iocFieldsList.stream().map(e -> + new IocFields(e.get(IOC).toString(), (List) e.get(FIELDS)) + ).collect(Collectors.toList()); + } else { + iocFieldsList = Collections.emptyList(); + } } public String getName() { @@ -72,7 +88,13 @@ public String getDescription() { return description; } - public boolean getIsBuiltIn() { return isBuiltIn; } + public boolean getIsBuiltIn() { + return isBuiltIn; + } + + public List getIocFieldsList() { + return iocFieldsList; + } public List getMappings() { return mappings; @@ -85,6 +107,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeString(description); out.writeCollection(mappings); + out.writeCollection(iocFieldsList); } @Override @@ -134,4 +157,54 @@ public static Mapping readFrom(StreamInput sin) throws IOException { } } + /** + * stores information of list of field names that contain information for given IoC (Indicator of Compromise). + */ + public static class IocFields implements Writeable, ToXContentObject { + + private final String ioc; + private final List fields; + + public IocFields(String ioc, List fields) { + this.ioc = ioc; + this.fields = fields; + } + + public IocFields(StreamInput sin) throws IOException { + this.ioc = sin.readString(); + this.fields = sin.readStringList(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(ioc); + out.writeStringCollection(fields); + } + + public String getIoc() { + return ioc; + } + + public List getFields() { + return fields; + } + + + public static IocFields readFrom(StreamInput sin) throws IOException { + return new IocFields(sin); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + String[] fieldsArray = new String[]{}; + fieldsArray = fields.toArray(fieldsArray); + builder.startObject() + .field(IOC, ioc) + .field(FIELDS, fieldsArray) + .endObject(); + return builder; + } + } + + } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/model/ThreatIntelFeedData.java b/src/main/java/org/opensearch/securityanalytics/model/ThreatIntelFeedData.java new file mode 100644 index 000000000..169270e9b --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/model/ThreatIntelFeedData.java @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.model; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; + +import java.io.IOException; +import java.time.Instant; +import java.util.Locale; +import java.util.Objects; + +/** + * Model for threat intel feed data stored in system index. + */ +public class ThreatIntelFeedData implements Writeable, ToXContentObject { + private static final Logger log = LogManager.getLogger(ThreatIntelFeedData.class); + private static final String FEED_TYPE = "feed"; + private static final String TYPE_FIELD = "type"; + private static final String IOC_TYPE_FIELD = "ioc_type"; + private static final String IOC_VALUE_FIELD = "ioc_value"; + private static final String FEED_ID_FIELD = "feed_id"; + private static final String TIMESTAMP_FIELD = "timestamp"; + + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + ThreatIntelFeedData.class, + new ParseField(FEED_TYPE), + xcp -> parse(xcp, null, null) + ); + + private final String iocType; + private final String iocValue; + private final String feedId; + private final Instant timestamp; + private final String type; + + public ThreatIntelFeedData(String iocType, String iocValue, String feedId, Instant timestamp) { + this.type = FEED_TYPE; + + this.iocType = iocType; + this.iocValue = iocValue; + this.feedId = feedId; + this.timestamp = timestamp; + } + + public static ThreatIntelFeedData parse(XContentParser xcp, String id, Long version) throws IOException { + String iocType = null; + String iocValue = null; + String feedId = null; + Instant timestamp = null; + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + + switch (fieldName) { + case IOC_TYPE_FIELD: + iocType = xcp.text(); + break; + case IOC_VALUE_FIELD: + iocValue = xcp.text(); + break; + case FEED_ID_FIELD: + feedId = xcp.text(); + break; + case TIMESTAMP_FIELD: + if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { + timestamp = null; + } else if (xcp.currentToken().isValue()) { + timestamp = Instant.ofEpochMilli(xcp.longValue()); + } else { + XContentParserUtils.throwUnknownToken(xcp.currentToken(), xcp.getTokenLocation()); + timestamp = null; + } + break; + default: + xcp.skipChildren(); + } + } + return new ThreatIntelFeedData(iocType, iocValue, feedId, timestamp); + } + + public String getIocType() { + return iocType; + } + + public String getIocValue() { + return iocValue; + } + + public String getFeedId() { + return feedId; + } + + public Instant getTimestamp() { + return timestamp; + } + + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(iocType); + out.writeString(iocValue); + out.writeString(feedId); + out.writeInstant(timestamp); + } + + public ThreatIntelFeedData(StreamInput sin) throws IOException { + this( + sin.readString(), + sin.readString(), + sin.readString(), + sin.readInstant() + ); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return createXContentBuilder(builder, params); + + } + + private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean("with_type", false)) { + builder.startObject(type); + } + builder + .field(TYPE_FIELD, type) + .field(IOC_TYPE_FIELD, iocType) + .field(IOC_VALUE_FIELD, iocValue) + .field(FEED_ID_FIELD, feedId) + .timeField(TIMESTAMP_FIELD, String.format(Locale.getDefault(), "%s_in_millis", TIMESTAMP_FIELD), timestamp.toEpochMilli()); + + return builder.endObject(); + } + + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ThreatIntelFeedData tif = (ThreatIntelFeedData) o; + return Objects.equals(iocType, tif.iocType) && Objects.equals(iocValue, tif.iocValue) && Objects.equals(feedId, tif.feedId); + } + + @Override + public int hashCode() { + return Objects.hash(); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestIndexDetectorAction.java index 489ce5ffb..6fac7a078 100644 --- a/src/main/java/org/opensearch/securityanalytics/resthandler/RestIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestIndexDetectorAction.java @@ -23,6 +23,7 @@ import org.opensearch.securityanalytics.action.IndexDetectorRequest; import org.opensearch.securityanalytics.action.IndexDetectorResponse; import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DetectorTrigger; import org.opensearch.securityanalytics.util.DetectorUtils; import org.opensearch.securityanalytics.util.RestHandlerUtils; @@ -67,11 +68,26 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli Detector detector = Detector.parse(xcp, id, null); detector.setLastUpdateTime(Instant.now()); + validateDetectorTriggers(detector); IndexDetectorRequest indexDetectorRequest = new IndexDetectorRequest(id, refreshPolicy, request.method(), detector); return channel -> client.execute(IndexDetectorAction.INSTANCE, indexDetectorRequest, indexDetectorResponse(channel, request.method())); } + private static void validateDetectorTriggers(Detector detector) { + if(detector.getTriggers() != null) { + for (DetectorTrigger trigger : detector.getTriggers()) { + if(trigger.getDetectionTypes().isEmpty()) + throw new IllegalArgumentException(String.format(Locale.ROOT,"Trigger [%s] should mention at least one detection type but found none", trigger.getName())); + for (String detectionType : trigger.getDetectionTypes()) { + if(false == (DetectorTrigger.THREAT_INTEL_DETECTION_TYPE.equals(detectionType) || DetectorTrigger.RULES_DETECTION_TYPE.equals(detectionType))) { + throw new IllegalArgumentException(String.format(Locale.ROOT,"Trigger [%s] has unsupported detection type [%s]", trigger.getName(), detectionType)); + } + } + } + } + } + private RestResponseListener indexDetectorResponse(RestChannel channel, RestRequest.Method restMethod) { return new RestResponseListener<>(channel) { @Override diff --git a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java index 4085d7ae2..f8942e70e 100644 --- a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java +++ b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java @@ -4,10 +4,11 @@ */ package org.opensearch.securityanalytics.settings; -import java.util.concurrent.TimeUnit; import org.opensearch.common.settings.Setting; import org.opensearch.common.unit.TimeValue; -import org.opensearch.securityanalytics.model.FieldMappingDoc; + +import java.util.List; +import java.util.concurrent.TimeUnit; public class SecurityAnalyticsSettings { public static final String CORRELATION_INDEX = "index.correlation"; @@ -117,4 +118,43 @@ public class SecurityAnalyticsSettings { "ecs", Setting.Property.NodeScope, Setting.Property.Dynamic ); + + // threat intel settings + public static final Setting TIF_UPDATE_INTERVAL = Setting.timeSetting( + "plugins.security_analytics.threatintel.tifjob.update_interval", + TimeValue.timeValueMinutes(1440), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Bulk size for indexing threat intel feed data + */ + public static final Setting BATCH_SIZE = Setting.intSetting( + "plugins.security_analytics.threatintel.tifjob.batch_size", + 10000, + 1, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Timeout value for threat intel processor + */ + public static final Setting THREAT_INTEL_TIMEOUT = Setting.timeSetting( + "plugins.security_analytics.threat_intel_timeout", + TimeValue.timeValueSeconds(30), + TimeValue.timeValueSeconds(1), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Return all settings of threat intel feature + * @return a list of all settings for threat intel feature + */ + public static final List> settings() { + return List.of(BATCH_SIZE, THREAT_INTEL_TIMEOUT, TIF_UPDATE_INTERVAL); + } + } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java new file mode 100644 index 000000000..2565d8175 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java @@ -0,0 +1,203 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.commons.alerting.model.DocLevelQuery; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.securityanalytics.action.IndexDetectorAction; +import org.opensearch.securityanalytics.action.IndexDetectorRequest; +import org.opensearch.securityanalytics.action.SearchDetectorAction; +import org.opensearch.securityanalytics.action.SearchDetectorRequest; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.LogType; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.opensearch.securityanalytics.model.Detector.DETECTORS_INDEX; +import static org.opensearch.securityanalytics.util.DetectorUtils.getDetectors; + +/** + * Service that populates detectors with queries generated from threat intelligence data. + */ +public class DetectorThreatIntelService { + + private static final Logger log = LogManager.getLogger(DetectorThreatIntelService.class); + + private final ThreatIntelFeedDataService threatIntelFeedDataService; + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + public DetectorThreatIntelService(ThreatIntelFeedDataService threatIntelFeedDataService, Client client, NamedXContentRegistry xContentRegistry) { + this.threatIntelFeedDataService = threatIntelFeedDataService; + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + + /** + * Convert the feed data IOCs into query string query format to create doc level queries. + */ + public List createDocLevelQueriesFromThreatIntelList( + List iocFieldList, List tifdList, Detector detector + ) { + List queries = new ArrayList<>(); + Set iocs = tifdList.stream().map(ThreatIntelFeedData::getIocValue).collect(Collectors.toSet()); + //ioc types supported by log type + List logTypeIocs = iocFieldList.stream().map(LogType.IocFields::getIoc).collect(Collectors.toList()); + // filter out ioc types not supported for given log types + Map> iocTypeToValues = tifdList.stream().filter(t -> logTypeIocs.contains(t.getIocType())) + .collect(Collectors.groupingBy( + ThreatIntelFeedData::getIocType, + Collectors.mapping(ThreatIntelFeedData::getIocValue, Collectors.toSet()) + )); + + for (Map.Entry> entry : iocTypeToValues.entrySet()) { + String query = buildQueryStringQueryWithIocList(iocs); + List fields = iocFieldList.stream().filter(t -> entry.getKey().matches(t.getIoc())).findFirst().get().getFields(); + + // create doc + for (String field : fields) { + queries.add(new DocLevelQuery( + constructId(detector, entry.getKey()), tifdList.get(0).getFeedId(), + Collections.emptyList(), + String.format(query, field), + List.of( + "threat_intel", + String.format("ioc_type:%s", entry.getKey()), + String.format("field:%s", field), + String.format("feed_name:%s", tifdList.get(0).getFeedId()) + ) + )); + } + } + return queries; + } + + private String buildQueryStringQueryWithIocList(Set iocs) { + StringBuilder sb = new StringBuilder(); + sb.append("%s"); + sb.append(":"); + sb.append("("); + for (String ioc : iocs) { + if (sb.length() > 4) { + sb.append(" OR "); + } + sb.append(ioc); + + } + sb.append(")"); + return sb.toString(); + } + + /** + * Fetches threat intel data and creates doc level queries from threat intel data + */ + public void createDocLevelQueryFromThreatIntel(List iocFieldList, Detector detector, ActionListener> listener) { + try { + if (false == detector.getThreatIntelEnabled() || iocFieldList.isEmpty()) { + listener.onResponse(Collections.emptyList()); + return; + } + + CountDownLatch latch = new CountDownLatch(1); + threatIntelFeedDataService.getThreatIntelFeedData(new ActionListener<>() { + @Override + public void onResponse(List threatIntelFeedData) { + if (threatIntelFeedData.isEmpty()) { + listener.onResponse(Collections.emptyList()); + } else { + listener.onResponse( + createDocLevelQueriesFromThreatIntelList(iocFieldList, threatIntelFeedData, detector) + ); + } + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to get threat intel feeds for doc level query creation", e); + listener.onFailure(e); + latch.countDown(); + } + }); + + latch.await(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + log.error("Failed to create doc level queries from threat intel feeds", e); + listener.onFailure(e); + } + + } + + private static String constructId(Detector detector, String iocType) { + return "threat_intel_" + UUID.randomUUID(); + } + + /** Updates all detectors having threat intel detection enabled with the latest threat intel feed data*/ + public void updateDetectorsWithLatestThreatIntelRules() { + try { + QueryBuilder queryBuilder = + QueryBuilders.nestedQuery("detector", + QueryBuilders.boolQuery().must( + QueryBuilders.matchQuery("detector.threat_intel_enabled", true) + ), ScoreMode.Avg); + SearchRequest searchRequest = new SearchRequest(DETECTORS_INDEX); + SearchSourceBuilder ssb = searchRequest.source(); + ssb.query(queryBuilder); + ssb.size(9999); + CountDownLatch countDownLatch = new CountDownLatch(1); + client.execute(SearchDetectorAction.INSTANCE, new SearchDetectorRequest(searchRequest), + ActionListener.wrap(searchResponse -> { + List detectors = getDetectors(searchResponse, xContentRegistry); + detectors.forEach(detector -> { + assert detector.getThreatIntelEnabled(); + client.execute(IndexDetectorAction.INSTANCE, new IndexDetectorRequest( + detector.getId(), WriteRequest.RefreshPolicy.IMMEDIATE, + RestRequest.Method.PUT, + detector), + ActionListener.wrap( + indexDetectorResponse -> { + log.debug("updated {} with latest threat intel info", indexDetectorResponse.getDetector().getId()); + countDownLatch.countDown(); + }, + e -> { + log.error(() -> new ParameterizedMessage("Failed to update detector {} with latest threat intel info", detector.getId()), e); + countDownLatch.countDown(); + })); + } + ); + }, e -> { + log.error("Failed to fetch detectors to update with threat intel queries.", e); + countDownLatch.countDown(); + })); + countDownLatch.await(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + log.error(""); + } + + + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java new file mode 100644 index 000000000..40bc7bc53 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java @@ -0,0 +1,293 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel; + +import org.apache.commons.csv.CSVRecord; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; +import org.opensearch.securityanalytics.threatIntel.action.PutTIFJobAction; +import org.opensearch.securityanalytics.threatIntel.action.PutTIFJobRequest; +import org.opensearch.securityanalytics.threatIntel.common.TIFMetadata; +import org.opensearch.securityanalytics.threatIntel.common.StashedThreadContext; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameterService; +import org.opensearch.securityanalytics.util.IndexUtils; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameter.THREAT_INTEL_DATA_INDEX_NAME_PREFIX; + +/** + * Service to handle CRUD operations on Threat Intel Feed Data + */ +public class ThreatIntelFeedDataService { + + private static final Logger log = LogManager.getLogger(ThreatIntelFeedDataService.class); + + public static final String SETTING_INDEX_REFRESH_INTERVAL = "index.refresh_interval"; + private static final Map INDEX_SETTING_TO_CREATE = Map.of( + IndexMetadata.SETTING_NUMBER_OF_SHARDS, + 1, + IndexMetadata.SETTING_NUMBER_OF_REPLICAS, + 0, + SETTING_INDEX_REFRESH_INTERVAL, + -1, + IndexMetadata.SETTING_INDEX_HIDDEN, + true + ); + + private final ClusterService clusterService; + private final ClusterSettings clusterSettings; + private final NamedXContentRegistry xContentRegistry; + private final Client client; + private final IndexNameExpressionResolver indexNameExpressionResolver; + + public ThreatIntelFeedDataService( + ClusterService clusterService, + Client client, + IndexNameExpressionResolver indexNameExpressionResolver, + NamedXContentRegistry xContentRegistry) { + this.client = client; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.xContentRegistry = xContentRegistry; + this.clusterService = clusterService; + this.clusterSettings = clusterService.getClusterSettings(); + } + + public void getThreatIntelFeedData( + ActionListener> listener + ) { + try { + + String tifdIndex = getLatestIndexByCreationDate(); + if (tifdIndex == null) { + createThreatIntelFeedData(); + tifdIndex = getLatestIndexByCreationDate(); + } + SearchRequest searchRequest = new SearchRequest(tifdIndex); + searchRequest.source().size(9999); //TODO: convert to scroll + String finalTifdIndex = tifdIndex; + client.search(searchRequest, ActionListener.wrap(r -> listener.onResponse(ThreatIntelFeedDataUtils.getTifdList(r, xContentRegistry)), e -> { + log.error(String.format( + "Failed to fetch threat intel feed data from system index %s", finalTifdIndex), e); + listener.onFailure(e); + })); + } catch (InterruptedException e) { + log.error("Failed to get threat intel feed data", e); + listener.onFailure(e); + } + } + + private String getLatestIndexByCreationDate() { + return IndexUtils.getNewIndexByCreationDate( + this.clusterService.state(), + this.indexNameExpressionResolver, + THREAT_INTEL_DATA_INDEX_NAME_PREFIX + "*" + ); + } + + /** + * Create an index for a threat intel feed + *

+ * Index setting start with single shard, zero replica, no refresh interval, and hidden. + * Once the threat intel feed is indexed, do refresh and force merge. + * Then, change the index setting to expand replica to all nodes, and read only allow delete. + * + * @param indexName index name + */ + public void createIndexIfNotExists(final String indexName) { + if (clusterService.state().metadata().hasIndex(indexName) == true) { + return; + } + final CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).settings(INDEX_SETTING_TO_CREATE) + .mapping(getIndexMapping()); + StashedThreadContext.run( + client, + () -> client.admin().indices().create(createIndexRequest).actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)) + ); + } + + /** + * Puts threat intel feed from CSVRecord iterator into a given index in bulk + * + * @param indexName Index name to save the threat intel feed + * @param iterator TIF data to insert + * @param renewLock Runnable to renew lock + */ + public void parseAndSaveThreatIntelFeedDataCSV( + final String indexName, + final Iterator iterator, + final Runnable renewLock, + final TIFMetadata tifMetadata + ) throws IOException { + if (indexName == null || iterator == null || renewLock == null) { + throw new IllegalArgumentException("Parameters cannot be null, failed to save threat intel feed data"); + } + + TimeValue timeout = clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT); + Integer batchSize = clusterSettings.get(SecurityAnalyticsSettings.BATCH_SIZE); + final BulkRequest bulkRequest = new BulkRequest(); + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + List tifdList = new ArrayList<>(); + while (iterator.hasNext()) { + CSVRecord record = iterator.next(); + String iocType = tifMetadata.getIocType(); + Integer colNum = tifMetadata.getIocCol(); + String iocValue = record.values()[colNum].split(" ")[0]; + if (iocType.equals("ip") && !isValidIp(iocValue)) { + log.info("Invalid IP address, skipping this ioc record."); + continue; + } + String feedId = tifMetadata.getFeedId(); + Instant timestamp = Instant.now(); + ThreatIntelFeedData threatIntelFeedData = new ThreatIntelFeedData(iocType, iocValue, feedId, timestamp); + tifdList.add(threatIntelFeedData); + } + for (ThreatIntelFeedData tifd : tifdList) { + XContentBuilder tifData = tifd.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + IndexRequest indexRequest = new IndexRequest(indexName); + indexRequest.source(tifData); + indexRequest.opType(DocWriteRequest.OpType.INDEX); + bulkRequest.add(indexRequest); + + if (bulkRequest.requests().size() == batchSize) { + saveTifds(bulkRequest, timeout); + } + } + saveTifds(bulkRequest, timeout); + renewLock.run(); + } + + public static boolean isValidIp(String ip) { + String ipPattern = "^\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$"; + Pattern pattern = Pattern.compile(ipPattern); + Matcher matcher = pattern.matcher(ip); + return matcher.matches(); + } + + public void saveTifds(BulkRequest bulkRequest, TimeValue timeout) { + try { + BulkResponse response = StashedThreadContext.run(client, () -> { + return client.bulk(bulkRequest).actionGet(timeout); + }); + if (response.hasFailures()) { + throw new OpenSearchException( + "error occurred while ingesting threat intel feed data in {} with an error {}", + StringUtils.join(bulkRequest.getIndices()), + response.buildFailureMessage() + ); + } + bulkRequest.requests().clear(); + } catch (OpenSearchException e) { + log.error("failed to save threat intel feed data", e); + } + + } + + public void deleteThreatIntelDataIndex(final List indices) { + if (indices == null || indices.isEmpty()) { + return; + } + + Optional invalidIndex = indices.stream() + .filter(index -> index.startsWith(THREAT_INTEL_DATA_INDEX_NAME_PREFIX) == false) + .findAny(); + if (invalidIndex.isPresent()) { + throw new OpenSearchException( + "the index[{}] is not threat intel data index which should start with {}", + invalidIndex.get(), + THREAT_INTEL_DATA_INDEX_NAME_PREFIX + ); + } + + AcknowledgedResponse response = StashedThreadContext.run( + client, + () -> client.admin() + .indices() + .prepareDelete(indices.toArray(new String[0])) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN) + .execute() + .actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)) + ); + + if (response.isAcknowledged() == false) { + throw new OpenSearchException("failed to delete data[{}]", String.join(",", indices)); + } + } + + private void createThreatIntelFeedData() throws InterruptedException { + CountDownLatch countDownLatch = new CountDownLatch(1); + client.execute( + PutTIFJobAction.INSTANCE, + new PutTIFJobRequest("feed_updater", clusterSettings.get(SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL)), + new ActionListener() { + @Override + public void onResponse(AcknowledgedResponse acknowledgedResponse) { + log.debug("Acknowledged threat intel feed updater job created"); + countDownLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + log.debug("Failed to create threat intel feed updater job", e); + countDownLatch.countDown(); + } + } + ); + countDownLatch.await(); + } + + private String getIndexMapping() { + try { + try (InputStream is = TIFJobParameterService.class.getResourceAsStream("/mappings/threat_intel_feed_mapping.json")) { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) { + return reader.lines().map(String::trim).collect(Collectors.joining()); + } + } + } catch (IOException e) { + log.error("Runtime exception when getting the threat intel index mapping", e); + throw new SecurityAnalyticsException("Runtime exception when getting the threat intel index mapping", RestStatus.INTERNAL_SERVER_ERROR, e); + } + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataUtils.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataUtils.java new file mode 100644 index 000000000..a96558b50 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataUtils.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class ThreatIntelFeedDataUtils { + + private static final Logger log = LogManager.getLogger(ThreatIntelFeedDataUtils.class); + + public static List getTifdList(SearchResponse searchResponse, NamedXContentRegistry xContentRegistry) { + List list = new ArrayList<>(); + if (searchResponse.getHits().getHits().length != 0) { + Arrays.stream(searchResponse.getHits().getHits()).forEach(hit -> { + try { + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() + ); + xcp.nextToken(); + list.add(ThreatIntelFeedData.parse(xcp, hit.getId(), hit.getVersion())); + } catch (Exception e) { + log.error(() -> new ParameterizedMessage( + "Failed to parse Threat intel feed data doc from hit {}", hit), + e + ); + } + + }); + } + return list; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedParser.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedParser.java new file mode 100644 index 000000000..92a66ed12 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedParser.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.SpecialPermission; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.securityanalytics.threatIntel.common.Constants; +import org.opensearch.securityanalytics.threatIntel.common.TIFMetadata; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URL; +import java.net.URLConnection; +import java.security.AccessController; +import java.security.PrivilegedAction; + +//Parser helper class +public class ThreatIntelFeedParser { + private static final Logger log = LogManager.getLogger(ThreatIntelFeedParser.class); + + /** + * Create CSVParser of a threat intel feed + * + * @param tifMetadata Threat intel feed metadata + * @return parser for threat intel feed + */ + @SuppressForbidden(reason = "Need to connect to http endpoint to read threat intel feed database file") + public static CSVParser getThreatIntelFeedReaderCSV(final TIFMetadata tifMetadata) { + SpecialPermission.check(); + return AccessController.doPrivileged((PrivilegedAction) () -> { + try { + URL url = new URL(tifMetadata.getUrl()); + URLConnection connection = url.openConnection(); + connection.addRequestProperty(Constants.USER_AGENT_KEY, Constants.USER_AGENT_VALUE); + return new CSVParser(new BufferedReader(new InputStreamReader(connection.getInputStream())), CSVFormat.RFC4180); + } catch (IOException e) { + log.error("Exception: failed to read threat intel feed data from {}",tifMetadata.getUrl(), e); + throw new OpenSearchException("failed to read threat intel feed data from {}", tifMetadata.getUrl(), e); + } + }); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobAction.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobAction.java new file mode 100644 index 000000000..01863f862 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.action; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; + +/** + * Threat intel tif job creation action + */ +public class PutTIFJobAction extends ActionType { + /** + * Put tif job action instance + */ + public static final PutTIFJobAction INSTANCE = new PutTIFJobAction(); + /** + * Put tif job action name + */ + public static final String NAME = "cluster:admin/security_analytics/tifjob/put"; + + private PutTIFJobAction() { + super(NAME, AcknowledgedResponse::new); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobRequest.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobRequest.java new file mode 100644 index 000000000..5f58e5529 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.action; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.securityanalytics.threatIntel.common.ParameterValidator; + +import java.io.IOException; +import java.util.List; + +/** + * Threat intel tif job creation request + */ +public class PutTIFJobRequest extends ActionRequest { + private static final ParameterValidator VALIDATOR = new ParameterValidator(); + + /** + * @param name the tif job name + * @return the tif job name + */ + private String name; + + /** + * @param updateInterval update interval of a tif job + * @return update interval of a tif job + */ + private TimeValue updateInterval; + + public void setName(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public TimeValue getUpdateInterval() { + return this.updateInterval; + } + + /** + * Default constructor + * @param name name of a tif job + */ + public PutTIFJobRequest(final String name, final TimeValue updateInterval) { + this.name = name; + this.updateInterval = updateInterval; + } + + /** + * Constructor with stream input + * @param in the stream input + * @throws IOException IOException + */ + public PutTIFJobRequest(final StreamInput in) throws IOException { + super(in); + this.name = in.readString(); + this.updateInterval = in.readTimeValue(); + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(name); + out.writeTimeValue(updateInterval); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException errors = new ActionRequestValidationException(); + List errorMsgs = VALIDATOR.validateTIFJobName(name); + if (errorMsgs.isEmpty() == false) { + errorMsgs.stream().forEach(msg -> errors.addValidationError(msg)); + } + return errors.validationErrors().isEmpty() ? null : errors; + } + +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobAction.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobAction.java new file mode 100644 index 000000000..1346da40c --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobAction.java @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.StepListener; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.securityanalytics.threatIntel.common.TIFJobState; +import org.opensearch.securityanalytics.threatIntel.common.TIFLockService; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameter; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameterService; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobUpdateService; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.time.Instant; +import java.util.ConcurrentModificationException; +import java.util.concurrent.atomic.AtomicReference; + +import static org.opensearch.securityanalytics.threatIntel.common.TIFLockService.LOCK_DURATION_IN_SECONDS; + +/** + * Transport action to create job to fetch threat intel feed data and save IoCs + */ +public class TransportPutTIFJobAction extends HandledTransportAction { + // TODO refactor this into a service class that creates feed updation job. This is not necessary to be a transport action + private static final Logger log = LogManager.getLogger(TransportPutTIFJobAction.class); + + private final ThreadPool threadPool; + private final TIFJobParameterService tifJobParameterService; + private final TIFJobUpdateService tifJobUpdateService; + private final TIFLockService lockService; + + /** + * Default constructor + * @param transportService the transport service + * @param actionFilters the action filters + * @param threadPool the thread pool + * @param tifJobParameterService the tif job parameter service facade + * @param tifJobUpdateService the tif job update service + * @param lockService the lock service + */ + @Inject + public TransportPutTIFJobAction( + final TransportService transportService, + final ActionFilters actionFilters, + final ThreadPool threadPool, + final TIFJobParameterService tifJobParameterService, + final TIFJobUpdateService tifJobUpdateService, + final TIFLockService lockService + ) { + super(PutTIFJobAction.NAME, transportService, actionFilters, PutTIFJobRequest::new); + this.threadPool = threadPool; + this.tifJobParameterService = tifJobParameterService; + this.tifJobUpdateService = tifJobUpdateService; + this.lockService = lockService; + } + + @Override + protected void doExecute(final Task task, final PutTIFJobRequest request, final ActionListener listener) { + lockService.acquireLock(request.getName(), LOCK_DURATION_IN_SECONDS, ActionListener.wrap(lock -> { + if (lock == null) { + listener.onFailure( + new ConcurrentModificationException("another processor is holding a lock on the resource. Try again later") + ); + log.error("another processor is a lock, BAD_REQUEST error", RestStatus.BAD_REQUEST); + return; + } + try { + internalDoExecute(request, lock, listener); + } catch (Exception e) { + lockService.releaseLock(lock); + listener.onFailure(e); + log.error("listener failed when executing", e); + } + }, exception -> { + listener.onFailure(exception); + log.error("execution failed", exception); + })); + } + + /** + * This method takes lock as a parameter and is responsible for releasing lock + * unless exception is thrown + */ + protected void internalDoExecute( + final PutTIFJobRequest request, + final LockModel lock, + final ActionListener listener + ) { + StepListener createIndexStep = new StepListener<>(); + tifJobParameterService.createJobIndexIfNotExists(createIndexStep); + createIndexStep.whenComplete(v -> { + TIFJobParameter tifJobParameter = TIFJobParameter.Builder.build(request); + tifJobParameterService.saveTIFJobParameter(tifJobParameter, postIndexingTifJobParameter(tifJobParameter, lock, listener)); + }, exception -> { + lockService.releaseLock(lock); + log.error("failed to release lock", exception); + listener.onFailure(exception); + }); + } + + /** + * This method takes lock as a parameter and is responsible for releasing lock + * unless exception is thrown + */ + protected ActionListener postIndexingTifJobParameter( + final TIFJobParameter tifJobParameter, + final LockModel lock, + final ActionListener listener + ) { + return new ActionListener<>() { + @Override + public void onResponse(final IndexResponse indexResponse) { + AtomicReference lockReference = new AtomicReference<>(lock); + try { + createThreatIntelFeedData(tifJobParameter, lockService.getRenewLockRunnable(lockReference)); + } finally { + lockService.releaseLock(lockReference.get()); + } + listener.onResponse(new AcknowledgedResponse(true)); + } + + @Override + public void onFailure(final Exception e) { + lockService.releaseLock(lock); + if (e instanceof VersionConflictEngineException) { + log.error("tifJobParameter already exists"); + listener.onFailure(new ResourceAlreadyExistsException("tifJobParameter [{}] already exists", tifJobParameter.getName())); + } else { + log.error("Internal server error"); + listener.onFailure(e); + } + } + }; + } + + protected void createThreatIntelFeedData(final TIFJobParameter tifJobParameter, final Runnable renewLock) { + if (TIFJobState.CREATING.equals(tifJobParameter.getState()) == false) { + log.error("Invalid tifJobParameter state. Expecting {} but received {}", TIFJobState.CREATING, tifJobParameter.getState()); + markTIFJobAsCreateFailed(tifJobParameter); + return; + } + + try { + tifJobUpdateService.createThreatIntelFeedData(tifJobParameter, renewLock); + } catch (Exception e) { + log.error("Failed to create tifJobParameter for {}", tifJobParameter.getName(), e); + markTIFJobAsCreateFailed(tifJobParameter); + } + } + + private void markTIFJobAsCreateFailed(final TIFJobParameter tifJobParameter) { + tifJobParameter.getUpdateStats().setLastFailedAt(Instant.now()); + tifJobParameter.setState(TIFJobState.CREATE_FAILED); + try { + tifJobParameterService.updateJobSchedulerParameter(tifJobParameter); + } catch (Exception e) { + log.error("Failed to mark tifJobParameter state as CREATE_FAILED for {}", tifJobParameter.getName(), e); + } + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/common/Constants.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/Constants.java new file mode 100644 index 000000000..808c0a3da --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/Constants.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel.common; + +import org.opensearch.Version; + +import java.util.Locale; +public class Constants { + public static final String USER_AGENT_KEY = "User-Agent"; + public static final String USER_AGENT_VALUE = String.format(Locale.ROOT, "OpenSearch/%s vanilla", Version.CURRENT.toString()); +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/common/ParameterValidator.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/ParameterValidator.java new file mode 100644 index 000000000..4658557df --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/ParameterValidator.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.common.Strings; + +/** + * Parameter validator for TIF APIs + */ +public class ParameterValidator { + private static final int MAX_TIFJOB_NAME_BYTES = 127; + + /** + * Validate TIF Job name and return list of error messages + * + * @param tifJobName tifJobName name + * @return Error messages. Empty list if there is no violation. + */ + public List validateTIFJobName(final String tifJobName) { + List errorMsgs = new ArrayList<>(); + if (StringUtils.isBlank(tifJobName)) { + errorMsgs.add("threat intel feed job name must not be empty"); + return errorMsgs; + } + + if (!Strings.validFileName(tifJobName)) { + errorMsgs.add( + String.format(Locale.ROOT, "threat intel feed job name must not contain the following characters %s", Strings.INVALID_FILENAME_CHARS) + ); + } + if (tifJobName.contains("#") || tifJobName.contains(":") ) { + errorMsgs.add("threat intel feed job name must not contain '#'"); + } + if (tifJobName.charAt(0) == '_' || tifJobName.charAt(0) == '-' || tifJobName.charAt(0) == '+') { + errorMsgs.add("threat intel feed job name must not start with '_', '-', or '+'"); + } + int byteCount = tifJobName.getBytes(StandardCharsets.UTF_8).length; + if (byteCount > MAX_TIFJOB_NAME_BYTES) { + errorMsgs.add(String.format(Locale.ROOT, "threat intel feed job name is too long, (%d > %d)", byteCount, MAX_TIFJOB_NAME_BYTES)); + } + if (tifJobName.equals(".") || tifJobName.equals("..")) { + errorMsgs.add("threat intel feed job name must not be '.' or '..'"); + } + return errorMsgs; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFJobState.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFJobState.java new file mode 100644 index 000000000..22ffee3e9 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFJobState.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +/** + * Threat intel tif job state + * + * When tif job is created, it starts with CREATING state. Once the first threat intel feed is generated, the state changes to AVAILABLE. + * Only when the first threat intel feed generation failed, the state changes to CREATE_FAILED. + * Subsequent threat intel feed failure won't change tif job state from AVAILABLE to CREATE_FAILED. + * When delete request is received, the tif job state changes to DELETING. + * + * State changed from left to right for the entire lifecycle of a datasource + * (CREATING) to (CREATE_FAILED or AVAILABLE) to (DELETING) + * + */ +public enum TIFJobState { + /** + * tif job is being created + */ + CREATING, + /** + * tif job is ready to be used + */ + AVAILABLE, + /** + * tif job creation failed + */ + CREATE_FAILED, + /** + * tif job is being deleted + */ + DELETING +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFLockService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFLockService.java new file mode 100644 index 000000000..7ec4e94f3 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFLockService.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +import static org.opensearch.securityanalytics.SecurityAnalyticsPlugin.JOB_INDEX_NAME; + + +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import org.opensearch.OpenSearchException; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; + +/** + * A wrapper of job scheduler's lock service + */ +public class TIFLockService { + private static final Logger log = LogManager.getLogger(TIFLockService.class); + + public static final long LOCK_DURATION_IN_SECONDS = 300l; + public static final long RENEW_AFTER_IN_SECONDS = 120l; + private final ClusterService clusterService; + private final LockService lockService; + + + /** + * Constructor + * + * @param clusterService the cluster service + * @param client the client + */ + public TIFLockService(final ClusterService clusterService, final Client client) { + this.clusterService = clusterService; + this.lockService = new LockService(client, clusterService); + } + + /** + * Wrapper method of LockService#acquireLockWithId + * + * tif job uses its name as doc id in job scheduler. Therefore, we can use tif job name to acquire + * a lock on a tif job. + * + * @param tifJobName tifJobName to acquire lock on + * @param lockDurationSeconds the lock duration in seconds + * @param listener the listener + */ + public void acquireLock(final String tifJobName, final Long lockDurationSeconds, final ActionListener listener) { + lockService.acquireLockWithId(JOB_INDEX_NAME, lockDurationSeconds, tifJobName, listener); + } + + /** + * Synchronous method of #acquireLock + * + * @param tifJobName tifJobName to acquire lock on + * @param lockDurationSeconds the lock duration in seconds + * @return lock model + */ + public Optional acquireLock(final String tifJobName, final Long lockDurationSeconds) { + AtomicReference lockReference = new AtomicReference(); + CountDownLatch countDownLatch = new CountDownLatch(1); + lockService.acquireLockWithId(JOB_INDEX_NAME, lockDurationSeconds, tifJobName, new ActionListener<>() { + @Override + public void onResponse(final LockModel lockModel) { + lockReference.set(lockModel); + countDownLatch.countDown(); + } + + @Override + public void onFailure(final Exception e) { + lockReference.set(null); + countDownLatch.countDown(); + log.error("aquiring lock failed", e); + } + }); + + try { + countDownLatch.await(clusterService.getClusterSettings().get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT).getSeconds(), TimeUnit.SECONDS); + return Optional.ofNullable(lockReference.get()); + } catch (InterruptedException e) { + log.error("Waiting for the count down latch failed", e); + return Optional.empty(); + } + } + + /** + * Wrapper method of LockService#release + * + * @param lockModel the lock model + */ + public void releaseLock(final LockModel lockModel) { + lockService.release( + lockModel, + ActionListener.wrap(released -> {}, exception -> log.error("Failed to release the lock", exception)) + ); + } + + /** + * Synchronous method of LockService#renewLock + * + * @param lockModel lock to renew + * @return renewed lock if renew succeed and null otherwise + */ + public LockModel renewLock(final LockModel lockModel) { + AtomicReference lockReference = new AtomicReference(); + CountDownLatch countDownLatch = new CountDownLatch(1); + lockService.renewLock(lockModel, new ActionListener<>() { + @Override + public void onResponse(final LockModel lockModel) { + lockReference.set(lockModel); + countDownLatch.countDown(); + } + + @Override + public void onFailure(final Exception e) { + log.error("failed to renew lock", e); + lockReference.set(null); + countDownLatch.countDown(); + } + }); + + try { + countDownLatch.await(clusterService.getClusterSettings().get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT).getSeconds(), TimeUnit.SECONDS); + return lockReference.get(); + } catch (InterruptedException e) { + log.error("Interrupted exception", e); + return null; + } + } + + /** + * Return a runnable which can renew the given lock model + * + * The runnable renews the lock and store the renewed lock in the AtomicReference. + * It only renews the lock when it passed {@code RENEW_AFTER_IN_SECONDS} since + * the last time the lock was renewed to avoid resource abuse. + * + * @param lockModel lock model to renew + * @return runnable which can renew the given lock for every call + */ + public Runnable getRenewLockRunnable(final AtomicReference lockModel) { + return () -> { + LockModel preLock = lockModel.get(); + if (Instant.now().isBefore(preLock.getLockTime().plusSeconds(RENEW_AFTER_IN_SECONDS))) { + return; + } + lockModel.set(renewLock(lockModel.get())); + if (lockModel.get() == null) { + log.error("Exception: failed to renew a lock"); + new OpenSearchException("failed to renew a lock [{}]", preLock); + } + }; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFMetadata.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFMetadata.java new file mode 100644 index 000000000..04486fb7a --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFMetadata.java @@ -0,0 +1,220 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel.common; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.*; + +/** + * POJO containing Threat Intel Feed Metadata + * Contains all the data necessary to fetch and parse threat intel IoC feeds. + */ +public class TIFMetadata implements Writeable, ToXContent { + + private static final ParseField FEED_ID_FIELD = new ParseField("id"); + private static final ParseField URL_FIELD = new ParseField("url"); + private static final ParseField NAME_FIELD = new ParseField("name"); + private static final ParseField ORGANIZATION_FIELD = new ParseField("organization"); + private static final ParseField DESCRIPTION_FIELD = new ParseField("description"); + private static final ParseField FEED_FORMAT = new ParseField("feed_format"); + private static final ParseField IOC_TYPE_FIELD = new ParseField("ioc_type"); + private static final ParseField IOC_COL_FIELD = new ParseField("ioc_col"); + private static final ParseField HAS_HEADER_FIELD = new ParseField("has_header"); + + + /** + * @param feedId ID of the threat intel feed data + * @return ID of the threat intel feed data + */ + private String feedId; + + /** + * @param url URL of the threat intel feed data + * @return URL of the threat intel feed data + */ + private String url; + + /** + * @param name Name of the threat intel feed + * @return Name of the threat intel feed + */ + private String name; + + /** + * @param organization A threat intel feed organization name + * @return A threat intel feed organization name + */ + private String organization; + + /** + * @param description A description of the database + * @return A description of a database + */ + private String description; + + /** + * @param feedType The type of the data feed (csv, json...) + * @return The type of the data feed (csv, json...) + */ + private String feedType; + + /** + * @param iocCol the column of the ioc data if feedType is csv + * @return the column of the ioc data if feedType is csv + */ + private Integer iocCol; + + /** + * @param containedIocs ioc type in feed + * @return ioc type in feed + */ + private String iocType; + + /** + * @param hasHeader boolean if feed has a header + * @return boolean if feed has a header + */ + private Boolean hasHeader; + + public TIFMetadata(Map input) { + this( + input.get(FEED_ID_FIELD.getPreferredName()).toString(), + input.get(URL_FIELD.getPreferredName()).toString(), + input.get(NAME_FIELD.getPreferredName()).toString(), + input.get(ORGANIZATION_FIELD.getPreferredName()).toString(), + input.get(DESCRIPTION_FIELD.getPreferredName()).toString(), + input.get(FEED_FORMAT.getPreferredName()).toString(), + input.get(IOC_TYPE_FIELD.getPreferredName()).toString(), + Integer.parseInt(input.get(IOC_COL_FIELD.getPreferredName()).toString()), + (Boolean)input.get(HAS_HEADER_FIELD.getPreferredName()) + ); + } + + public String getUrl() { + return url; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public String getFeedId() { + return feedId; + } + + public String getFeedType() { + return feedType; + } + + public Integer getIocCol() { + return iocCol; + } + + public String getIocType() { + return iocType; + } + + public Boolean hasHeader() { + return hasHeader; + } + + + public TIFMetadata(final String feedId, final String url, final String name, final String organization, final String description, + final String feedType, final String iocType, final Integer iocCol, final Boolean hasHeader) { + this.feedId = feedId; + this.url = url; + this.name = name; + this.organization = organization; + this.description = description; + this.feedType = feedType; + this.iocType = iocType; + this.iocCol = iocCol; + this.hasHeader = hasHeader; + } + + + /** + * tif job metadata parser + */ + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tif_metadata", + true, + args -> { + String feedId = (String) args[0]; + String url = (String) args[1]; + String name = (String) args[2]; + String organization = (String) args[3]; + String description = (String) args[4]; + String feedType = (String) args[5]; + String containedIocs = (String) args[6]; + Integer iocCol = Integer.parseInt((String) args[7]); + Boolean hasHeader = (Boolean) args[8]; + return new TIFMetadata(feedId, url, name, organization, description, feedType, containedIocs, iocCol, hasHeader); + } + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), FEED_ID_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), URL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), NAME_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), ORGANIZATION_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), DESCRIPTION_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), FEED_FORMAT); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), IOC_TYPE_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), IOC_COL_FIELD); + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), HAS_HEADER_FIELD); + } + + public TIFMetadata(final StreamInput in) throws IOException { + feedId = in.readString(); + url = in.readString(); + name = in.readString(); + organization = in.readString(); + description = in.readString(); + feedType = in.readString(); + iocType = in.readString(); + iocCol = in.readInt(); + hasHeader = in.readBoolean(); + } + + public void writeTo(final StreamOutput out) throws IOException { + out.writeString(feedId); + out.writeString(url); + out.writeString(name); + out.writeString(organization); + out.writeString(description); + out.writeString(feedType); + out.writeString(iocType); + out.writeInt(iocCol); + out.writeBoolean(hasHeader); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + builder.field(FEED_ID_FIELD.getPreferredName(), feedId); + builder.field(URL_FIELD.getPreferredName(), url); + builder.field(NAME_FIELD.getPreferredName(), name); + builder.field(ORGANIZATION_FIELD.getPreferredName(), organization); + builder.field(DESCRIPTION_FIELD.getPreferredName(), description); + builder.field(FEED_FORMAT.getPreferredName(), feedType); + builder.field(IOC_TYPE_FIELD.getPreferredName(), iocType); + builder.field(IOC_COL_FIELD.getPreferredName(), iocCol); + builder.field(HAS_HEADER_FIELD.getPreferredName(), hasHeader); + builder.endObject(); + return builder; + } + +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/feedMetadata/BuiltInTIFMetadataLoader.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/feedMetadata/BuiltInTIFMetadataLoader.java new file mode 100644 index 000000000..6b84e9fe9 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/feedMetadata/BuiltInTIFMetadataLoader.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel.feedMetadata; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.lifecycle.AbstractLifecycleComponent; +import org.opensearch.common.settings.SettingsException; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.securityanalytics.threatIntel.common.TIFMetadata; +import org.opensearch.securityanalytics.util.FileUtils; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class BuiltInTIFMetadataLoader extends AbstractLifecycleComponent { + + private static final Logger logger = LogManager.getLogger(BuiltInTIFMetadataLoader.class); + + private static final String BASE_PATH = "threatIntelFeed/"; + + + private List tifMetadataList = null; + private Map tifMetadataByName; + + public List getTifMetadataList() { + ensureTifMetadataLoaded(); + return tifMetadataList; + } + + public TIFMetadata getTifMetadataByName(String name) { + ensureTifMetadataLoaded(); + return tifMetadataByName.get(name); + } + + public boolean tifMetadataExists(String name) { + ensureTifMetadataLoaded(); + return tifMetadataByName.containsKey(name); + } + + public void ensureTifMetadataLoaded() { + try { + if (tifMetadataList != null) { + return; + } + loadBuiltInTifMetadata(); + tifMetadataByName = tifMetadataList.stream() + .collect(Collectors.toMap(TIFMetadata::getName, Function.identity())); + } catch (Exception e) { + logger.error("Failed loading builtin log types from disk!", e); + } + } + + @SuppressWarnings("unchecked") + protected void loadBuiltInTifMetadata() throws URISyntaxException, IOException { + final String url = Objects.requireNonNull(BuiltInTIFMetadataLoader.class.getClassLoader().getResource(BASE_PATH), + "Built-in threat intel feed metadata file not found").toURI().toString(); + Path dirPath = null; + if (url.contains("!")) { + final String[] paths = url.split("!"); + dirPath = FileUtils.getFs().getPath(paths[1]); + } else { + dirPath = Path.of(url); + } + + Stream folder = Files.list(dirPath); + Path tifMetadataPath = folder.filter(e -> e.toString().endsWith("feedMetadata.json")).collect(Collectors.toList()).get(0); + try ( + InputStream is = BuiltInTIFMetadataLoader.class.getResourceAsStream(tifMetadataPath.toString()) + ) { + String tifMetadataFilePayload = new String(Objects.requireNonNull(is).readAllBytes(), StandardCharsets.UTF_8); + + if (tifMetadataFilePayload != null) { + if(tifMetadataList == null) + tifMetadataList = new ArrayList<>(); + Map tifMetadataFileAsMap = + XContentHelper.convertToMap(JsonXContent.jsonXContent, tifMetadataFilePayload, false); + + for (Map.Entry mapEntry : tifMetadataFileAsMap.entrySet()) { + Map tifMetadataMap = (Map) mapEntry.getValue(); + tifMetadataList.add(new TIFMetadata(tifMetadataMap)); + } + } + } catch (Exception e) { + throw new SettingsException("Failed to load builtin threat intel feed metadata" + + "", e); + } + } + + @Override + protected void doStart() { + ensureTifMetadataLoaded(); + } + + @Override + protected void doStop() { + + } + + @Override + protected void doClose() throws IOException { + + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameter.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameter.java new file mode 100644 index 000000000..bcbb84c1c --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameter.java @@ -0,0 +1,569 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ConstructingObjectParser; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; +import org.opensearch.securityanalytics.threatIntel.action.PutTIFJobRequest; +import org.opensearch.securityanalytics.threatIntel.common.TIFJobState; +import org.opensearch.securityanalytics.threatIntel.common.TIFLockService; +import org.opensearch.securityanalytics.threatIntel.common.TIFMetadata; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import static org.opensearch.common.time.DateUtils.toInstant; + +public class TIFJobParameter implements Writeable, ScheduledJobParameter { + /** + * Prefix of indices having threatIntel data + */ + public static final String THREAT_INTEL_DATA_INDEX_NAME_PREFIX = ".opensearch-sap-threat-intel"; + + + /** + * String fields for job scheduling parameters used for ParseField + */ + private static final String NAME_FIELD = "name"; + private static final String ENABLED_FIELD = "update_enabled"; + private static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + private static final String LAST_UPDATE_TIME_FIELD_READABLE = "last_update_time_field"; + private static final String SCHEDULE_FIELD = "schedule"; + private static final String ENABLED_TIME_FIELD = "enabled_time"; + private static final String ENABLED_TIME_FIELD_READABLE = "enabled_time_field"; + private static final String state_field = "state"; + private static final String INDICES_FIELD = "indices"; + private static final String update_stats_field = "update_stats"; + + + /** + * Default fields for job scheduling + */ + public static final ParseField NAME_PARSER_FIELD = new ParseField(NAME_FIELD); + public static final ParseField ENABLED_PARSER_FIELD = new ParseField(ENABLED_FIELD); + public static final ParseField LAST_UPDATE_TIME_PARSER_FIELD = new ParseField(LAST_UPDATE_TIME_FIELD); + public static final ParseField LAST_UPDATE_TIME_FIELD_READABLE_PARSER_FIELD = new ParseField(LAST_UPDATE_TIME_FIELD_READABLE); + public static final ParseField SCHEDULE_PARSER_FIELD = new ParseField(SCHEDULE_FIELD); + public static final ParseField ENABLED_TIME_PARSER_FIELD = new ParseField(ENABLED_TIME_FIELD); + public static final ParseField ENABLED_TIME_FIELD_READABLE_PARSER_FIELD = new ParseField(ENABLED_TIME_FIELD_READABLE); + + /** + * Additional fields for tif job + */ + public static final ParseField STATE_PARSER_FIELD = new ParseField(state_field); + public static final ParseField INDICES_PARSER_FIELD = new ParseField(INDICES_FIELD); + public static final ParseField UPDATE_STATS_PARSER_FIELD = new ParseField(update_stats_field); + + /** + * Default variables for job scheduling + */ + + /** + * @param name name of a tif job + * @return name of a tif job + */ + private String name; + + /** + * @param lastUpdateTime Last update time of a tif job + * @return Last update time of a tif job + */ + private Instant lastUpdateTime; + /** + * @param enabledTime Last time when a scheduling is enabled for a threat intel feed data update + * @return Last time when a scheduling is enabled for the job scheduler + */ + private Instant enabledTime; + /** + * @param isEnabled Indicate if threat intel feed data update is scheduled or not + * @return Indicate if scheduling is enabled or not + */ + private boolean isEnabled; + /** + * @param schedule Schedule that system uses + * @return Schedule that system uses + */ + private IntervalSchedule schedule; + + + /** + * Additional variables for tif job + */ + + /** + * @param state State of a tif job + * @return State of a tif job + */ + private TIFJobState state; + + /** + * @param indices A list of indices having threat intel feed data + * @return A list of indices having threat intel feed data including + */ + private List indices; + + /** + * @param updateStats threat intel feed database update statistics + * @return threat intel feed database update statistics + */ + private UpdateStats updateStats; + + public static TIFJobParameter parse(XContentParser xcp, String id, Long version) throws IOException { + String name = null; + Instant lastUpdateTime = null; + Boolean isEnabled = null; + TIFJobState state = null; + + xcp.nextToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = xcp.text(); + break; + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = Instant.ofEpochMilli(xcp.longValue()); + break; + case ENABLED_FIELD: + isEnabled = xcp.booleanValue(); + break; + case state_field: + state = toState(xcp.text()); + break; + default: + xcp.skipChildren(); + } + } + return new TIFJobParameter(name, lastUpdateTime, isEnabled, state); + } + + public static TIFJobState toState(String stateName) { + if (stateName.equals("CREATING")) { + return TIFJobState.CREATING; + } + if (stateName.equals("AVAILABLE")) { + return TIFJobState.AVAILABLE; + } + if (stateName.equals("CREATE_FAILED")) { + return TIFJobState.CREATE_FAILED; + } + if (stateName.equals("DELETING")) { + return TIFJobState.DELETING; + } + return null; + } + + public TIFJobParameter(final String name, final Instant lastUpdateTime, final Boolean isEnabled, TIFJobState state) { + this.name = name; + this.lastUpdateTime = lastUpdateTime; + this.isEnabled = isEnabled; + this.state = state; + } + + /** + * tif job parser + */ + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tifjob_metadata", + true, + args -> { + String name = (String) args[0]; + Instant lastUpdateTime = Instant.ofEpochMilli((long) args[1]); + Instant enabledTime = args[2] == null ? null : Instant.ofEpochMilli((long) args[2]); + boolean isEnabled = (boolean) args[3]; + IntervalSchedule schedule = (IntervalSchedule) args[4]; + TIFJobState state = TIFJobState.valueOf((String) args[5]); + List indices = (List) args[6]; + UpdateStats updateStats = (UpdateStats) args[7]; + TIFJobParameter parameter = new TIFJobParameter( + name, + lastUpdateTime, + enabledTime, + isEnabled, + schedule, + state, + indices, + updateStats + ); + return parameter; + } + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), NAME_PARSER_FIELD); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), LAST_UPDATE_TIME_PARSER_FIELD); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ENABLED_TIME_PARSER_FIELD); + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ENABLED_PARSER_FIELD); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> ScheduleParser.parse(p), SCHEDULE_PARSER_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), STATE_PARSER_FIELD); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDICES_PARSER_FIELD); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), UpdateStats.PARSER, UPDATE_STATS_PARSER_FIELD); + } + + public TIFJobParameter() { + this(null, null); + } + + public TIFJobParameter(final String name, final Instant lastUpdateTime, final Instant enabledTime, final Boolean isEnabled, + final IntervalSchedule schedule, final TIFJobState state, + final List indices, final UpdateStats updateStats) { + this.name = name; + this.lastUpdateTime = lastUpdateTime; + this.enabledTime = enabledTime; + this.isEnabled = isEnabled; + this.schedule = schedule; + this.state = state; + this.indices = indices; + this.updateStats = updateStats; + } + + public TIFJobParameter(final String name, final IntervalSchedule schedule) { + this( + name, + Instant.now().truncatedTo(ChronoUnit.MILLIS), + null, + false, + schedule, + TIFJobState.CREATING, + new ArrayList<>(), + new UpdateStats() + ); + } + + public TIFJobParameter(final StreamInput in) throws IOException { + name = in.readString(); + lastUpdateTime = toInstant(in.readVLong()); + enabledTime = toInstant(in.readOptionalVLong()); + isEnabled = in.readBoolean(); + schedule = new IntervalSchedule(in); + state = TIFJobState.valueOf(in.readString()); + indices = in.readStringList(); + updateStats = new UpdateStats(in); + } + + public void writeTo(final StreamOutput out) throws IOException { + out.writeString(name); + out.writeVLong(lastUpdateTime.toEpochMilli()); + out.writeOptionalVLong(enabledTime == null ? null : enabledTime.toEpochMilli()); + out.writeBoolean(isEnabled); + schedule.writeTo(out); + out.writeString(state.name()); + out.writeStringCollection(indices); + updateStats.writeTo(out); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + builder.field(NAME_PARSER_FIELD.getPreferredName(), name); + builder.timeField( + LAST_UPDATE_TIME_PARSER_FIELD.getPreferredName(), + LAST_UPDATE_TIME_FIELD_READABLE_PARSER_FIELD.getPreferredName(), + lastUpdateTime.toEpochMilli() + ); + if (enabledTime != null) { + builder.timeField( + ENABLED_TIME_PARSER_FIELD.getPreferredName(), + ENABLED_TIME_FIELD_READABLE_PARSER_FIELD.getPreferredName(), + enabledTime.toEpochMilli() + ); + } + builder.field(ENABLED_PARSER_FIELD.getPreferredName(), isEnabled); + builder.field(SCHEDULE_PARSER_FIELD.getPreferredName(), schedule); + builder.field(STATE_PARSER_FIELD.getPreferredName(), state.name()); + builder.field(INDICES_PARSER_FIELD.getPreferredName(), indices); + builder.field(UPDATE_STATS_PARSER_FIELD.getPreferredName(), updateStats); + builder.endObject(); + return builder; + } + + // getters and setters + public void setName(String name) { + this.name = name; + } + + public void setEnabledTime(Instant enabledTime) { + this.enabledTime = enabledTime; + } + + public void setEnabled(boolean enabled) { + isEnabled = enabled; + } + + public void setIndices(List indices) { + this.indices = indices; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public Instant getLastUpdateTime() { + return this.lastUpdateTime; + } + + @Override + public Instant getEnabledTime() { + return this.enabledTime; + } + + @Override + public IntervalSchedule getSchedule() { + return this.schedule; + } + + @Override + public boolean isEnabled() { + return this.isEnabled; + } + + public void setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + } + + @Override + public Long getLockDurationSeconds() { + return TIFLockService.LOCK_DURATION_IN_SECONDS; + } + + /** + * Enable auto update of threat intel feed data + */ + public void enable() { + if (isEnabled == true) { + return; + } + enabledTime = Instant.now().truncatedTo(ChronoUnit.MILLIS); + isEnabled = true; + } + + /** + * Disable auto update of threat intel feed data + */ + public void disable() { + enabledTime = null; + isEnabled = false; + } + + public void setSchedule(IntervalSchedule schedule) { + this.schedule = schedule; + } + + /** + * Index name for a tif job + * + * @return index name for a tif job + */ + public String newIndexName(final TIFJobParameter jobSchedulerParameter, TIFMetadata tifMetadata) { + List indices = jobSchedulerParameter.getIndices(); + Optional nameOptional = indices.stream().filter(name -> name.contains(tifMetadata.getFeedId())).findAny(); + String suffix = "1"; + if (nameOptional.isPresent()) { + String lastChar = "" + nameOptional.get().charAt(nameOptional.get().length() - 1); + suffix = (lastChar.equals("1")) ? "2" : suffix; + } + return String.format(Locale.ROOT, "%s-%s%s", THREAT_INTEL_DATA_INDEX_NAME_PREFIX, tifMetadata.getFeedId(), suffix); + } + + public TIFJobState getState() { + return state; + } + + public List getIndices() { + return indices; + } + + public void setState(TIFJobState previousState) { + this.state = previousState; + } + + public UpdateStats getUpdateStats() { + return this.updateStats; + } + + + /** + * Update stats of a tif job + */ + public static class UpdateStats implements Writeable, ToXContent { + private static final ParseField LAST_SUCCEEDED_AT_FIELD = new ParseField("last_succeeded_at_in_epoch_millis"); + private static final ParseField LAST_SUCCEEDED_AT_FIELD_READABLE = new ParseField("last_succeeded_at"); + private static final ParseField LAST_PROCESSING_TIME_IN_MILLIS_FIELD = new ParseField("last_processing_time_in_millis"); + private static final ParseField LAST_FAILED_AT_FIELD = new ParseField("last_failed_at_in_epoch_millis"); + private static final ParseField LAST_FAILED_AT_FIELD_READABLE = new ParseField("last_failed_at"); + private static final ParseField LAST_SKIPPED_AT = new ParseField("last_skipped_at_in_epoch_millis"); + private static final ParseField LAST_SKIPPED_AT_READABLE = new ParseField("last_skipped_at"); + + public Instant getLastSucceededAt() { + return lastSucceededAt; + } + + public Long getLastProcessingTimeInMillis() { + return lastProcessingTimeInMillis; + } + + public Instant getLastFailedAt() { + return lastFailedAt; + } + + public Instant getLastSkippedAt() { + return lastSkippedAt; + } + + /** + * @param lastSucceededAt The last time when threat intel feed data update was succeeded + * @return The last time when threat intel feed data update was succeeded + */ + private Instant lastSucceededAt; + /** + * @param lastProcessingTimeInMillis The last processing time when threat intel feed data update was succeeded + * @return The last processing time when threat intel feed data update was succeeded + */ + private Long lastProcessingTimeInMillis; + /** + * @param lastFailedAt The last time when threat intel feed data update was failed + * @return The last time when threat intel feed data update was failed + */ + private Instant lastFailedAt; + + /** + * @param lastSkippedAt The last time when threat intel feed data update was skipped as there was no new update from an endpoint + * @return The last time when threat intel feed data update was skipped as there was no new update from an endpoint + */ + private Instant lastSkippedAt; + + private UpdateStats() { + } + + public void setLastSkippedAt(Instant lastSkippedAt) { + this.lastSkippedAt = lastSkippedAt; + } + + public void setLastSucceededAt(Instant lastSucceededAt) { + this.lastSucceededAt = lastSucceededAt; + } + + public void setLastProcessingTimeInMillis(Long lastProcessingTimeInMillis) { + this.lastProcessingTimeInMillis = lastProcessingTimeInMillis; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tifjob_metadata_update_stats", + true, + args -> { + Instant lastSucceededAt = args[0] == null ? null : Instant.ofEpochMilli((long) args[0]); + Long lastProcessingTimeInMillis = (Long) args[1]; + Instant lastFailedAt = args[2] == null ? null : Instant.ofEpochMilli((long) args[2]); + Instant lastSkippedAt = args[3] == null ? null : Instant.ofEpochMilli((long) args[3]); + return new UpdateStats(lastSucceededAt, lastProcessingTimeInMillis, lastFailedAt, lastSkippedAt); + } + ); + + static { + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), LAST_SUCCEEDED_AT_FIELD); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), LAST_PROCESSING_TIME_IN_MILLIS_FIELD); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), LAST_FAILED_AT_FIELD); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), LAST_SKIPPED_AT); + } + + public UpdateStats(final StreamInput in) throws IOException { + lastSucceededAt = toInstant(in.readOptionalVLong()); + lastProcessingTimeInMillis = in.readOptionalVLong(); + lastFailedAt = toInstant(in.readOptionalVLong()); + lastSkippedAt = toInstant(in.readOptionalVLong()); + } + + public UpdateStats(Instant lastSucceededAt, Long lastProcessingTimeInMillis, Instant lastFailedAt, Instant lastSkippedAt) { + this.lastSucceededAt = lastSucceededAt; + this.lastProcessingTimeInMillis = lastProcessingTimeInMillis; + this.lastFailedAt = lastFailedAt; + this.lastSkippedAt = lastSkippedAt; + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + out.writeOptionalVLong(lastSucceededAt == null ? null : lastSucceededAt.toEpochMilli()); + out.writeOptionalVLong(lastProcessingTimeInMillis); + out.writeOptionalVLong(lastFailedAt == null ? null : lastFailedAt.toEpochMilli()); + out.writeOptionalVLong(lastSkippedAt == null ? null : lastSkippedAt.toEpochMilli()); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + if (lastSucceededAt != null) { + builder.timeField( + LAST_SUCCEEDED_AT_FIELD.getPreferredName(), + LAST_SUCCEEDED_AT_FIELD_READABLE.getPreferredName(), + lastSucceededAt.toEpochMilli() + ); + } + if (lastProcessingTimeInMillis != null) { + builder.field(LAST_PROCESSING_TIME_IN_MILLIS_FIELD.getPreferredName(), lastProcessingTimeInMillis); + } + if (lastFailedAt != null) { + builder.timeField( + LAST_FAILED_AT_FIELD.getPreferredName(), + LAST_FAILED_AT_FIELD_READABLE.getPreferredName(), + lastFailedAt.toEpochMilli() + ); + } + if (lastSkippedAt != null) { + builder.timeField( + LAST_SKIPPED_AT.getPreferredName(), + LAST_SKIPPED_AT_READABLE.getPreferredName(), + lastSkippedAt.toEpochMilli() + ); + } + builder.endObject(); + return builder; + } + + public void setLastFailedAt(Instant now) { + this.lastFailedAt = now; + } + } + + /** + * Builder class for tif job + */ + public static class Builder { + public static TIFJobParameter build(final PutTIFJobRequest request) { + long minutes = request.getUpdateInterval().minutes(); + String name = request.getName(); + IntervalSchedule schedule = new IntervalSchedule( + Instant.now().truncatedTo(ChronoUnit.MILLIS), + (int) minutes, + ChronoUnit.MINUTES + ); + return new TIFJobParameter(name, schedule); + + } + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterService.java new file mode 100644 index 000000000..640b3874b --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterService.java @@ -0,0 +1,176 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.StepListener; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.common.StashedThreadContext; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.stream.Collectors; + +/** + * Data access object for tif job parameter + */ +public class TIFJobParameterService { + private static final Logger log = LogManager.getLogger(TIFJobParameterService.class); + private final Client client; + private final ClusterService clusterService; + private final ClusterSettings clusterSettings; + + public TIFJobParameterService(final Client client, final ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + this.clusterSettings = clusterService.getClusterSettings(); + } + + /** + * Create tif job index + * + * @param stepListener setup listener + */ + public void createJobIndexIfNotExists(final StepListener stepListener) { + if (clusterService.state().metadata().hasIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME) == true) { + stepListener.onResponse(null); + return; + } + final CreateIndexRequest createIndexRequest = new CreateIndexRequest(SecurityAnalyticsPlugin.JOB_INDEX_NAME).mapping(getIndexMapping()) + .settings(SecurityAnalyticsPlugin.TIF_JOB_INDEX_SETTING); + StashedThreadContext.run(client, () -> client.admin().indices().create(createIndexRequest, new ActionListener<>() { + @Override + public void onResponse(final CreateIndexResponse createIndexResponse) { + stepListener.onResponse(null); + } + + @Override + public void onFailure(final Exception e) { + if (e instanceof ResourceAlreadyExistsException) { + log.info("index[{}] already exist", SecurityAnalyticsPlugin.JOB_INDEX_NAME); + stepListener.onResponse(null); + return; + } + stepListener.onFailure(e); + } + })); + } + + private String getIndexMapping() { + try { + try (InputStream is = TIFJobParameterService.class.getResourceAsStream("/mappings/threat_intel_job_mapping.json")) { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) { + return reader.lines().map(String::trim).collect(Collectors.joining()); + } + } + } catch (IOException e) { + log.error("Runtime exception", e); + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); + } + } + + /** + * Update jobSchedulerParameter in an index {@code TIFJobExtension.JOB_INDEX_NAME} + * @param jobSchedulerParameter the jobSchedulerParameter + * @return index response + */ + public IndexResponse updateJobSchedulerParameter(final TIFJobParameter jobSchedulerParameter) { + jobSchedulerParameter.setLastUpdateTime(Instant.now()); + return StashedThreadContext.run(client, () -> { + try { + return client.prepareIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME) + .setId(jobSchedulerParameter.getName()) + .setOpType(DocWriteRequest.OpType.INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(jobSchedulerParameter.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .execute() + .actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)); + } catch (IOException e) { + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); + } + }); + } + + /** + * Get tif job from an index {@code TIFJobExtension.JOB_INDEX_NAME} + * @param name the name of a tif job + * @return tif job + * @throws IOException exception + */ + public TIFJobParameter getJobParameter(final String name) throws IOException { + GetRequest request = new GetRequest(SecurityAnalyticsPlugin.JOB_INDEX_NAME, name); + GetResponse response; + try { + response = StashedThreadContext.run(client, () -> client.get(request).actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT))); + if (response.isExists() == false) { + log.error("TIF job[{}] does not exist in an index[{}]", name, SecurityAnalyticsPlugin.JOB_INDEX_NAME); + return null; + } + } catch (IndexNotFoundException e) { + log.error("Index[{}] is not found", SecurityAnalyticsPlugin.JOB_INDEX_NAME); + return null; + } + + XContentParser parser = XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + response.getSourceAsBytesRef() + ); + return TIFJobParameter.PARSER.parse(parser, null); + } + + /** + * Put tifJobParameter in an index {@code TIFJobExtension.JOB_INDEX_NAME} + * + * @param tifJobParameter the tifJobParameter + * @param listener the listener + */ + public void saveTIFJobParameter(final TIFJobParameter tifJobParameter, final ActionListener listener) { + tifJobParameter.setLastUpdateTime(Instant.now()); + StashedThreadContext.run(client, () -> { + try { + client.prepareIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME) + .setId(tifJobParameter.getName()) + .setOpType(DocWriteRequest.OpType.CREATE) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(tifJobParameter.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .execute(listener); + } catch (IOException e) { + throw new SecurityAnalyticsException("Exception saving the threat intel feed job parameter in index", RestStatus.INTERNAL_SERVER_ERROR, e); + } + }); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunner.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunner.java new file mode 100644 index 000000000..e3500064f --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunner.java @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.time.Instant; + +import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService; +import org.opensearch.securityanalytics.threatIntel.common.TIFJobState; +import org.opensearch.securityanalytics.threatIntel.common.TIFLockService; +import org.opensearch.threadpool.ThreadPool; + +/** + * Job Parameter update task + * + * This is a background task which is responsible for updating threat intel feed data + */ +public class TIFJobRunner implements ScheduledJobRunner { + private static final Logger log = LogManager.getLogger(TIFJobRunner.class); + private static TIFJobRunner INSTANCE; + + public static TIFJobRunner getJobRunnerInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (TIFJobRunner.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new TIFJobRunner(); + return INSTANCE; + } + } + + private ClusterService clusterService; + + // threat intel specific variables + private TIFJobUpdateService jobSchedulerUpdateService; + private TIFJobParameterService jobSchedulerParameterService; + private TIFLockService lockService; + private boolean initialized; + private ThreadPool threadPool; + private DetectorThreatIntelService detectorThreatIntelService; + + public void setThreadPool(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + private TIFJobRunner() { + // Singleton class, use getJobRunner method instead of constructor + } + + public void initialize( + final ClusterService clusterService, + final TIFJobUpdateService jobSchedulerUpdateService, + final TIFJobParameterService jobSchedulerParameterService, + final TIFLockService threatIntelLockService, + final ThreadPool threadPool, + DetectorThreatIntelService detectorThreatIntelService + ) { + this.clusterService = clusterService; + this.jobSchedulerUpdateService = jobSchedulerUpdateService; + this.jobSchedulerParameterService = jobSchedulerParameterService; + this.lockService = threatIntelLockService; + this.threadPool = threadPool; + this.initialized = true; + this.detectorThreatIntelService = detectorThreatIntelService; + } + + @Override + public void runJob(final ScheduledJobParameter jobParameter, final JobExecutionContext context) { + if (initialized == false) { + throw new AssertionError("This instance is not initialized"); + } + + log.info("Update job started for a job parameter[{}]", jobParameter.getName()); + if (jobParameter instanceof TIFJobParameter == false) { + log.error("Illegal state exception: job parameter is not instance of Job Scheduler Parameter"); + throw new IllegalStateException( + "job parameter is not instance of Job Scheduler Parameter, type: " + jobParameter.getClass().getCanonicalName() + ); + } + threadPool.generic().submit(updateJobRunner(jobParameter)); + } + + /** + * Update threat intel feed data + * + * Lock is used so that only one of nodes run this task. + * + * @param jobParameter job parameter + */ + protected Runnable updateJobRunner(final ScheduledJobParameter jobParameter) { + return () -> { + Optional lockModel = lockService.acquireLock( + jobParameter.getName(), + TIFLockService.LOCK_DURATION_IN_SECONDS + ); + if (lockModel.isEmpty()) { + log.error("Failed to update. Another processor is holding a lock for job parameter[{}]", jobParameter.getName()); + return; + } + + LockModel lock = lockModel.get(); + try { + updateJobParameter(jobParameter, lockService.getRenewLockRunnable(new AtomicReference<>(lock))); + } catch (Exception e) { + log.error("Failed to update job parameter[{}]", jobParameter.getName(), e); + } finally { + lockService.releaseLock(lock); + } + }; + } + + protected void updateJobParameter(final ScheduledJobParameter jobParameter, final Runnable renewLock) throws IOException { + TIFJobParameter jobSchedulerParameter = jobSchedulerParameterService.getJobParameter(jobParameter.getName()); + /** + * If delete request comes while update task is waiting on a queue for other update tasks to complete, + * because update task for this jobSchedulerParameter didn't acquire a lock yet, delete request is processed. + * When it is this jobSchedulerParameter's turn to run, it will find that the jobSchedulerParameter is deleted already. + * Therefore, we stop the update process when data source does not exist. + */ + if (jobSchedulerParameter == null) { + log.info("Job parameter[{}] does not exist", jobParameter.getName()); + return; + } + + if (TIFJobState.AVAILABLE.equals(jobSchedulerParameter.getState()) == false) { + log.error("Invalid jobSchedulerParameter state. Expecting {} but received {}", TIFJobState.AVAILABLE, jobSchedulerParameter.getState()); + jobSchedulerParameter.disable(); + jobSchedulerParameter.getUpdateStats().setLastFailedAt(Instant.now()); + jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter); + return; + } + try { + // create new TIF data and delete old ones + List oldIndices = new ArrayList<>(jobSchedulerParameter.getIndices()); + List newFeedIndices = jobSchedulerUpdateService.createThreatIntelFeedData(jobSchedulerParameter, renewLock); + jobSchedulerUpdateService.deleteAllTifdIndices(oldIndices, newFeedIndices); + if(false == newFeedIndices.isEmpty()) { + detectorThreatIntelService.updateDetectorsWithLatestThreatIntelRules(); + } + } catch (Exception e) { + log.error("Failed to update jobSchedulerParameter for {}", jobSchedulerParameter.getName(), e); + jobSchedulerParameter.getUpdateStats().setLastFailedAt(Instant.now()); + jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter); + } + } + +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobUpdateService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobUpdateService.java new file mode 100644 index 000000000..3006285ad --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobUpdateService.java @@ -0,0 +1,215 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.OpenSearchException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedDataService; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedParser; +import org.opensearch.securityanalytics.threatIntel.common.TIFJobState; +import org.opensearch.securityanalytics.threatIntel.common.TIFMetadata; +import org.opensearch.securityanalytics.threatIntel.feedMetadata.BuiltInTIFMetadataLoader; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; + +public class TIFJobUpdateService { + private static final Logger log = LogManager.getLogger(TIFJobUpdateService.class); + + private static final int SLEEP_TIME_IN_MILLIS = 5000; // 5 seconds + private static final int MAX_WAIT_TIME_FOR_REPLICATION_TO_COMPLETE_IN_MILLIS = 10 * 60 * 60 * 1000; // 10 hours + private final ClusterService clusterService; + private final ClusterSettings clusterSettings; + private final TIFJobParameterService jobSchedulerParameterService; + private final ThreatIntelFeedDataService threatIntelFeedDataService; + private final BuiltInTIFMetadataLoader builtInTIFMetadataLoader; + + public TIFJobUpdateService( + final ClusterService clusterService, + final TIFJobParameterService jobSchedulerParameterService, + final ThreatIntelFeedDataService threatIntelFeedDataService, + BuiltInTIFMetadataLoader builtInTIFMetadataLoader) { + this.clusterService = clusterService; + this.clusterSettings = clusterService.getClusterSettings(); + this.jobSchedulerParameterService = jobSchedulerParameterService; + this.threatIntelFeedDataService = threatIntelFeedDataService; + this.builtInTIFMetadataLoader = builtInTIFMetadataLoader; + } + + // functions used in job Runner + + /** + * Delete old feed indices except the one which is being used + */ + public void deleteAllTifdIndices(List oldIndices, List newIndices) { + try { + oldIndices.removeAll(newIndices); + if (false == oldIndices.isEmpty()) { + deleteIndices(oldIndices); + } + } catch (Exception e) { + log.error( + () -> new ParameterizedMessage("Failed to delete old threat intel feed indices {}", StringUtils.join(oldIndices)), e + ); + } + } + + private List deleteIndices(final List indicesToDelete) { + List deletedIndices = new ArrayList<>(indicesToDelete.size()); + for (String index : indicesToDelete) { + if (false == clusterService.state().metadata().hasIndex(index)) { + deletedIndices.add(index); + } + } + indicesToDelete.removeAll(deletedIndices); + try { + threatIntelFeedDataService.deleteThreatIntelDataIndex(indicesToDelete); + } catch (Exception e) { + log.error( + () -> new ParameterizedMessage("Failed to delete old threat intel feed index [{}]", indicesToDelete), e + ); + } + return indicesToDelete; + } + + + /** + * Update threat intel feed data + *

+ * The first column is ip range field regardless its header name. + * Therefore, we don't store the first column's header name. + * + * @param jobSchedulerParameter the jobSchedulerParameter + * @param renewLock runnable to renew lock + * @throws IOException + */ + public List createThreatIntelFeedData(final TIFJobParameter jobSchedulerParameter, final Runnable renewLock) throws IOException { + Instant startTime = Instant.now(); + + List freshIndices = new ArrayList<>(); + for (TIFMetadata tifMetadata : builtInTIFMetadataLoader.getTifMetadataList()) { + String indexName = setupIndex(jobSchedulerParameter, tifMetadata); + + Boolean succeeded; + switch (tifMetadata.getFeedType()) { + case "csv": + try (CSVParser reader = ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(tifMetadata)) { + CSVParser noHeaderReader = ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(tifMetadata); + boolean notFound = true; + + while (notFound) { + CSVRecord hasHeaderRecord = reader.iterator().next(); + + //if we want to skip this line and keep iterating + if ((hasHeaderRecord.values().length ==1 && "".equals(hasHeaderRecord.values()[0])) || hasHeaderRecord.get(0).charAt(0) == '#' || hasHeaderRecord.get(0).charAt(0) == ' '){ + noHeaderReader.iterator().next(); + } else { // we found the first line that contains information + notFound = false; + } + } + if (tifMetadata.hasHeader()){ + threatIntelFeedDataService.parseAndSaveThreatIntelFeedDataCSV(indexName, reader.iterator(), renewLock, tifMetadata); + } else { + threatIntelFeedDataService.parseAndSaveThreatIntelFeedDataCSV(indexName, noHeaderReader.iterator(), renewLock, tifMetadata); + } + succeeded = true; + } + break; + default: + // if the feed type doesn't match any of the supporting feed types, throw an exception + succeeded = false; + } + waitUntilAllShardsStarted(indexName, MAX_WAIT_TIME_FOR_REPLICATION_TO_COMPLETE_IN_MILLIS); + + if (!succeeded) { + log.error("Exception: failed to parse correct feed type"); + throw new OpenSearchException("Exception: failed to parse correct feed type"); + } + freshIndices.add(indexName); + } + Instant endTime = Instant.now(); + updateJobSchedulerParameterAsSucceeded(freshIndices, jobSchedulerParameter, startTime, endTime); + return freshIndices; + } + + // helper functions + + /*** + * Update jobSchedulerParameter as succeeded + * + * @param jobSchedulerParameter the jobSchedulerParameter + */ + public void updateJobSchedulerParameterAsSucceeded( + List indices, + final TIFJobParameter jobSchedulerParameter, + final Instant startTime, + final Instant endTime + ) { + jobSchedulerParameter.setIndices(indices); + jobSchedulerParameter.getUpdateStats().setLastSucceededAt(endTime); + jobSchedulerParameter.getUpdateStats().setLastProcessingTimeInMillis(endTime.toEpochMilli() - startTime.toEpochMilli()); + jobSchedulerParameter.enable(); + jobSchedulerParameter.setState(TIFJobState.AVAILABLE); + jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter); + log.info( + "threat intel feed data creation succeeded for {} and took {} seconds", + jobSchedulerParameter.getName(), + Duration.between(startTime, endTime) + ); + } + + /*** + * Create index to add a new threat intel feed data + * + * @param jobSchedulerParameter the jobSchedulerParameter + * @param tifMetadata + * @return new index name + */ + private String setupIndex(final TIFJobParameter jobSchedulerParameter, TIFMetadata tifMetadata) { + String indexName = jobSchedulerParameter.newIndexName(jobSchedulerParameter, tifMetadata); + jobSchedulerParameter.getIndices().add(indexName); + jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter); + threatIntelFeedDataService.createIndexIfNotExists(indexName); + return indexName; + } + + /** + * We wait until all shards are ready to serve search requests before updating job scheduler parameter to + * point to a new index so that there won't be latency degradation during threat intel feed data update + * + * @param indexName the indexName + */ + protected void waitUntilAllShardsStarted(final String indexName, final int timeout) { + Instant start = Instant.now(); + try { + while (Instant.now().toEpochMilli() - start.toEpochMilli() < timeout) { + if (clusterService.state().routingTable().allShards(indexName).stream().allMatch(shard -> shard.started())) { + return; + } + Thread.sleep(SLEEP_TIME_IN_MILLIS); + } + throw new OpenSearchException( + "index[{}] replication did not complete after {} millis", + MAX_WAIT_TIME_FOR_REPLICATION_TO_COMPLETE_IN_MILLIS + ); + } catch (InterruptedException e) { + log.error("runtime exception", e); + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); + } + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/common/StashedThreadContext.java b/src/main/java/org/opensearch/securityanalytics/threatintel/common/StashedThreadContext.java new file mode 100644 index 000000000..32f4e6d40 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/common/StashedThreadContext.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +import java.util.function.Supplier; + +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; + +/** + * Helper class to run code with stashed thread context + * + * Code need to be run with stashed thread context if it interacts with system index + * when security plugin is enabled. + */ +public class StashedThreadContext { + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing + */ + public static void run(final Client client, final Runnable function) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + function.run(); + } + } + + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function supplier function that needs to be executed after thread context has been stashed, return object + */ + public static T run(final Client client, final Supplier function) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + return function.get(); + } + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetMappingsViewAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetMappingsViewAction.java index 38c761261..327990b2d 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetMappingsViewAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetMappingsViewAction.java @@ -4,22 +4,15 @@ */ package org.opensearch.securityanalytics.transport; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.securityanalytics.action.GetIndexMappingsAction; -import org.opensearch.securityanalytics.action.GetIndexMappingsRequest; -import org.opensearch.securityanalytics.action.GetIndexMappingsResponse; +import org.opensearch.core.action.ActionListener; import org.opensearch.securityanalytics.action.GetMappingsViewAction; import org.opensearch.securityanalytics.action.GetMappingsViewRequest; import org.opensearch.securityanalytics.action.GetMappingsViewResponse; import org.opensearch.securityanalytics.mapper.MapperService; -import org.opensearch.securityanalytics.util.SecurityAnalyticsException; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index ae2afc1f3..480ed0152 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -88,6 +88,7 @@ import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.model.LogType; import org.opensearch.securityanalytics.model.Rule; import org.opensearch.securityanalytics.model.Value; import org.opensearch.securityanalytics.rules.aggregation.AggregationItem; @@ -96,6 +97,7 @@ import org.opensearch.securityanalytics.rules.backend.QueryBackend; import org.opensearch.securityanalytics.rules.exceptions.SigmaError; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService; import org.opensearch.securityanalytics.util.DetectorIndices; import org.opensearch.securityanalytics.util.DetectorUtils; import org.opensearch.securityanalytics.util.IndexUtils; @@ -108,6 +110,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -115,6 +118,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -155,6 +159,7 @@ public class TransportIndexDetectorAction extends HandledTransportAction> rulesById, Detect List monitorRequests = new ArrayList<>(); - if (!docLevelRules.isEmpty()) { + if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) { monitorRequests.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST)); } @@ -318,7 +325,9 @@ private void createMonitorFromQueries(List> rulesById, Detect monitorResponses.add(addedFirstMonitorResponse); saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener); }, - listener::onFailure + e -> { + listener.onFailure(e); + } ); } } @@ -441,7 +450,7 @@ public void onResponse(Map> ruleFieldMappings) { Collectors.toList()); // Process doc level monitors - if (!docLevelRules.isEmpty()) { + if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) { if (detector.getDocLevelMonitorId() == null) { monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST)); } else { @@ -469,7 +478,7 @@ public void onFailure(Exception e) { Collectors.toList()); // Process doc level monitors - if (!docLevelRules.isEmpty()) { + if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) { if (detector.getDocLevelMonitorId() == null) { monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST)); } else { @@ -648,6 +657,7 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, Collections.emptyList(), actualQuery, tags); docLevelQueries.add(docLevelQuery); } + addThreatIntelBasedDocLevelQueries(detector, docLevelQueries); DocLevelMonitorInput docLevelMonitorInput = new DocLevelMonitorInput(detector.getName(), detector.getInputs().get(0).getIndices(), docLevelQueries); docLevelMonitorInputs.add(docLevelMonitorInput); @@ -678,6 +688,40 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List return new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null); } + private void addThreatIntelBasedDocLevelQueries(Detector detector, List docLevelQueries) { + try { + + if (detector.getThreatIntelEnabled()) { + log.debug("threat intel enabled for detector {} . adding threat intel based doc level queries.", detector.getName()); + List iocFieldsList = logTypeService.getIocFieldsList(detector.getDetectorType()); + if (iocFieldsList == null || iocFieldsList.isEmpty()) { + + } else { + CountDownLatch countDownLatch = new CountDownLatch(1); + detectorThreatIntelService.createDocLevelQueryFromThreatIntel(iocFieldsList, detector, new ActionListener<>() { + @Override + public void onResponse(List dlqs) { + if (dlqs != null) + docLevelQueries.addAll(dlqs); + countDownLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + // not failing detector creation if any fatal exception occurs during doc level query creation from threat intel feed data + log.error("Failed to convert threat intel feed to. Proceeding with detector creation", e); + countDownLatch.countDown(); + } + }); + countDownLatch.await(); + } + } + } catch (Exception e) { + // not failing detector creation if any fatal exception occurs during doc level query creation from threat intel feed data + log.error("Failed to convert threat intel feed to doc level query. Proceeding with detector creation", e); + } + } + /** * Creates doc level monitor which generates per document alerts for the findings of the bucket level delegate monitors in a workflow. * This monitor has match all query applied to generate the alerts per each finding doc. @@ -1409,6 +1453,7 @@ public void indexDetector() throws Exception { .source(request.getDetector().toXContentWithUser(XContentFactory.jsonBuilder(), new ToXContent.MapParams(Map.of("with_type", "true")))) .timeout(indexTimeout); } else { + request.getDetector().setLastUpdateTime(Instant.now()); indexRequest = new IndexRequest(Detector.DETECTORS_INDEX) .setRefreshPolicy(request.getRefreshPolicy()) .source(request.getDetector().toXContentWithUser(XContentFactory.jsonBuilder(), new ToXContent.MapParams(Map.of("with_type", "true")))) diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchDetectorAction.java index 53ef22a76..0643b34d7 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchDetectorAction.java @@ -21,12 +21,15 @@ import org.opensearch.securityanalytics.action.SearchDetectorAction; import org.opensearch.securityanalytics.action.SearchDetectorRequest; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.action.TransportPutTIFJobAction; import org.opensearch.securityanalytics.util.DetectorIndices; import org.opensearch.threadpool.ThreadPool; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import java.util.concurrent.CountDownLatch; + import static org.opensearch.securityanalytics.util.DetectorUtils.getEmptySearchResponse; public class TransportSearchDetectorAction extends HandledTransportAction implements SecureTransportAction { @@ -45,11 +48,13 @@ public class TransportSearchDetectorAction extends HandledTransportAction() { @Override public void onResponse(SearchResponse response) { - actionListener.onResponse(response); + actionListener.onResponse(response); } @Override diff --git a/src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java b/src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java index a2c2be7ee..65762c57f 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java +++ b/src/main/java/org/opensearch/securityanalytics/util/RuleIndices.java @@ -4,18 +4,11 @@ */ package org.opensearch.securityanalytics.util; -import java.util.Set; - -import com.google.common.collect.ImmutableMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchStatusException; -import org.opensearch.cluster.routing.Preference; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; -import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; @@ -28,26 +21,23 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.health.ClusterIndexHealth; import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.cluster.routing.IndexRoutingTable; +import org.opensearch.cluster.routing.Preference; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.io.PathUtils; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequestBuilder; import org.opensearch.search.SearchHit; -import org.opensearch.core.rest.RestStatus; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.logtype.LogTypeService; -import org.opensearch.securityanalytics.mapper.MapperUtils; -import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.Rule; import org.opensearch.securityanalytics.rules.backend.OSQueryBackend; import org.opensearch.securityanalytics.rules.backend.QueryBackend; @@ -56,24 +46,14 @@ import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.Charset; -import java.nio.file.FileSystem; -import java.nio.file.FileSystems; import java.nio.file.Files; import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.opensearch.securityanalytics.model.Detector.NO_ID; import static org.opensearch.securityanalytics.model.Detector.NO_VERSION; public class RuleIndices { diff --git a/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension b/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension new file mode 100644 index 000000000..0ffeb24aa --- /dev/null +++ b/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension @@ -0,0 +1 @@ +org.opensearch.securityanalytics.SecurityAnalyticsPlugin \ No newline at end of file diff --git a/src/main/resources/OSMapping/ad_ldap_logtype.json b/src/main/resources/OSMapping/ad_ldap_logtype.json index e3434bca5..be2dd5488 100644 --- a/src/main/resources/OSMapping/ad_ldap_logtype.json +++ b/src/main/resources/OSMapping/ad_ldap_logtype.json @@ -2,7 +2,8 @@ "name": "ad_ldap", "description": "AD/LDAP", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"TargetUserName", "ecs":"azure.signinlogs.properties.user_id" diff --git a/src/main/resources/OSMapping/apache_access_logtype.json b/src/main/resources/OSMapping/apache_access_logtype.json index 7753c8440..714fa2acb 100644 --- a/src/main/resources/OSMapping/apache_access_logtype.json +++ b/src/main/resources/OSMapping/apache_access_logtype.json @@ -2,5 +2,6 @@ "name": "apache_access", "description": "Apache Access Log type", "is_builtin": true, - "mappings": [] + "ioc_fields" : [], + "mappings":[] } diff --git a/src/main/resources/OSMapping/azure_logtype.json b/src/main/resources/OSMapping/azure_logtype.json index ec9ae0502..bb55dbe5f 100644 --- a/src/main/resources/OSMapping/azure_logtype.json +++ b/src/main/resources/OSMapping/azure_logtype.json @@ -2,7 +2,8 @@ "name": "azure", "description": "Azure Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"Resultdescription", "ecs":"azure.signinlogs.result_description" diff --git a/src/main/resources/OSMapping/cloudtrail_logtype.json b/src/main/resources/OSMapping/cloudtrail_logtype.json index 389652373..8c2ea3b3a 100644 --- a/src/main/resources/OSMapping/cloudtrail_logtype.json +++ b/src/main/resources/OSMapping/cloudtrail_logtype.json @@ -2,7 +2,15 @@ "name": "cloudtrail", "description": "Cloudtrail Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields": [ + { + "ioc": "ip", + "fields": [ + "src_endpoint.ip" + ] + } + ], + "mappings":[ { "raw_field":"eventName", "ecs":"aws.cloudtrail.event_name", diff --git a/src/main/resources/OSMapping/dns_logtype.json b/src/main/resources/OSMapping/dns_logtype.json index ca2f5451a..ef012407f 100644 --- a/src/main/resources/OSMapping/dns_logtype.json +++ b/src/main/resources/OSMapping/dns_logtype.json @@ -2,7 +2,15 @@ "name": "dns", "description": "DNS Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields": [ + { + "ioc": "ip", + "fields": [ + "src_endpoint.ip" + ] + } + ], + "mappings":[ { "raw_field":"record_type", "ecs":"dns.answers.type", diff --git a/src/main/resources/OSMapping/github_logtype.json b/src/main/resources/OSMapping/github_logtype.json index 6369e2949..31ec6ee59 100644 --- a/src/main/resources/OSMapping/github_logtype.json +++ b/src/main/resources/OSMapping/github_logtype.json @@ -2,7 +2,8 @@ "name": "github", "description": "Github Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"action", "ecs":"github.action" diff --git a/src/main/resources/OSMapping/gworkspace_logtype.json b/src/main/resources/OSMapping/gworkspace_logtype.json index b0006b6a3..7c5766895 100644 --- a/src/main/resources/OSMapping/gworkspace_logtype.json +++ b/src/main/resources/OSMapping/gworkspace_logtype.json @@ -2,7 +2,8 @@ "name": "gworkspace", "description": "GWorkspace Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"eventSource", "ecs":"google_workspace.admin.service.name" diff --git a/src/main/resources/OSMapping/linux_logtype.json b/src/main/resources/OSMapping/linux_logtype.json index f719913c0..5b77de6b3 100644 --- a/src/main/resources/OSMapping/linux_logtype.json +++ b/src/main/resources/OSMapping/linux_logtype.json @@ -2,7 +2,8 @@ "name": "linux", "description": "Linux Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"name", "ecs":"user.filesystem.name" diff --git a/src/main/resources/OSMapping/m365_logtype.json b/src/main/resources/OSMapping/m365_logtype.json index 6547d3d63..e19c2418e 100644 --- a/src/main/resources/OSMapping/m365_logtype.json +++ b/src/main/resources/OSMapping/m365_logtype.json @@ -2,7 +2,8 @@ "name": "m365", "description": "Microsoft 365 Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"eventSource", "ecs":"rsa.misc.event_source" diff --git a/src/main/resources/OSMapping/netflow_logtype.json b/src/main/resources/OSMapping/netflow_logtype.json index d8ec32632..9dc015198 100644 --- a/src/main/resources/OSMapping/netflow_logtype.json +++ b/src/main/resources/OSMapping/netflow_logtype.json @@ -2,7 +2,16 @@ "name": "netflow", "description": "Netflow Log Type used only in Integration Tests", "is_builtin": true, - "mappings": [ + "ioc_fields": [ + { + "ioc": "ip", + "fields": [ + "destination.ip", + "source.ip" + ] + } + ], + "mappings":[ { "raw_field":"netflow.source_ipv4_address", "ecs":"source.ip" diff --git a/src/main/resources/OSMapping/network_logtype.json b/src/main/resources/OSMapping/network_logtype.json index 90f0b2ee6..2ca92a1ad 100644 --- a/src/main/resources/OSMapping/network_logtype.json +++ b/src/main/resources/OSMapping/network_logtype.json @@ -2,7 +2,16 @@ "name": "network", "description": "Network Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields": [ + { + "ioc": "ip", + "fields": [ + "destination.ip", + "source.ip" + ] + } + ], + "mappings":[ { "raw_field":"action", "ecs":"netflow.firewall_event" diff --git a/src/main/resources/OSMapping/okta_logtype.json b/src/main/resources/OSMapping/okta_logtype.json index 8038b7f01..e73a0c273 100644 --- a/src/main/resources/OSMapping/okta_logtype.json +++ b/src/main/resources/OSMapping/okta_logtype.json @@ -2,7 +2,8 @@ "name": "okta", "description": "Okta Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"eventtype", "ecs":"okta.event_type" diff --git a/src/main/resources/OSMapping/others_application_logtype.json b/src/main/resources/OSMapping/others_application_logtype.json index d7faf8c94..4008602d4 100644 --- a/src/main/resources/OSMapping/others_application_logtype.json +++ b/src/main/resources/OSMapping/others_application_logtype.json @@ -2,7 +2,8 @@ "name": "others_application", "description": "others_application", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"record_type", "ecs":"dns.answers.type" diff --git a/src/main/resources/OSMapping/others_apt_logtype.json b/src/main/resources/OSMapping/others_apt_logtype.json index ace55cbc3..1a4ca711f 100644 --- a/src/main/resources/OSMapping/others_apt_logtype.json +++ b/src/main/resources/OSMapping/others_apt_logtype.json @@ -2,7 +2,8 @@ "name": "others_apt", "description": "others_apt", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"record_type", "ecs":"dns.answers.type" diff --git a/src/main/resources/OSMapping/others_cloud_logtype.json b/src/main/resources/OSMapping/others_cloud_logtype.json index b5da3e005..64cbc7935 100644 --- a/src/main/resources/OSMapping/others_cloud_logtype.json +++ b/src/main/resources/OSMapping/others_cloud_logtype.json @@ -2,7 +2,8 @@ "name": "others_cloud", "description": "others_cloud", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"record_type", "ecs":"dns.answers.type" diff --git a/src/main/resources/OSMapping/others_compliance_logtype.json b/src/main/resources/OSMapping/others_compliance_logtype.json index 6f362d589..6e065795a 100644 --- a/src/main/resources/OSMapping/others_compliance_logtype.json +++ b/src/main/resources/OSMapping/others_compliance_logtype.json @@ -2,7 +2,8 @@ "name": "others_compliance", "description": "others_compliance", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"record_type", "ecs":"dns.answers.type" diff --git a/src/main/resources/OSMapping/others_macos_logtype.json b/src/main/resources/OSMapping/others_macos_logtype.json index 50d1c2160..6b6452100 100644 --- a/src/main/resources/OSMapping/others_macos_logtype.json +++ b/src/main/resources/OSMapping/others_macos_logtype.json @@ -2,7 +2,8 @@ "name": "others_macos", "description": "others_macos", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"record_type", "ecs":"dns.answers.type" diff --git a/src/main/resources/OSMapping/others_proxy_logtype.json b/src/main/resources/OSMapping/others_proxy_logtype.json index aca4529d1..a2b0794a4 100644 --- a/src/main/resources/OSMapping/others_proxy_logtype.json +++ b/src/main/resources/OSMapping/others_proxy_logtype.json @@ -2,7 +2,8 @@ "name": "others_proxy", "description": "others_proxy", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"record_type", "ecs":"dns.answers.type" diff --git a/src/main/resources/OSMapping/others_web_logtype.json b/src/main/resources/OSMapping/others_web_logtype.json index ae8262d52..b46adc6a4 100644 --- a/src/main/resources/OSMapping/others_web_logtype.json +++ b/src/main/resources/OSMapping/others_web_logtype.json @@ -2,7 +2,8 @@ "name": "others_web", "description": "others_web", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"record_type", "ecs":"dns.answers.type" diff --git a/src/main/resources/OSMapping/s3_logtype.json b/src/main/resources/OSMapping/s3_logtype.json index 58c546258..20c896df6 100644 --- a/src/main/resources/OSMapping/s3_logtype.json +++ b/src/main/resources/OSMapping/s3_logtype.json @@ -2,7 +2,8 @@ "name": "s3", "description": "S3 Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"eventName", "ecs":"aws.cloudtrail.event_name" diff --git a/src/main/resources/OSMapping/test_windows_logtype.json b/src/main/resources/OSMapping/test_windows_logtype.json index 7491a954c..cc619c5a1 100644 --- a/src/main/resources/OSMapping/test_windows_logtype.json +++ b/src/main/resources/OSMapping/test_windows_logtype.json @@ -2,6 +2,12 @@ "name": "test_windows", "description": "Test Log Type used by tests. It is created as a lightweight log type for integration tests", "is_builtin": true, + "ioc_fields": [ + { + "ioc": "ip", + "fields": ["HostName"] + } + ], "mappings": [ { "raw_field":"EventID", diff --git a/src/main/resources/OSMapping/vpcflow_logtype.json b/src/main/resources/OSMapping/vpcflow_logtype.json index c55305b6d..29d9f38c2 100644 --- a/src/main/resources/OSMapping/vpcflow_logtype.json +++ b/src/main/resources/OSMapping/vpcflow_logtype.json @@ -2,7 +2,16 @@ "name": "vpcflow", "description": "VPC Flow Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields": [ + { + "ioc": "ip", + "fields": [ + "dst_endpoint.ip", + "src_endpoint.ip" + ] + } + ], + "mappings":[ { "raw_field":"version", "ecs":"netflow.version", diff --git a/src/main/resources/OSMapping/waf_logtype.json b/src/main/resources/OSMapping/waf_logtype.json index 5eed2c2fb..3e5b1f4f1 100644 --- a/src/main/resources/OSMapping/waf_logtype.json +++ b/src/main/resources/OSMapping/waf_logtype.json @@ -2,7 +2,8 @@ "name": "waf", "description": "Web Application Firewall Log Type", "is_builtin": true, - "mappings": [ + "ioc_fields" : [], + "mappings":[ { "raw_field":"cs-method", "ecs":"waf.request.method" diff --git a/src/main/resources/OSMapping/windows_logtype.json b/src/main/resources/OSMapping/windows_logtype.json index a5fef8ea7..ec9b3ed1a 100644 --- a/src/main/resources/OSMapping/windows_logtype.json +++ b/src/main/resources/OSMapping/windows_logtype.json @@ -2,7 +2,13 @@ "name": "windows", "description": "Windows Log Type", "is_builtin": true, - "mappings":[ + "ioc_fields" : [ + { + "ioc": "ip", + "fields": ["destination.ip","source.ip"] + } + ], + "mappings": [ { "raw_field":"AccountName", "ecs":"winlog.computerObject.name" diff --git a/src/main/resources/mappings/detectors.json b/src/main/resources/mappings/detectors.json index e1e160d5f..c4a42d53a 100644 --- a/src/main/resources/mappings/detectors.json +++ b/src/main/resources/mappings/detectors.json @@ -62,6 +62,9 @@ "enabled": { "type": "boolean" }, + "threat_intel_enabled": { + "type": "boolean" + }, "enabled_time": { "type": "date", "format": "strict_date_time||epoch_millis" diff --git a/src/main/resources/mappings/threat_intel_feed_mapping.json b/src/main/resources/mappings/threat_intel_feed_mapping.json new file mode 100644 index 000000000..2e775cf8e --- /dev/null +++ b/src/main/resources/mappings/threat_intel_feed_mapping.json @@ -0,0 +1,27 @@ +{ + "dynamic": "strict", + "_meta" : { + "schema_version": 1 + }, + "properties": { + "schema_version": { + "type": "integer" + }, + "ioc_type": { + "type": "keyword" + }, + "ioc_value": { + "type": "keyword" + }, + "feed_id": { + "type": "keyword" + }, + "timestamp": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "type": { + "type": "keyword" + } + } +} diff --git a/src/main/resources/mappings/threat_intel_job_mapping.json b/src/main/resources/mappings/threat_intel_job_mapping.json new file mode 100644 index 000000000..ffd165ae5 --- /dev/null +++ b/src/main/resources/mappings/threat_intel_job_mapping.json @@ -0,0 +1,62 @@ +{ + "dynamic": "strict", + "_meta" : { + "schema_version": 1 + }, + "properties": { + "schema_version": { + "type": "integer" + }, + "enabled_time": { + "type": "long" + }, + "indices": { + "type": "text" + }, + "last_update_time": { + "type": "long" + }, + "name": { + "type": "text" + }, + "schedule": { + "properties": { + "interval": { + "properties": { + "period": { + "type": "long" + }, + "start_time": { + "type": "long" + }, + "unit": { + "type": "text" + } + } + } + } + }, + "state": { + "type": "text" + }, + "update_enabled": { + "type": "boolean" + }, + "update_stats": { + "properties": { + "last_failed_at_in_epoch_millis": { + "type": "long" + }, + "last_processing_time_in_millis": { + "type": "long" + }, + "last_skipped_at_in_epoch_millis": { + "type": "long" + }, + "last_succeeded_at_in_epoch_millis": { + "type": "long" + } + } + } + } +} diff --git a/src/main/resources/threatIntelFeed/feedMetadata.json b/src/main/resources/threatIntelFeed/feedMetadata.json new file mode 100644 index 000000000..e0f448012 --- /dev/null +++ b/src/main/resources/threatIntelFeed/feedMetadata.json @@ -0,0 +1,13 @@ +{ + "alienvault_reputation_ip_database": { + "id": "alienvault_reputation_ip_database", + "url": "https://reputation.alienvault.com/reputation.generic", + "name": "Alienvault IP Reputation", + "organization": "Alienvault", + "description": "Alienvault IP Reputation threat intelligence feed managed by AlienVault", + "feed_format": "csv", + "ioc_type": "ip", + "ioc_col": 0, + "has_header": false + } +} diff --git a/src/main/resources/threatIntelFeedInfo/feodo.yml b/src/main/resources/threatIntelFeedInfo/feodo.yml new file mode 100644 index 000000000..8205e47ca --- /dev/null +++ b/src/main/resources/threatIntelFeedInfo/feodo.yml @@ -0,0 +1,6 @@ +url: "https://feodotracker.abuse.ch/downloads/ipblocklist_aggressive.csv" +name: "ipblocklist_aggressive.csv" +feedFormat: "csv" +org: "Feodo" +iocTypes: ["ip"] +description: "" diff --git a/src/test/java/org/opensearch/securityanalytics/DetectorThreatIntelIT.java b/src/test/java/org/opensearch/securityanalytics/DetectorThreatIntelIT.java new file mode 100644 index 000000000..9d83b3ed3 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/DetectorThreatIntelIT.java @@ -0,0 +1,714 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics; + +import org.apache.hc.core5.http.HttpStatus; +import org.junit.Assert; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.model.DetectorRule; +import org.opensearch.securityanalytics.model.DetectorTrigger; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.emptyList; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntelAndTriggers; +import static org.opensearch.securityanalytics.TestHelpers.randomDoc; +import static org.opensearch.securityanalytics.TestHelpers.randomDocWithIpIoc; +import static org.opensearch.securityanalytics.TestHelpers.randomIndex; +import static org.opensearch.securityanalytics.TestHelpers.randomRule; +import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; +import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; + +public class DetectorThreatIntelIT extends SecurityAnalyticsRestTestCase { + + public void testCreateDetectorWithThreatIntelEnabled_updateDetectorWithThreatIntelDisabled() throws IOException { + + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + String randomDocRuleId = createRule(randomRule()); + List detectorRules = List.of(new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + emptyList()); + DetectorTrigger trigger = new DetectorTrigger("all", "all", "high", List.of(randomDetectorType()), emptyList(), emptyList(), List.of(), emptyList(), List.of(DetectorTrigger.RULES_DETECTION_TYPE, DetectorTrigger.THREAT_INTEL_DETECTION_TYPE)); + Detector detector = randomDetectorWithInputsAndThreatIntelAndTriggers(List.of(input), true, List.of(trigger)); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + + assertEquals(2, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + List iocs = getThreatIntelFeedIocs(3); + int i = 1; + for (String ioc : iocs) { + indexDoc(index, i + "", randomDocWithIpIoc(5, 3, ioc)); + i++; + } + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + List> monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + Map docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + int noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(2, noOfSigmaRuleMatches); + String threatIntelDocLevelQueryId = docLevelQueryResults.keySet().stream().filter(id -> id.startsWith("threat_intel")).findAny().get(); + ArrayList docs = (ArrayList) docLevelQueryResults.get(threatIntelDocLevelQueryId); + assertEquals(docs.size(), 3); + //verify alerts + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + + Assert.assertEquals(3, getAlertsBody.get("total_alerts")); + + // update detector + Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(randomDetectorWithInputsAndThreatIntel(List.of(input), false))); + + assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); + + Map updateResponseBody = asMap(updateResponse); + for (String ioc : iocs) { + indexDoc(index, i + "", randomDocWithIpIoc(5, 3, ioc)); + i++; + } + + executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(1, noOfSigmaRuleMatches); + } + + public void testCreateDetectorWithThreatIntelDisabled_updateDetectorWithThreatIntelEnabled() throws IOException { + + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + String randomDocRuleId = createRule(randomRule()); + List detectorRules = List.of(new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + emptyList()); + Detector detector = randomDetectorWithInputsAndThreatIntel(List.of(input), false); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + indexDoc(index, "1", randomDoc(2, 4, "test")); + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + List> monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + Map docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + int noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(1, noOfSigmaRuleMatches); + + Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(randomDetectorWithInputsAndThreatIntel(List.of(input), true))); + + assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); + + Map updateResponseBody = asMap(updateResponse); + List iocs = getThreatIntelFeedIocs(3); + int i = 2; + for (String ioc : iocs) { + indexDoc(index, i + "", randomDocWithIpIoc(5, 3, ioc)); + i++; + } + executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(2, noOfSigmaRuleMatches); + } + + public void testCreateDetectorWithThreatIntelEnabledAndNoRules_triggerDetectionTypeOnlyRules_noAlertsForFindings() throws IOException { + + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + + List detectorRules = emptyList(); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + emptyList()); + DetectorTrigger trigger = new DetectorTrigger("all", "all", "high", List.of(randomDetectorType()), emptyList(), emptyList(), List.of(), emptyList(), List.of(DetectorTrigger.RULES_DETECTION_TYPE)); + Detector detector = randomDetectorWithInputsAndThreatIntelAndTriggers(List.of(input), true, List.of(trigger)); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + List iocs = getThreatIntelFeedIocs(3); + int i = 1; + for (String ioc : iocs) { + indexDoc(index, i + "", randomDocWithIpIoc(5, 3, ioc)); + i++; + } + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + List> monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + Map docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + int noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(1, noOfSigmaRuleMatches); + String threatIntelDocLevelQueryId = docLevelQueryResults.keySet().stream().filter(id -> id.startsWith("threat_intel")).findAny().get(); + ArrayList docs = (ArrayList) docLevelQueryResults.get(threatIntelDocLevelQueryId); + assertEquals(docs.size(), 3); + //verify alerts + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + /** findings are present but alerts should not be generated as detection type mentioned in trigger is rules only */ + Assert.assertEquals(0, getAlertsBody.get("total_alerts")); + } + + public void testCreateDetectorWithThreatIntelEnabled_triggerDetectionTypeOnlyThreatIntel_allAlertsForFindings() throws IOException { + + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + + List detectorRules = emptyList(); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + emptyList()); + DetectorTrigger trigger = new DetectorTrigger("all", "all", "high", + List.of(randomDetectorType()), emptyList(), emptyList(), List.of(), emptyList(), List.of(DetectorTrigger.THREAT_INTEL_DETECTION_TYPE)); + Detector detector = randomDetectorWithInputsAndThreatIntelAndTriggers(List.of(input), true, List.of(trigger)); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + List iocs = getThreatIntelFeedIocs(3); + int i = 1; + for (String ioc : iocs) { + indexDoc(index, i + "", randomDocWithIpIoc(5, 3, ioc)); + i++; + } + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + List> monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + Map docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + int noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(1, noOfSigmaRuleMatches); + String threatIntelDocLevelQueryId = docLevelQueryResults.keySet().stream().filter(id -> id.startsWith("threat_intel")).findAny().get(); + ArrayList docs = (ArrayList) docLevelQueryResults.get(threatIntelDocLevelQueryId); + assertEquals(docs.size(), 3); + //verify alerts + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + /** findings are present and alerts are generated as detection type mentioned in trigger is threat_intel only */ + Assert.assertEquals(3, getAlertsBody.get("total_alerts")); + } + + public void testCreateDetectorWithThreatIntelEnabled_triggerWithBothDetectionType_allAlertsForFindings() throws IOException { + + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + + List detectorRules = emptyList(); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + emptyList()); + DetectorTrigger trigger = new DetectorTrigger("all", "all", "high", + List.of(randomDetectorType()), emptyList(), emptyList(), List.of(), emptyList(), + List.of(DetectorTrigger.THREAT_INTEL_DETECTION_TYPE, DetectorTrigger.RULES_DETECTION_TYPE)); + Detector detector = randomDetectorWithInputsAndThreatIntelAndTriggers(List.of(input), true, List.of(trigger)); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + List iocs = getThreatIntelFeedIocs(3); + int i = 1; + for (String ioc : iocs) { + indexDoc(index, i + "", randomDocWithIpIoc(5, 3, ioc)); + i++; + } + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + List> monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + Map docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + int noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(1, noOfSigmaRuleMatches); + String threatIntelDocLevelQueryId = docLevelQueryResults.keySet().stream().filter(id -> id.startsWith("threat_intel")).findAny().get(); + ArrayList docs = (ArrayList) docLevelQueryResults.get(threatIntelDocLevelQueryId); + assertEquals(docs.size(), 3); + //verify alerts + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + /** findings are present and alerts are generated as both detection type mentioned in trigger is threat_intel only */ + Assert.assertEquals(3, getAlertsBody.get("total_alerts")); + } + + public void testCreateDetectorWithThreatIntelDisabled_triggerWithThreatIntelDetectionType_mpAlertsForFindings() throws IOException { + + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + String randomDocRuleId = createRule(randomRule()); + List detectorRules = List.of(new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + emptyList()); + DetectorTrigger trigger = new DetectorTrigger("all", "all", "high", + List.of(randomDetectorType()), emptyList(), emptyList(), List.of(), emptyList(), + List.of(DetectorTrigger.THREAT_INTEL_DETECTION_TYPE)); + Detector detector = randomDetectorWithInputsAndThreatIntelAndTriggers(List.of(input), false, List.of(trigger)); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + + int i = 1; + while (i<4) { + indexDoc(index, i + "", randomDocWithIpIoc(5, 3, i+"")); + i++; + } + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + List> monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + Map docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + int noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(1, noOfSigmaRuleMatches); + String ruleQueryId = docLevelQueryResults.keySet().stream().findAny().get(); + ArrayList docs = (ArrayList) docLevelQueryResults.get(ruleQueryId); + assertEquals(docs.size(), 3); + //verify alerts + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + /** findings are present but alerts are NOT generated as detection type mentioned in trigger is threat_intel only but finding is from rules*/ + Assert.assertEquals(0, getAlertsBody.get("total_alerts")); + } + + public void testCreateDetectorWithThreatIntelDisabled_triggerWithRulesDetectionType_allAlertsForFindings() throws IOException { + + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + String randomDocRuleId = createRule(randomRule()); + List detectorRules = List.of(new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + emptyList()); + DetectorTrigger trigger = new DetectorTrigger("all", "all", "high", + List.of(randomDetectorType()), emptyList(), emptyList(), List.of(), emptyList(), + List.of(DetectorTrigger.RULES_DETECTION_TYPE)); + Detector detector = randomDetectorWithInputsAndThreatIntelAndTriggers(List.of(input), false, List.of(trigger)); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + + int i = 1; + while (i<4) { + indexDoc(index, i + "", randomDocWithIpIoc(5, 3, i+"")); + i++; + } + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + List> monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + Map docLevelQueryResults = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0); + int noOfSigmaRuleMatches = docLevelQueryResults.size(); + assertEquals(1, noOfSigmaRuleMatches); + String ruleQueryId = docLevelQueryResults.keySet().stream().findAny().get(); + ArrayList docs = (ArrayList) docLevelQueryResults.get(ruleQueryId); + assertEquals(docs.size(), 3); + //verify alerts + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + /** findings are present but alerts are NOT generated as detection type mentioned in trigger is threat_intel only but finding is from rules*/ + Assert.assertEquals(3, getAlertsBody.get("total_alerts")); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/LogTypeServiceTests.java b/src/test/java/org/opensearch/securityanalytics/LogTypeServiceTests.java index 8eb717e60..64288f669 100644 --- a/src/test/java/org/opensearch/securityanalytics/LogTypeServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/LogTypeServiceTests.java @@ -50,7 +50,8 @@ protected void beforeTest() throws Exception { new LogType.Mapping("rawFld1", "ecsFld1", "ocsfFld1"), new LogType.Mapping("rawFld2", "ecsFld2", "ocsfFld2"), new LogType.Mapping("rawFld3", "ecsFld3", "ocsfFld3") - ) + ), + List.of(new LogType.IocFields("ip", List.of("dst.ip"))) ) ); when(builtinLogTypeLoader.getAllLogTypes()).thenReturn(dummyLogTypes); diff --git a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java index 2178f06d6..1c8770677 100644 --- a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java +++ b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java @@ -64,6 +64,7 @@ import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.Rule; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; import org.opensearch.test.rest.OpenSearchRestTestCase; @@ -91,6 +92,7 @@ import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.FINDING_HISTORY_MAX_DOCS; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.FINDING_HISTORY_RETENTION_PERIOD; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.FINDING_HISTORY_ROLLOVER_PERIOD; +import static org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedDataUtils.getTifdList; import static org.opensearch.securityanalytics.util.RuleTopicIndices.ruleTopicIndexSettings; public class SecurityAnalyticsRestTestCase extends OpenSearchRestTestCase { @@ -682,6 +684,11 @@ protected String toJsonString(CorrelationRule rule) throws IOException { return IndexUtilsKt.string(shuffleXContent(rule.toXContent(builder, ToXContent.EMPTY_PARAMS))); } + protected String toJsonString(ThreatIntelFeedData tifd) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + return IndexUtilsKt.string(shuffleXContent(tifd.toXContent(builder, ToXContent.EMPTY_PARAMS))); + } + private String alertingScheduledJobMappings() { return " \"_meta\" : {\n" + " \"schema_version\": 5\n" + @@ -1736,4 +1743,20 @@ protected void enableOrDisableWorkflow(String trueOrFalse) throws IOException { request.setJsonEntity(entity); client().performRequest(request); } + + public List getThreatIntelFeedIocs(int num) throws IOException { + String request = getMatchAllSearchRequestString(num); + SearchResponse res = executeSearchAndGetResponse(".opensearch-sap-threat-intel*", request, false); + return getTifdList(res, xContentRegistry()).stream().map(it -> it.getIocValue()).collect(Collectors.toList()); + } + + public String getMatchAllSearchRequestString(int num) { + return "{\n" + + "\"size\" : " + num + "," + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index dde7efbb5..477a7ecee 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -28,6 +28,7 @@ import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.rest.OpenSearchRestTestCase; @@ -53,53 +54,61 @@ static class AccessRoles { public static Detector randomDetector(List rules) { DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), Collections.emptyList(), rules.stream().map(DetectorRule::new).collect(Collectors.toList())); - return randomDetector(null, null, null, List.of(input), List.of(), null, null, null, null); + return randomDetector(null, null, null, List.of(input), List.of(), null, null, null, null, false); } public static Detector randomDetector(List rules, String detectorType) { DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), Collections.emptyList(), rules.stream().map(DetectorRule::new).collect(Collectors.toList())); - return randomDetector(null, detectorType, null, List.of(input), List.of(), null, null, null, null); + return randomDetector(null, detectorType, null, List.of(input), List.of(), null, null, null, null, false); } public static Detector randomDetectorWithInputs(List inputs) { - return randomDetector(null, null, null, inputs, List.of(), null, null, null, null); + return randomDetector(null, null, null, inputs, List.of(), null, null, null, null, false); + } + + public static Detector randomDetectorWithInputsAndThreatIntel(List inputs, Boolean threatIntel) { + return randomDetector(null, null, null, inputs, List.of(), null, null, null, null, threatIntel); + } + + public static Detector randomDetectorWithInputsAndThreatIntelAndTriggers(List inputs, Boolean threatIntel, List triggers) { + return randomDetector(null, null, null, inputs, triggers, null, null, null, null, threatIntel); } public static Detector randomDetectorWithInputsAndTriggers(List inputs, List triggers) { - return randomDetector(null, null, null, inputs, triggers, null, null, null, null); + return randomDetector(null, null, null, inputs, triggers, null, null, null, null, false); } public static Detector randomDetectorWithInputs(List inputs, String detectorType) { - return randomDetector(null, detectorType, null, inputs, List.of(), null, null, null, null); + return randomDetector(null, detectorType, null, inputs, List.of(), null, null, null, null, false); } public static Detector randomDetectorWithTriggers(List triggers) { - return randomDetector(null, null, null, List.of(), triggers, null, null, null, null); + return randomDetector(null, null, null, List.of(), triggers, null, null, null, null, false); } public static Detector randomDetectorWithTriggers(List rules, List triggers) { DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), Collections.emptyList(), rules.stream().map(DetectorRule::new).collect(Collectors.toList())); - return randomDetector(null, null, null, List.of(input), triggers, null, null, null, null); + return randomDetector(null, null, null, List.of(input), triggers, null, null, null, null, false); } public static Detector randomDetectorWithTriggers(List rules, List triggers, List inputIndices) { DetectorInput input = new DetectorInput("windows detector for security analytics", inputIndices, Collections.emptyList(), rules.stream().map(DetectorRule::new).collect(Collectors.toList())); - return randomDetector(null, null, null, List.of(input), triggers, null, null, null, null); + return randomDetector(null, null, null, List.of(input), triggers, null, true, null, null, false); } public static Detector randomDetectorWithTriggersAndScheduleAndEnabled(List rules, List triggers, Schedule schedule, boolean enabled) { DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), Collections.emptyList(), rules.stream().map(DetectorRule::new).collect(Collectors.toList())); - return randomDetector(null, null, null, List.of(input), triggers, schedule, enabled, null, null); + return randomDetector(null, null, null, List.of(input), triggers, schedule, enabled, null, null, false); } public static Detector randomDetectorWithTriggers(List rules, List triggers, String detectorType, DetectorInput input) { - return randomDetector(null, detectorType, null, List.of(input), triggers, null, null, null, null); + return randomDetector(null, detectorType, null, List.of(input), triggers, null, null, null, null, false); } public static Detector randomDetectorWithInputsAndTriggersAndType(List inputs, List triggers, String detectorType) { - return randomDetector(null, detectorType, null, inputs, triggers, null, null, null, null); + return randomDetector(null, detectorType, null, inputs, triggers, null, null, null, null, false); } public static Detector randomDetector(String name, @@ -110,7 +119,8 @@ public static Detector randomDetector(String name, Schedule schedule, Boolean enabled, Instant enabledTime, - Instant lastUpdateTime) { + Instant lastUpdateTime, + Boolean threatIntel) { if (name == null) { name = OpenSearchRestTestCase.randomAlphaOfLength(10); } @@ -146,10 +156,10 @@ public static Detector randomDetector(String name, if (triggers.size() == 0) { triggers = new ArrayList<>(); - DetectorTrigger trigger = new DetectorTrigger(null, "windows-trigger", "1", List.of(randomDetectorType()), List.of("QuarksPwDump Clearing Access History"), List.of("high"), List.of("T0008"), List.of()); + DetectorTrigger trigger = new DetectorTrigger(null, "windows-trigger", "1", List.of(randomDetectorType()), List.of("QuarksPwDump Clearing Access History"), List.of("high"), List.of("T0008"), List.of(), List.of()); triggers.add(trigger); } - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap(), Collections.emptyList()); + return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap(), Collections.emptyList(), threatIntel); } public static CustomLogType randomCustomLogType(String name, String description, String category, String source) { @@ -168,6 +178,15 @@ public static CustomLogType randomCustomLogType(String name, String description, return new CustomLogType(null, null, name, description, category, source, null); } + public static ThreatIntelFeedData randomThreatIntelFeedData() { + return new ThreatIntelFeedData( + "IP_ADDRESS", + "ip", + "alientVault", + Instant.now() + ); + } + public static Detector randomDetectorWithNoUser() { String name = OpenSearchRestTestCase.randomAlphaOfLength(10); String detectorType = randomDetectorType(); @@ -197,7 +216,8 @@ public static Detector randomDetectorWithNoUser() { "", "", Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); } @@ -429,6 +449,12 @@ public static String toJsonStringWithUser(Detector detector) throws IOException return BytesReference.bytes(builder).utf8ToString(); } + public static String toJsonString(ThreatIntelFeedData threatIntelFeedData) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = threatIntelFeedData.toXContent(builder, ToXContent.EMPTY_PARAMS); + return BytesReference.bytes(builder).utf8ToString(); + } + public static User randomUser() { return new User( OpenSearchRestTestCase.randomAlphaOfLength(10), @@ -1351,6 +1377,46 @@ public static String randomDoc(int severity, int version, String opCode) { } + //Add IPs in HostName field. + public static String randomDocWithIpIoc(int severity, int version, String ioc) { + String doc = "{\n" + + "\"EventTime\":\"2020-02-04T14:59:39.343541+00:00\",\n" + + "\"HostName\":\"%s\",\n" + + "\"Keywords\":\"9223372036854775808\",\n" + + "\"SeverityValue\":%s,\n" + + "\"Severity\":\"INFO\",\n" + + "\"EventID\":22,\n" + + "\"SourceName\":\"Microsoft-Windows-Sysmon\",\n" + + "\"ProviderGuid\":\"{5770385F-C22A-43E0-BF4C-06F5698FFBD9}\",\n" + + "\"Version\":%s,\n" + + "\"TaskValue\":22,\n" + + "\"OpcodeValue\":0,\n" + + "\"RecordNumber\":9532,\n" + + "\"ExecutionProcessID\":1996,\n" + + "\"ExecutionThreadID\":2616,\n" + + "\"Channel\":\"Microsoft-Windows-Sysmon/Operational\",\n" + + "\"Domain\":\"NT AUTHORITY\",\n" + + "\"AccountName\":\"SYSTEM\",\n" + + "\"UserID\":\"S-1-5-18\",\n" + + "\"AccountType\":\"User\",\n" + + "\"Message\":\"Dns query:\\r\\nRuleName: \\r\\nUtcTime: 2020-02-04 14:59:38.349\\r\\nProcessGuid: {b3c285a4-3cda-5dc0-0000-001077270b00}\\r\\nProcessId: 1904\\r\\nQueryName: EC2AMAZ-EPO7HKA\\r\\nQueryStatus: 0\\r\\nQueryResults: 172.31.46.38;\\r\\nImage: C:\\\\Program Files\\\\nxlog\\\\nxlog.exe\",\n" + + "\"Category\":\"Dns query (rule: DnsQuery)\",\n" + + "\"Opcode\":\"blahblah\",\n" + + "\"UtcTime\":\"2020-02-04 14:59:38.349\",\n" + + "\"ProcessGuid\":\"{b3c285a4-3cda-5dc0-0000-001077270b00}\",\n" + + "\"ProcessId\":\"1904\",\"QueryName\":\"EC2AMAZ-EPO7HKA\",\"QueryStatus\":\"0\",\n" + + "\"QueryResults\":\"172.31.46.38;\",\n" + + "\"Image\":\"C:\\\\Program Files\\\\nxlog\\\\regsvr32.exe\",\n" + + "\"EventReceivedTime\":\"2020-02-04T14:59:40.780905+00:00\",\n" + + "\"SourceModuleName\":\"in\",\n" + + "\"SourceModuleType\":\"im_msvistalog\",\n" + + "\"CommandLine\": \"eachtest\",\n" + + "\"Initiated\": \"true\"\n" + + "}"; + return String.format(Locale.ROOT, doc, ioc, severity, version); + + } + public static String randomDoc() { return "{\n" + "\"@timestamp\":\"2020-02-04T14:59:39.343541+00:00\",\n" + @@ -1501,6 +1567,20 @@ public static String vpcFlowMappings() { " }"; } + private static String randomString() { + return OpenSearchTestCase.randomAlphaOfLengthBetween(2, 16); + } + + public static String randomLowerCaseString() { + return randomString().toLowerCase(Locale.ROOT); + } + + public static List randomLowerCaseStringList() { + List stringList = new ArrayList<>(); + stringList.add(randomLowerCaseString()); + return stringList; + } + public static XContentParser parser(String xc) throws IOException { XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, xc); parser.nextToken(); @@ -1511,7 +1591,8 @@ public static NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry( List.of( Detector.XCONTENT_REGISTRY, - DetectorInput.XCONTENT_REGISTRY + DetectorInput.XCONTENT_REGISTRY, + ThreatIntelFeedData.XCONTENT_REGISTRY ) ); } diff --git a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java index db366056b..ca98a1144 100644 --- a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java +++ b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java @@ -50,7 +50,8 @@ public void testIndexDetectorPostResponse() throws IOException { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); IndexDetectorResponse response = new IndexDetectorResponse("1234", 1L, RestStatus.OK, detector); Assert.assertNotNull(response); @@ -69,5 +70,6 @@ public void testIndexDetectorPostResponse() throws IOException { Assert.assertTrue(newResponse.getDetector().getMonitorIds().contains("1")); Assert.assertTrue(newResponse.getDetector().getMonitorIds().contains("2")); Assert.assertTrue(newResponse.getDetector().getMonitorIds().contains("3")); + Assert.assertFalse(newResponse.getDetector().getThreatIntelEnabled()); } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java index 78dacd6e1..d250d2eef 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java @@ -65,7 +65,8 @@ public void testGetAlerts_success() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -242,7 +243,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java index d3665dcfc..04f17d867 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java @@ -19,6 +19,7 @@ import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Assert; import org.junit.Ignore; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; @@ -37,9 +38,11 @@ import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings; import static org.opensearch.securityanalytics.TestHelpers.randomAction; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDoc; +import static org.opensearch.securityanalytics.TestHelpers.randomDocWithIpIoc; import static org.opensearch.securityanalytics.TestHelpers.randomIndex; import static org.opensearch.securityanalytics.TestHelpers.randomRule; import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; @@ -47,6 +50,7 @@ import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_MAX_DOCS; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_RETENTION_PERIOD; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_ROLLOVER_PERIOD; +import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; public class AlertsIT extends SecurityAnalyticsRestTestCase { @@ -82,7 +86,7 @@ public void testGetAlerts_success() throws IOException { Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdId)), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction)))); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of()))); createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -200,13 +204,13 @@ public void testAckAlerts_WithInvalidDetectorAlertsCombination() throws IOExcept Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction)))); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Detector detector1 = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction)))); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of()))); Response createResponse1 = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector1)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -307,7 +311,7 @@ public void testAckAlertsWithInvalidDetector() throws IOException { Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdId)), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction)))); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of()))); createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -415,7 +419,7 @@ public void testGetAlerts_byDetectorType_success() throws IOException, Interrupt Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -494,7 +498,7 @@ public void testGetAlerts_byDetectorType_multipleDetectors_success() throws IOEx Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); // Detector 1 - WINDOWS - Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector1)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -517,7 +521,7 @@ public void testGetAlerts_byDetectorType_multipleDetectors_success() throws IOEx getPrePackagedRules("network").stream().map(DetectorRule::new).collect(Collectors.toList())); Detector detector2 = randomDetectorWithTriggers( getPrePackagedRules("network"), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of(), List.of())), "network", inputNetflow ); @@ -610,7 +614,7 @@ public void testAlertHistoryRollover_maxAge() throws IOException, InterruptedExc Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -681,7 +685,7 @@ public void testAlertHistoryRollover_maxAge_low_retention() throws IOException, Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -761,7 +765,7 @@ public void testAlertHistoryRollover_maxDocs() throws IOException, InterruptedEx Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -845,7 +849,7 @@ public void testGetAlertsFromAllIndices() throws IOException, InterruptedExcepti Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/SecureAlertsRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/alerts/SecureAlertsRestApiIT.java index 4ecf3287f..20e526697 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/SecureAlertsRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/SecureAlertsRestApiIT.java @@ -97,8 +97,8 @@ public void testGetAlerts_byDetectorId_success() throws IOException { Action triggerAction = randomAction(createDestination()); Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdId)), - getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction)))); + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of(DetectorTrigger.RULES_DETECTION_TYPE, DetectorTrigger.THREAT_INTEL_DETECTION_TYPE)))); createResponse = makeRequest(userClient, "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -235,7 +235,7 @@ public void testGetAlerts_byDetectorType_success() throws IOException, Interrupt Response response = userClient.performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(userClient, "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java index e721e1124..225cebb8c 100644 --- a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java @@ -220,7 +220,7 @@ private String createWindowsToAppLogsToS3LogsRule(LogIndices indices) throws IOE private String createVpcFlowDetector(String indexName) throws IOException { Detector vpcFlowDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("vpc flow detector for security analytics", List.of(indexName), List.of(), getPrePackagedRules("network").stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of())), "network"); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of(), List.of())), "network"); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(vpcFlowDetector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -280,7 +280,7 @@ private String createAdLdapDetector(String indexName) throws IOException { Detector adLdapDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("ad_ldap logs detector for security analytics", List.of(indexName), List.of(), getPrePackagedRules("ad_ldap").stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("ad_ldap"), List.of(), List.of(), List.of(), List.of())), "ad_ldap"); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("ad_ldap"), List.of(), List.of(), List.of(), List.of(), List.of())), "ad_ldap"); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(adLdapDetector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -319,7 +319,7 @@ private String createTestWindowsDetector(String indexName) throws IOException { Detector windowsDetector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of(indexName), List.of(), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(windowsDetector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -345,7 +345,7 @@ private String createTestWindowsDetector(String indexName) throws IOException { private String createAppLogsDetector(String indexName) throws IOException { Detector appLogsDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("app logs detector for security analytics", List.of(indexName), List.of(), getPrePackagedRules("others_application").stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("others_application"), List.of(), List.of(), List.of(), List.of())), "others_application"); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("others_application"), List.of(), List.of(), List.of(), List.of(), List.of())), "others_application"); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(appLogsDetector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -397,7 +397,7 @@ private String createS3Detector(String indexName) throws IOException { Detector s3AccessLogsDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("s3 access logs detector for security analytics", List.of(indexName), List.of(), getPrePackagedRules("s3").stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("s3"), List.of(), List.of(), List.of(), List.of())), "s3"); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("s3"), List.of(), List.of(), List.of(), List.of(), List.of())), "s3"); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(s3AccessLogsDetector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java index ce6634f41..c69bb2e00 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java @@ -57,7 +57,7 @@ public void testGetFindings_byDetectorId_success() throws IOException { Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -119,7 +119,7 @@ public void testGetFindings_byDetectorType_oneDetector_success() throws IOExcept Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -187,7 +187,7 @@ public void testGetFindings_byDetectorType_success() throws IOException { response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); // Detector 1 - WINDOWS - Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector1)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -210,7 +210,7 @@ public void testGetFindings_byDetectorType_success() throws IOException { getPrePackagedRules("network").stream().map(DetectorRule::new).collect(Collectors.toList())); Detector detector2 = randomDetectorWithTriggers( getPrePackagedRules("network"), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of(), List.of())), "network", inputNetflow ); @@ -286,7 +286,7 @@ public void testGetFindings_rolloverByMaxAge_success() throws IOException, Inter Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -357,7 +357,7 @@ public void testGetFindings_rolloverByMaxDoc_success() throws IOException, Inter Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -422,7 +422,7 @@ public void testGetFindings_rolloverByMaxDoc_short_retention_success() throws IO Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java index 5c28ba65b..6551f579c 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java @@ -5,6 +5,12 @@ package org.opensearch.securityanalytics.findings; +import java.io.BufferedReader; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URL; +import java.net.URLConnection; import java.time.Instant; import java.time.ZoneId; import java.util.ArrayDeque; @@ -65,7 +71,8 @@ public void testGetFindings_success() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -186,7 +193,8 @@ public void testGetFindings_getFindingsByMonitorIdFailure() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/findings/SecureFindingRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/findings/SecureFindingRestApiIT.java index 64d5b7cef..ab68eabe7 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/SecureFindingRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/SecureFindingRestApiIT.java @@ -87,7 +87,7 @@ public void testGetFindings_byDetectorId_success() throws IOException { Response response = userClient.performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(userClient, "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -206,7 +206,7 @@ public void testGetFindings_byDetectorType_success() throws IOException { createUserRolesMapping(TEST_HR_ROLE, users); // Detector 1 - WINDOWS - Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(userClient, "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector1)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -229,7 +229,7 @@ public void testGetFindings_byDetectorType_success() throws IOException { getPrePackagedRules("network").stream().map(DetectorRule::new).collect(Collectors.toList())); Detector detector2 = randomDetectorWithTriggers( getPrePackagedRules("network"), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of(), List.of())), "network", inputNetflow ); diff --git a/src/test/java/org/opensearch/securityanalytics/mapper/MapperRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/mapper/MapperRestApiIT.java index 315997a47..6e63f4296 100644 --- a/src/test/java/org/opensearch/securityanalytics/mapper/MapperRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/mapper/MapperRestApiIT.java @@ -11,6 +11,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -37,6 +38,7 @@ import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; import org.opensearch.securityanalytics.TestHelpers; +import org.opensearch.securityanalytics.action.GetMappingsViewResponse; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; @@ -353,6 +355,8 @@ public void testGetMappingsViewSuccess() throws IOException { // Verify unmapped field aliases List unmappedFieldAliases = (List) respMap.get("unmapped_field_aliases"); assertEquals(3, unmappedFieldAliases.size()); + List> iocFieldsList = (List>) respMap.get(GetMappingsViewResponse.THREAT_INTEL_FIELD_ALIASES); + assertEquals(iocFieldsList.size(), 1); } public void testGetMappingsViewLinuxSuccess() throws IOException { diff --git a/src/test/java/org/opensearch/securityanalytics/model/WriteableTests.java b/src/test/java/org/opensearch/securityanalytics/model/WriteableTests.java index e82911c1b..f12535b98 100644 --- a/src/test/java/org/opensearch/securityanalytics/model/WriteableTests.java +++ b/src/test/java/org/opensearch/securityanalytics/model/WriteableTests.java @@ -19,7 +19,7 @@ public class WriteableTests extends OpenSearchTestCase { - public void testDetectorAsStream() throws IOException { + public void testDetectorAsAStream() throws IOException { Detector detector = randomDetector(List.of()); detector.setInputs(List.of(new DetectorInput("", List.of(), List.of(), List.of()))); BytesStreamOutput out = new BytesStreamOutput(); @@ -50,7 +50,8 @@ public void testEmptyUserAsStream() throws IOException { public void testLogTypeAsStreamRawFieldOnly() throws IOException { LogType logType = new LogType( "1", "my_log_type", "description", false, - List.of(new LogType.Mapping("rawField", null, null)) + List.of(new LogType.Mapping("rawField", null, null)), + List.of(new LogType.IocFields("ip", List.of("dst.ip"))) ); BytesStreamOutput out = new BytesStreamOutput(); logType.writeTo(out); @@ -66,7 +67,8 @@ public void testLogTypeAsStreamRawFieldOnly() throws IOException { public void testLogTypeAsStreamFull() throws IOException { LogType logType = new LogType( "1", "my_log_type", "description", false, - List.of(new LogType.Mapping("rawField", "some_ecs_field", "some_ocsf_field")) + List.of(new LogType.Mapping("rawField", "some_ecs_field", "some_ocsf_field")), + List.of(new LogType.IocFields("ip", List.of("dst.ip"))) ); BytesStreamOutput out = new BytesStreamOutput(); logType.writeTo(out); @@ -80,7 +82,7 @@ public void testLogTypeAsStreamFull() throws IOException { } public void testLogTypeAsStreamNoMappings() throws IOException { - LogType logType = new LogType("1", "my_log_type", "description", false, null); + LogType logType = new LogType("1", "my_log_type", "description", false, null, null); BytesStreamOutput out = new BytesStreamOutput(); logType.writeTo(out); StreamInput sin = StreamInput.wrap(out.bytes().toBytesRef().bytes); diff --git a/src/test/java/org/opensearch/securityanalytics/model/XContentTests.java b/src/test/java/org/opensearch/securityanalytics/model/XContentTests.java index f2ec8c5cc..89f447440 100644 --- a/src/test/java/org/opensearch/securityanalytics/model/XContentTests.java +++ b/src/test/java/org/opensearch/securityanalytics/model/XContentTests.java @@ -17,8 +17,10 @@ import static org.opensearch.securityanalytics.TestHelpers.parser; import static org.opensearch.securityanalytics.TestHelpers.randomDetector; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithNoUser; +import static org.opensearch.securityanalytics.TestHelpers.randomThreatIntelFeedData; import static org.opensearch.securityanalytics.TestHelpers.randomUser; import static org.opensearch.securityanalytics.TestHelpers.randomUserEmpty; +import static org.opensearch.securityanalytics.TestHelpers.toJsonString; import static org.opensearch.securityanalytics.TestHelpers.toJsonStringWithUser; public class XContentTests extends OpenSearchTestCase { @@ -193,4 +195,12 @@ public void testDetectorParsingWithNoUser() throws IOException { Detector parsedDetector = Detector.parse(parser(detectorString), null, null); Assert.assertEquals("Round tripping Detector doesn't work", detector, parsedDetector); } + + public void testThreatIntelFeedParsing() throws IOException { + ThreatIntelFeedData tifd = randomThreatIntelFeedData(); + + String tifdString = toJsonString(tifd); + ThreatIntelFeedData parsedTifd = ThreatIntelFeedData.parse(parser(tifdString), null, null); + Assert.assertEquals("Round tripping Threat intel feed data model doesn't work", tifd, parsedTifd); + } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java index 68d3636ae..dfea4bac8 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java @@ -32,6 +32,7 @@ import java.util.Set; import java.util.stream.Collectors; +import static java.util.Collections.emptyList; import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule; import static org.opensearch.securityanalytics.TestHelpers.randomDetector; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; @@ -49,6 +50,7 @@ public class DetectorMonitorRestApiIT extends SecurityAnalyticsRestTestCase { * 2. Creates two aggregation rules and assigns to a detector, while removing 5 prepackaged rules * 3. Verifies that two bucket level monitor exists * 4. Verifies the findings + * * @throws IOException */ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() throws IOException { @@ -103,16 +105,16 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t assertEquals(1, monitorIds.size()); String monitorId = monitorIds.get(0); - String monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); + String monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); assertEquals(MonitorType.DOC_LEVEL_MONITOR.getValue(), monitorType); // Create aggregation rules - String sumRuleId = createRule(randomAggregationRule( "sum", " > 2")); - String avgTermRuleId = createRule(randomAggregationRule( "avg", " > 1")); + String sumRuleId = createRule(randomAggregationRule("sum", " > 2")); + String avgTermRuleId = createRule(randomAggregationRule("avg", " > 1")); // Update detector and empty doc level rules so detector contains only one aggregation rule DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(sumRuleId), new DetectorRule(avgTermRuleId)), - Collections.emptyList()); + emptyList()); Detector updatedDetector = randomDetectorWithInputs(List.of(input)); Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector)); @@ -133,8 +135,8 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t indexDoc(index, "2", randomDoc(3, 4, "Info")); // Execute two bucket level monitors - for(String id: monitorIds){ - monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + id))).get("monitor")).get("monitor_type"); + for (String id : monitorIds) { + monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + id))).get("monitor")).get("monitor_type"); Assert.assertEquals(MonitorType.BUCKET_LEVEL_MONITOR.getValue(), monitorType); executeAlertingMonitor(id, Collections.emptyMap()); } @@ -149,24 +151,24 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t List aggRuleIds = List.of(sumRuleId, avgTermRuleId); - List> findings = (List)getFindingsBody.get("findings"); - for(Map finding : findings) { - Set aggRulesFinding = ((List>)finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( + List> findings = (List) getFindingsBody.get("findings"); + for (Map finding : findings) { + Set aggRulesFinding = ((List>) finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( Collectors.toSet()); // Bucket monitor finding will have one rule String aggRuleId = aggRulesFinding.iterator().next(); assertTrue(aggRulesFinding.contains(aggRuleId)); - List findingDocs = (List)finding.get("related_doc_ids"); + List findingDocs = (List) finding.get("related_doc_ids"); Assert.assertEquals(2, findingDocs.size()); assertTrue(Arrays.asList("1", "2").containsAll(findingDocs)); } - String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); - String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); assertEquals(index, findingIndex); } @@ -175,6 +177,7 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t * 2. Creates 5 prepackaged doc level rules and one custom doc level rule and removes the aggregation rule * 3. Verifies that one doc level monitor exists * 4. Verifies the findings + * * @throws IOException */ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throws IOException { @@ -194,10 +197,10 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); - String maxRuleId = createRule(randomAggregationRule( "max", " > 2")); + String maxRuleId = createRule(randomAggregationRule("max", " > 2")); List detectorRules = List.of(new DetectorRule(maxRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -228,7 +231,7 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw Map detectorAsMap = (Map) hit.getSourceAsMap().get("detector"); String monitorId = ((List) (detectorAsMap).get("monitor_id")).get(0); - String monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); + String monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); assertEquals(MonitorType.BUCKET_LEVEL_MONITOR.getValue(), monitorType); @@ -255,7 +258,7 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw assertEquals(1, monitorIds.size()); monitorId = monitorIds.get(0); - monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); + monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); assertEquals(MonitorType.DOC_LEVEL_MONITOR.getValue(), monitorType); @@ -292,15 +295,15 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw Set docRuleIds = new HashSet<>(prepackagedRules); docRuleIds.add(randomDocRuleId); - List> findings = (List)getFindingsBody.get("findings"); + List> findings = (List) getFindingsBody.get("findings"); List foundDocIds = new ArrayList<>(); - for(Map finding : findings) { - Set aggRulesFinding = ((List>)finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( + for (Map finding : findings) { + Set aggRulesFinding = ((List>) finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( Collectors.toSet()); assertTrue(docRuleIds.containsAll(aggRulesFinding)); - List findingDocs = (List)finding.get("related_doc_ids"); + List findingDocs = (List) finding.get("related_doc_ids"); Assert.assertEquals(1, findingDocs.size()); foundDocIds.addAll(findingDocs); } @@ -365,11 +368,11 @@ public void testRemoveAllRulesAndUpdateDetector_success() throws IOException { assertEquals(1, monitorIds.size()); String monitorId = monitorIds.get(0); - String monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); + String monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); assertEquals(MonitorType.DOC_LEVEL_MONITOR.getValue(), monitorType); - Detector updatedDetector = randomDetector(Collections.emptyList()); + Detector updatedDetector = randomDetector(emptyList()); Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector)); assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); @@ -413,7 +416,7 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio String sumRuleId = createRule(randomAggregationRule("sum", " > 1")); List detectorRules = List.of(new DetectorRule(sumRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -421,7 +424,7 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); - String request = "{\n" + + String request = "{\n" + " \"query\" : {\n" + " \"match\":{\n" + " \"_id\": \"" + detectorId + "\"\n" + @@ -431,15 +434,15 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = detectorMap.get("inputs"); assertEquals(1, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); // Test adding the new max monitor and updating the existing sum monitor - String maxRuleId = createRule(randomAggregationRule("max", " > 3")); + String maxRuleId = createRule(randomAggregationRule("max", " > 3")); DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(maxRuleId), new DetectorRule(sumRuleId)), - Collections.emptyList()); + emptyList()); Detector updatedDetector = randomDetectorWithInputs(List.of(newInput)); Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector)); @@ -447,7 +450,7 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio hits = executeSearch(Detector.DETECTORS_INDEX, request); hit = hits.get(0); - Map updatedDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); inputArr = updatedDetectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -459,8 +462,8 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio indexDoc(index, "1", randomDoc(2, 4, "Info")); indexDoc(index, "2", randomDoc(3, 4, "Info")); - for(String monitorId: monitorIds) { - Map monitor = (Map)(entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); + for (String monitorId : monitorIds) { + Map monitor = (Map) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); assertEquals(MonitorType.BUCKET_LEVEL_MONITOR.getValue(), monitor.get("monitor_type")); executeAlertingMonitor(monitorId, Collections.emptyMap()); } @@ -486,10 +489,10 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio assertEquals(2, findingDocs.size()); assertTrue(Arrays.asList("1", "2").containsAll(findingDocs)); - String findingDetectorId = ((Map)((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); - String findingIndex = ((Map)((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); assertEquals(index, findingIndex); } @@ -525,7 +528,7 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio List detectorRules = aggRuleIds.stream().map(DetectorRule::new).collect(Collectors.toList()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -533,7 +536,7 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); - String request = "{\n" + + String request = "{\n" + " \"query\" : {\n" + " \"match\":{\n" + " \"_id\": \"" + detectorId + "\"\n" + @@ -543,14 +546,14 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = detectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); // Test deleting the aggregation rule DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(avgRuleId)), - Collections.emptyList()); + emptyList()); detector = randomDetectorWithInputs(List.of(newInput)); Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(detector)); @@ -558,7 +561,7 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio hits = executeSearch(Detector.DETECTORS_INDEX, request); hit = hits.get(0); - Map updatedDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); inputArr = updatedDetectorMap.get("inputs"); assertEquals(1, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -572,7 +575,7 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio assertEquals(1, monitorIds.size()); - Map monitor = (Map)(entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorIds.get(0))))).get("monitor"); + Map monitor = (Map) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorIds.get(0))))).get("monitor"); assertEquals(MonitorType.BUCKET_LEVEL_MONITOR.getValue(), monitor.get("monitor_type")); @@ -601,10 +604,10 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio assertEquals(2, findingDocs.size()); assertTrue(Arrays.asList("1", "2").containsAll(findingDocs)); - String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); - String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); assertEquals(index, findingIndex); } @@ -614,6 +617,7 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio * 3. Verifies that number of rules is unchanged * 4. Verifies monitor types * 5. Verifies findings + * * @throws IOException */ public void testReplaceAggregationRule_verifyFindings_success() throws IOException { @@ -649,7 +653,7 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); - String request = "{\n" + + String request = "{\n" + " \"query\" : {\n" + " \"match\":{\n" + " \"_id\": \"" + detectorId + "\"\n" + @@ -659,7 +663,7 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = detectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -675,7 +679,7 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti hits = executeSearch(Detector.DETECTORS_INDEX, request); hit = hits.get(0); - Map updatedDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); inputArr = updatedDetectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -688,8 +692,8 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti indexDoc(index, "2", randomDoc(3, 4, "Info")); indexDoc(index, "3", randomDoc(3, 4, "Test")); Map numberOfMonitorTypes = new HashMap<>(); - for(String monitorId: monitorIds) { - Map monitor = (Map)(entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); + for (String monitorId : monitorIds) { + Map monitor = (Map) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum); executeAlertingMonitor(monitorId, Collections.emptyMap()); } @@ -705,27 +709,27 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti assertNotNull(getFindingsBody); assertEquals(5, getFindingsBody.get("total_findings")); - String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); - String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); assertEquals(index, findingIndex); List docLevelFinding = new ArrayList<>(); - List> findings = (List)getFindingsBody.get("findings"); + List> findings = (List) getFindingsBody.get("findings"); Set docLevelRules = new HashSet<>(prepackagedDocRules); - for(Map finding : findings) { - List> queries = (List>)finding.get("queries"); + for (Map finding : findings) { + List> queries = (List>) finding.get("queries"); Set findingRules = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet()); // In this test case all doc level rules are matching the finding rule ids - if(docLevelRules.containsAll(findingRules)) { - docLevelFinding.addAll((List)finding.get("related_doc_ids")); + if (docLevelRules.containsAll(findingRules)) { + docLevelFinding.addAll((List) finding.get("related_doc_ids")); } else { String aggRuleId = findingRules.iterator().next(); - List findingDocs = (List)finding.get("related_doc_ids"); + List findingDocs = (List) finding.get("related_doc_ids"); Assert.assertEquals(2, findingDocs.size()); assertTrue(Arrays.asList("1", "2").containsAll(findingDocs)); } @@ -755,7 +759,7 @@ public void testMinAggregationRule_findingSuccess() throws IOException { aggRuleIds.add(createRule(randomAggregationRule("min", " > 3", testOpCode))); List detectorRules = aggRuleIds.stream().map(id -> new DetectorRule(id)).collect(Collectors.toList()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -763,7 +767,7 @@ public void testMinAggregationRule_findingSuccess() throws IOException { Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); - String request = "{\n" + + String request = "{\n" + " \"query\" : {\n" + " \"match\":{\n" + " \"_id\": \"" + detectorId + "\"\n" + @@ -773,7 +777,7 @@ public void testMinAggregationRule_findingSuccess() throws IOException { List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List monitorIds = ((List) (detectorMap).get("monitor_id")); @@ -784,8 +788,8 @@ public void testMinAggregationRule_findingSuccess() throws IOException { indexDoc(index, "8", randomDoc(1, 1, testOpCode)); Map numberOfMonitorTypes = new HashMap<>(); - for (String monitorId: monitorIds) { - Map monitor = (Map)(entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); + for (String monitorId : monitorIds) { + Map monitor = (Map) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum); executeAlertingMonitor(monitorId, Collections.emptyMap()); } @@ -798,17 +802,17 @@ public void testMinAggregationRule_findingSuccess() throws IOException { assertNotNull(getFindingsBody); - List> findings = (List)getFindingsBody.get("findings"); + List> findings = (List) getFindingsBody.get("findings"); for (Map finding : findings) { - List findingDocs = (List)finding.get("related_doc_ids"); + List findingDocs = (List) finding.get("related_doc_ids"); Assert.assertEquals(1, findingDocs.size()); assertTrue(Arrays.asList("7").containsAll(findingDocs)); } - String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); - String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); assertEquals(index, findingIndex); } @@ -843,10 +847,10 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti // 5 custom aggregation rules String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode)); - String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); - String minRuleId = createRule(randomAggregationRule("min", " > 3", testOpCode)); - String avgRuleId = createRule(randomAggregationRule("avg", " > 3", infoOpCode)); - String cntRuleId = createRule(randomAggregationRule("count", " > 3", "randomTestCode")); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String minRuleId = createRule(randomAggregationRule("min", " > 3", testOpCode)); + String avgRuleId = createRule(randomAggregationRule("avg", " > 3", infoOpCode)); + String cntRuleId = createRule(randomAggregationRule("count", " > 3", "randomTestCode")); List aggRuleIds = List.of(sumRuleId, maxRuleId); String randomDocRuleId = createRule(randomRule()); List prepackagedRules = getRandomPrePackagedRules(); @@ -861,7 +865,6 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); - String request = "{\n" + " \"query\" : {\n" + " \"match_all\":{\n" + @@ -884,7 +887,7 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map updatedDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = updatedDetectorMap.get("inputs"); assertEquals(6, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -904,8 +907,8 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti Map numberOfMonitorTypes = new HashMap<>(); - for (String monitorId: monitorIds) { - Map monitor = (Map)(entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); + for (String monitorId : monitorIds) { + Map monitor = (Map) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum); Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); @@ -916,16 +919,15 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti // 5 prepackaged and 1 custom doc level rule assertEquals(6, noOfSigmaRuleMatches); } else if (MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type"))) { - for(String ruleId: aggRuleIds) { - Object rule = (((Map)((Map)((List)((Map)executeResults.get("input_results")).get("results")).get(0)).get("aggregations")).get(ruleId)); - if(rule != null) { - if(ruleId == sumRuleId) { - assertRuleMonitorFinding(executeResults, ruleId,3, List.of("4")); + for (String ruleId : aggRuleIds) { + Object rule = (((Map) ((Map) ((List) ((Map) executeResults.get("input_results")).get("results")).get(0)).get("aggregations")).get(ruleId)); + if (rule != null) { + if (ruleId == sumRuleId) { + assertRuleMonitorFinding(executeResults, ruleId, 3, List.of("4")); } else if (ruleId == maxRuleId) { - assertRuleMonitorFinding(executeResults, ruleId,5, List.of("2", "3")); - } - else if (ruleId == minRuleId) { - assertRuleMonitorFinding(executeResults, ruleId,1, List.of("2")); + assertRuleMonitorFinding(executeResults, ruleId, 5, List.of("2", "3")); + } else if (ruleId == minRuleId) { + assertRuleMonitorFinding(executeResults, ruleId, 1, List.of("2")); } } } @@ -945,10 +947,10 @@ else if (ruleId == minRuleId) { // 8 findings from doc level rules, and 3 findings for aggregation (sum, max and min) assertEquals(11, getFindingsBody.get("total_findings")); - String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); - String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); assertEquals(index, findingIndex); List docLevelFinding = new ArrayList<>(); @@ -957,22 +959,22 @@ else if (ruleId == minRuleId) { Set docLevelRules = new HashSet<>(prepackagedRules); docLevelRules.add(randomDocRuleId); - for(Map finding : findings) { - List> queries = (List>)finding.get("queries"); + for (Map finding : findings) { + List> queries = (List>) finding.get("queries"); Set findingRuleIds = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet()); // Doc level finding matches all doc level rules (including the custom one) in this test case - if(docLevelRules.containsAll(findingRuleIds)) { - docLevelFinding.addAll((List)finding.get("related_doc_ids")); + if (docLevelRules.containsAll(findingRuleIds)) { + docLevelFinding.addAll((List) finding.get("related_doc_ids")); } else { // In the case of bucket level monitors, queries will always contain one value String aggRuleId = findingRuleIds.iterator().next(); - List findingDocs = (List)finding.get("related_doc_ids"); + List findingDocs = (List) finding.get("related_doc_ids"); - if(aggRuleId.equals(sumRuleId)) { + if (aggRuleId.equals(sumRuleId)) { assertTrue(List.of("1", "2", "3").containsAll(findingDocs)); - } else if(aggRuleId.equals(maxRuleId)) { + } else if (aggRuleId.equals(maxRuleId)) { assertTrue(List.of("4", "5", "6", "7").containsAll(findingDocs)); - } else if(aggRuleId.equals( minRuleId)) { + } else if (aggRuleId.equals(minRuleId)) { assertTrue(List.of("7").containsAll(findingDocs)); } } @@ -1001,11 +1003,11 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithoutGroupByRule String testOpCode = "Test"; - String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); String randomDocRuleId = createRule(randomRule()); List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -1033,7 +1035,7 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithoutGroupByRule "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = (List) detectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -1048,8 +1050,6 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithoutGroupByRule verifyWorkflow(detectorMap, monitorIds, 2); } - - public void testCreateDetector_verifyWorkflowCreation_success_WithGroupByRulesInTrigger() throws IOException { updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); String index = createTestIndex(randomIndex(), windowsIndexMapping()); @@ -1070,12 +1070,12 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithGroupByRulesIn String testOpCode = "Test"; - String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); String randomDocRuleId = createRule(randomRule()); List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); - DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(maxRuleId), List.of(), List.of(), List.of()); + emptyList()); + DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(maxRuleId), List.of(), List.of(), List.of(), List.of()); Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -1103,7 +1103,7 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithGroupByRulesIn "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = (List) detectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -1141,7 +1141,7 @@ public void testUpdateDetector_disabledWorkflowUsage_verifyWorkflowNotCreated_su List detectorRules = List.of(new DetectorRule(randomDocRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -1169,7 +1169,7 @@ public void testUpdateDetector_disabledWorkflowUsage_verifyWorkflowNotCreated_su "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List monitorIds = ((List) (detectorMap).get("monitor_id")); assertEquals(1, monitorIds.size()); @@ -1184,7 +1184,7 @@ public void testUpdateDetector_disabledWorkflowUsage_verifyWorkflowNotCreated_su assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); hits = executeSearch(Detector.DETECTORS_INDEX, request); hit = hits.get(0); - detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); // Verify that the workflow for the given detector is not added assertTrue("Workflow created", ((List) detectorMap.get("workflow_ids")).size() == 0); @@ -1212,13 +1212,13 @@ public void testUpdateDetector_removeRule_verifyWorkflowUpdate_success() throws String testOpCode = "Test"; - String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); String randomDocRuleId = createRule(randomRule()); List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); - DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of()); + DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of(), List.of()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -1246,7 +1246,7 @@ public void testUpdateDetector_removeRule_verifyWorkflowUpdate_success() throws "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = (List) detectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -1261,14 +1261,14 @@ public void testUpdateDetector_removeRule_verifyWorkflowUpdate_success() throws verifyWorkflow(detectorMap, monitorIds, 3); // Update detector - remove one agg rule; Verify workflow - DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), Arrays.asList(new DetectorRule(randomDocRuleId)) , getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); + DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), Arrays.asList(new DetectorRule(randomDocRuleId)), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); detector = randomDetectorWithInputs(List.of(newInput)); createResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(detector)); assertEquals("Update detector failed", RestStatus.OK, restStatus(createResponse)); hits = executeSearch(Detector.DETECTORS_INDEX, request); hit = hits.get(0); - detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); inputArr = (List) detectorMap.get("inputs"); assertEquals(1, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -1302,13 +1302,13 @@ public void testUpdateDetector_removeRule_verifyWorkflowUpdate_success() throws assertNotNull(getFindingsBody); assertEquals(1, getFindingsBody.get("total_findings")); - String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); - String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); assertEquals(index, findingIndex); - List> findings = (List)getFindingsBody.get("findings"); + List> findings = (List) getFindingsBody.get("findings"); assertEquals(1, findings.size()); List findingDocs = (List) findings.get(0).get("related_doc_ids"); @@ -1336,13 +1336,13 @@ public void testCreateDetector_workflowWithDuplicateMonitor_failure() throws IOE String testOpCode = "Test"; - String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); String randomDocRuleId = createRule(randomRule()); List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); - DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of()); + DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of(), List.of()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -1370,7 +1370,7 @@ public void testCreateDetector_workflowWithDuplicateMonitor_failure() throws IOE "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = (List) detectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -1405,14 +1405,14 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor String testOpCode = "Test"; - String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); String randomDocRuleId = createRule(randomRule()); List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); DetectorTrigger t1, t2; - t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of()); + t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of(), List.of()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + emptyList()); Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -1440,7 +1440,7 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = (List) detectorMap.get("inputs"); assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -1485,21 +1485,21 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor assertNotNull(getFindingsBody); assertEquals(6, getFindingsBody.get("total_findings")); - String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); - String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); assertEquals(index, findingIndex); List docLevelFinding = new ArrayList<>(); - List> findings = (List)getFindingsBody.get("findings"); + List> findings = (List) getFindingsBody.get("findings"); Set docLevelRules = new HashSet<>(List.of(randomDocRuleId)); - for(Map finding : findings) { + for (Map finding : findings) { List> queries = (List>) finding.get("queries"); Set findingRules = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet()); // In this test case all doc level rules are matching the finding rule ids - if(docLevelRules.containsAll(findingRules)) { + if (docLevelRules.containsAll(findingRules)) { docLevelFinding.addAll((List) finding.get("related_doc_ids")); } else { List findingDocs = (List) finding.get("related_doc_ids"); @@ -1533,10 +1533,10 @@ public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLeve // 5 custom aggregation rules String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode)); - String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); - String minRuleId = createRule(randomAggregationRule("min", " > 3", testOpCode)); - String avgRuleId = createRule(randomAggregationRule("avg", " > 3", infoOpCode)); - String cntRuleId = createRule(randomAggregationRule("count", " > 3", "randomTestCode")); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String minRuleId = createRule(randomAggregationRule("min", " > 3", testOpCode)); + String avgRuleId = createRule(randomAggregationRule("avg", " > 3", infoOpCode)); + String cntRuleId = createRule(randomAggregationRule("count", " > 3", "randomTestCode")); String randomDocRuleId = createRule(randomRule()); List prepackagedRules = getRandomPrePackagedRules(); @@ -1546,8 +1546,8 @@ public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLeve DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, prepackagedRules.stream().map(DetectorRule::new).collect(Collectors.toList())); DetectorTrigger t1, t2; - t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(sumRuleId, maxRuleId), List.of(), List.of(), List.of()); - t2 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(minRuleId, avgRuleId, cntRuleId), List.of(), List.of(), List.of()); + t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(sumRuleId, maxRuleId), List.of(), List.of(), List.of(), List.of()); + t2 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(minRuleId, avgRuleId, cntRuleId), List.of(), List.of(), List.of(), List.of()); Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1, t2)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -1575,7 +1575,7 @@ public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLeve "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); List inputArr = (List) detectorMap.get("inputs"); assertEquals(6, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); @@ -1620,19 +1620,19 @@ public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLeve for (Map runResult : monitorRunResults) { String monitorName = runResult.get("monitor_name").toString(); String monitorId = monitorNameToIdMap.get(monitorName); - if(monitorId.equals(docMonitorId)){ + if (monitorId.equals(docMonitorId)) { int noOfSigmaRuleMatches = ((List>) ((Map) runResult.get("input_results")).get("results")).get(0).size(); // 5 prepackaged and 1 custom doc level rule assertEquals(6, noOfSigmaRuleMatches); - } else if(monitorId.equals(chainedFindingsMonitorId)) { + } else if (monitorId.equals(chainedFindingsMonitorId)) { } else { Map trigger_results = (Map) runResult.get("trigger_results"); if (trigger_results.containsKey(maxRuleId)) { assertRuleMonitorFinding(runResult, maxRuleId, 5, List.of("2", "3")); - } else if( trigger_results.containsKey(sumRuleId)) { + } else if (trigger_results.containsKey(sumRuleId)) { assertRuleMonitorFinding(runResult, sumRuleId, 3, List.of("4")); - } else if( trigger_results.containsKey(minRuleId)) { + } else if (trigger_results.containsKey(minRuleId)) { assertRuleMonitorFinding(runResult, minRuleId, 5, List.of("2")); } } @@ -1650,11 +1650,11 @@ public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLeve private static void assertRuleMonitorFinding(Map executeResults, String ruleId, int expectedDocCount, List expectedTriggerResult) { - List> buckets = ((List>)(((Map)((Map)((Map)((List)((Map) executeResults.get("input_results")).get("results")).get(0)).get("aggregations")).get("result_agg")).get("buckets"))); - Integer docCount = buckets.stream().mapToInt(it -> (Integer)it.get("doc_count")).sum(); + List> buckets = ((List>) (((Map) ((Map) ((Map) ((List) ((Map) executeResults.get("input_results")).get("results")).get(0)).get("aggregations")).get("result_agg")).get("buckets"))); + Integer docCount = buckets.stream().mapToInt(it -> (Integer) it.get("doc_count")).sum(); assertEquals(expectedDocCount, docCount.intValue()); - List triggerResultBucketKeys = ((Map)((Map) ((Map)executeResults.get("trigger_results")).get(ruleId)).get("agg_result_buckets")).keySet().stream().collect(Collectors.toList()); + List triggerResultBucketKeys = ((Map) ((Map) ((Map) executeResults.get("trigger_results")).get(ruleId)).get("agg_result_buckets")).keySet().stream().collect(Collectors.toList()); Assert.assertEquals(expectedTriggerResult, triggerResultBucketKeys); } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java index 83ff51928..2059fb191 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java @@ -63,7 +63,7 @@ public void testNewLogTypes() throws IOException { Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("github"), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("github"), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -87,7 +87,7 @@ public void testDeletingADetector_MonitorNotExists() throws IOException { Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); // Create detector #1 of type test_windows - Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); String detectorId1 = createDetector(detector1); String request = "{\n" + @@ -129,7 +129,7 @@ public void testCreatingADetector() throws IOException { Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -187,7 +187,7 @@ public void testCreatingADetectorScheduledJobFinding() throws IOException, Inter assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); Detector detector = randomDetectorWithTriggersAndScheduleAndEnabled(getRandomPrePackagedRules(), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of())), new IntervalSchedule(1, ChronoUnit.MINUTES, null), true); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -288,7 +288,7 @@ public void testCreatingADetectorWithMultipleIndices() throws IOException { Detector detector = randomDetectorWithTriggers( getRandomPrePackagedRules(), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of())), List.of(index1, index2) ); @@ -346,7 +346,7 @@ public void testCreatingADetectorWithMultipleIndices() throws IOException { } public void testCreatingADetectorWithIndexNotExists() throws IOException { - Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); try { makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -878,7 +878,7 @@ public void testDeletingADetector_single_ruleTopicIndex() throws IOException { Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); // Create detector #1 of type test_windows - Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); String detectorId1 = createDetector(detector1); String request = "{\n" + @@ -901,7 +901,7 @@ public void testDeletingADetector_single_ruleTopicIndex() throws IOException { int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); Assert.assertEquals(5, noOfSigmaRuleMatches); // Create detector #2 of type windows - Detector detector2 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector2 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); String detectorId2 = createDetector(detector2); request = "{\n" + @@ -972,7 +972,7 @@ public void testDeletingADetector_single_Monitor() throws IOException { Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); // Create detector #1 of type test_windows - Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); String detectorId1 = createDetector(detector1); String request = "{\n" + @@ -999,7 +999,7 @@ public void testDeletingADetector_single_Monitor() throws IOException { int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); Assert.assertEquals(5, noOfSigmaRuleMatches); // Create detector #2 of type windows - Detector detector2 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector2 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); String detectorId2 = createDetector(detector2); request = "{\n" + @@ -1082,7 +1082,7 @@ public void testDeletingADetector_single_Monitor_workflow_enabled() throws IOExc Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); // Create detector #1 of type test_windows - Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); String detectorId1 = createDetector(detector1); String request = "{\n" + @@ -1109,7 +1109,7 @@ public void testDeletingADetector_single_Monitor_workflow_enabled() throws IOExc int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); Assert.assertEquals(5, noOfSigmaRuleMatches); // Create detector #2 of type windows - Detector detector2 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + Detector detector2 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); String detectorId2 = createDetector(detector2); request = "{\n" + @@ -1187,7 +1187,7 @@ public void testDeletingADetector_oneDetectorType_multiple_ruleTopicIndex() thro // Create detector #1 of type test_windows Detector detector1 = randomDetectorWithTriggers( getRandomPrePackagedRules(), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of())), List.of(index1) ); String detectorId1 = createDetector(detector1); @@ -1195,7 +1195,7 @@ public void testDeletingADetector_oneDetectorType_multiple_ruleTopicIndex() thro // Create detector #2 of type test_windows Detector detector2 = randomDetectorWithTriggers( getRandomPrePackagedRules(), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of())), List.of(index2) ); @@ -1483,7 +1483,7 @@ public void testDetector_withDatastream_withTemplateField_endToEnd_success() thr // Create detector Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of(datastream), List.of(), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(), List.of(), List.of("attack.defense_evasion"), List.of()))); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(), List.of(), List.of("attack.defense_evasion"), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -1576,7 +1576,7 @@ public void testDetector_withAlias_endToEnd_success() throws IOException { // Create detector Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of(indexAlias), List.of(), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(), List.of(), List.of("attack.defense_evasion"), List.of()))); + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(), List.of(), List.of("attack.defense_evasion"), List.of(), List.of()))); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestCase.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestCase.java new file mode 100644 index 000000000..20d36ab2d --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestCase.java @@ -0,0 +1,273 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel; + +import org.junit.After; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.routing.RoutingTable; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Randomness; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.ingest.IngestMetadata; +import org.opensearch.ingest.IngestService; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.common.TIFJobState; +import org.opensearch.securityanalytics.threatIntel.common.TIFLockService; +import org.opensearch.securityanalytics.threatIntel.feedMetadata.BuiltInTIFMetadataLoader; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameter; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameterService; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobUpdateService; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskListener; +import org.opensearch.test.client.NoOpNodeClient; +import org.opensearch.test.rest.RestActionTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import org.opensearch.securityanalytics.TestHelpers; + +public abstract class ThreatIntelTestCase extends RestActionTestCase { + @Mock + protected ClusterService clusterService; + @Mock + protected TIFJobUpdateService tifJobUpdateService; + @Mock + protected TIFJobParameterService tifJobParameterService; + @Mock + protected BuiltInTIFMetadataLoader builtInTIFMetadataLoader; + @Mock + protected ThreatIntelFeedDataService threatIntelFeedDataService; + @Mock + protected ClusterState clusterState; + @Mock + protected Metadata metadata; + @Mock + protected IngestService ingestService; + @Mock + protected ActionFilters actionFilters; + @Mock + protected ThreadPool threadPool; + @Mock + protected TIFLockService tifLockService; + @Mock + protected RoutingTable routingTable; + @Mock + protected TransportService transportService; + protected IngestMetadata ingestMetadata; + protected NoOpNodeClient client; + protected VerifyingClient verifyingClient; + protected LockService lockService; + protected ClusterSettings clusterSettings; + protected Settings settings; + private AutoCloseable openMocks; + @Mock + protected DetectorThreatIntelService detectorThreatIntelService; + @Mock + protected TIFJobParameter tifJobParameter; + + @Before + public void prepareThreatIntelTestCase() { + openMocks = MockitoAnnotations.openMocks(this); + settings = Settings.EMPTY; + client = new NoOpNodeClient(this.getTestName()); + verifyingClient = spy(new VerifyingClient(this.getTestName())); + clusterSettings = new ClusterSettings(settings, new HashSet<>(SecurityAnalyticsSettings.settings())); + lockService = new LockService(client, clusterService); + ingestMetadata = new IngestMetadata(Collections.emptyMap()); + when(metadata.custom(IngestMetadata.TYPE)).thenReturn(ingestMetadata); + when(clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterState.getMetadata()).thenReturn(metadata); + when(clusterState.routingTable()).thenReturn(routingTable); + when(ingestService.getClusterService()).thenReturn(clusterService); + when(threadPool.generic()).thenReturn(OpenSearchExecutors.newDirectExecutorService()); + detectorThreatIntelService = new DetectorThreatIntelService(threatIntelFeedDataService, client, xContentRegistry()); + } + + @After + public void clean() throws Exception { + openMocks.close(); + client.close(); + verifyingClient.close(); + } + + protected TIFJobState randomStateExcept(TIFJobState state) { + assertNotNull(state); + return Arrays.stream(TIFJobState.values()) + .sequential() + .filter(s -> !s.equals(state)) + .collect(Collectors.toList()) + .get(Randomness.createSecure().nextInt(TIFJobState.values().length - 2)); + } + + protected TIFJobState randomState() { + return Arrays.stream(TIFJobState.values()) + .sequential() + .collect(Collectors.toList()) + .get(Randomness.createSecure().nextInt(TIFJobState.values().length - 1)); + } + + protected long randomPositiveLong() { + long value = Randomness.get().nextLong(); + return value < 0 ? -value : value; + } + + /** + * Update interval should be > 0 and < validForInDays. + * For an update test to work, there should be at least one eligible value other than current update interval. + * Therefore, the smallest value for validForInDays is 2. + * Update interval is random value from 1 to validForInDays - 2. + * The new update value will be validForInDays - 1. + */ + protected TIFJobParameter randomTifJobParameter(final Instant updateStartTime) { + Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); + TIFJobParameter tifJobParameter = new TIFJobParameter(); + tifJobParameter.setName(TestHelpers.randomLowerCaseString()); + tifJobParameter.setSchedule( + new IntervalSchedule( + updateStartTime.truncatedTo(ChronoUnit.MILLIS), + 1, + ChronoUnit.DAYS + ) + ); + tifJobParameter.setState(randomState()); + tifJobParameter.setIndices(Arrays.asList(TestHelpers.randomLowerCaseString(), TestHelpers.randomLowerCaseString())); + tifJobParameter.getUpdateStats().setLastSkippedAt(now); + tifJobParameter.getUpdateStats().setLastSucceededAt(now); + tifJobParameter.getUpdateStats().setLastFailedAt(now); + tifJobParameter.getUpdateStats().setLastProcessingTimeInMillis(randomPositiveLong()); + tifJobParameter.setLastUpdateTime(now); + if (Randomness.get().nextInt() % 2 == 0) { + tifJobParameter.enable(); + } else { + tifJobParameter.disable(); + } + return tifJobParameter; + } + + protected TIFJobParameter randomTifJobParameter() { + return randomTifJobParameter(Instant.now()); + } + + protected LockModel randomLockModel() { + LockModel lockModel = new LockModel( + TestHelpers.randomLowerCaseString(), + TestHelpers.randomLowerCaseString(), + Instant.now(), + randomPositiveLong(), + false + ); + return lockModel; + } + + /** + * Temporary class of VerifyingClient until this PR(https://github.com/opensearch-project/OpenSearch/pull/7167) + * is merged in OpenSearch core + */ + public static class VerifyingClient extends NoOpNodeClient { + AtomicReference executeVerifier = new AtomicReference<>(); + AtomicReference executeLocallyVerifier = new AtomicReference<>(); + + public VerifyingClient(String testName) { + super(testName); + reset(); + } + + /** + * Clears any previously set verifier functions set by {@link #setExecuteVerifier(BiFunction)} and/or + * {@link #setExecuteLocallyVerifier(BiFunction)}. These functions are replaced with functions which will throw an + * {@link AssertionError} if called. + */ + public void reset() { + executeVerifier.set((arg1, arg2) -> { throw new AssertionError(); }); + executeLocallyVerifier.set((arg1, arg2) -> { throw new AssertionError(); }); + } + + /** + * Sets the function that will be called when {@link #doExecute(ActionType, ActionRequest, ActionListener)} is called. The given + * function should return either a subclass of {@link ActionResponse} or {@code null}. + * @param verifier A function which is called in place of {@link #doExecute(ActionType, ActionRequest, ActionListener)} + */ + public void setExecuteVerifier( + BiFunction, Request, Response> verifier + ) { + executeVerifier.set(verifier); + } + + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + listener.onResponse((Response) executeVerifier.get().apply(action, request)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Sets the function that will be called when {@link #executeLocally(ActionType, ActionRequest, TaskListener)}is called. The given + * function should return either a subclass of {@link ActionResponse} or {@code null}. + * @param verifier A function which is called in place of {@link #executeLocally(ActionType, ActionRequest, TaskListener)} + */ + public void setExecuteLocallyVerifier( + BiFunction, Request, Response> verifier + ) { + executeLocallyVerifier.set(verifier); + } + + @Override + public Task executeLocally( + ActionType action, + Request request, + ActionListener listener + ) { + listener.onResponse((Response) executeLocallyVerifier.get().apply(action, request)); + return null; + } + + @Override + public Task executeLocally( + ActionType action, + Request request, + TaskListener listener + ) { + listener.onResponse(null, (Response) executeLocallyVerifier.get().apply(action, request)); + return null; + } + + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobRequestTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobRequestTests.java new file mode 100644 index 000000000..baa18695d --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/action/PutTIFJobRequestTests.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.action; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.TestHelpers; + + +public class PutTIFJobRequestTests extends ThreatIntelTestCase { + + public void testValidate_whenValidInput_thenSucceed() { + String tifJobParameterName = TestHelpers.randomLowerCaseString(); + PutTIFJobRequest request = new PutTIFJobRequest(tifJobParameterName, clusterSettings.get(SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL)); + + assertNull(request.validate()); + } + + public void testValidate_whenInvalidTIFJobParameterName_thenFails() { + String invalidName = "_" + TestHelpers.randomLowerCaseString(); + PutTIFJobRequest request = new PutTIFJobRequest(invalidName, clusterSettings.get(SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL)); + + // Run + ActionRequestValidationException exception = request.validate(); + + // Verify + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("must not")); + } + + public void testStreamInOut_whenValidInput_thenSucceed() throws Exception { + String tifJobParameterName = TestHelpers.randomLowerCaseString(); + PutTIFJobRequest request = new PutTIFJobRequest(tifJobParameterName, clusterSettings.get(SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL)); + + // Run + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + BytesStreamInput input = new BytesStreamInput(output.bytes().toBytesRef().bytes); + PutTIFJobRequest copiedRequest = new PutTIFJobRequest(input); + + // Verify + assertEquals(request.getName(), copiedRequest.getName()); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobActionTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobActionTests.java new file mode 100644 index 000000000..68dcbf527 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobActionTests.java @@ -0,0 +1,161 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.action; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.StepListener; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.common.TIFJobState; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameter; +import org.opensearch.tasks.Task; +import org.opensearch.securityanalytics.TestHelpers; + +import java.io.IOException; +import java.util.ConcurrentModificationException; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +public class TransportPutTIFJobActionTests extends ThreatIntelTestCase { + private TransportPutTIFJobAction action; + + @Before + public void init() { + action = new TransportPutTIFJobAction( + transportService, + actionFilters, + threadPool, + tifJobParameterService, + tifJobUpdateService, + tifLockService + ); + } + + public void testDoExecute_whenFailedToAcquireLock_thenError() throws IOException { + validateDoExecute(null, null, null); + } + + public void testDoExecute_whenAcquiredLock_thenSucceed() throws IOException { + validateDoExecute(randomLockModel(), null, null); + } + + public void testDoExecute_whenExceptionBeforeAcquiringLock_thenError() throws IOException { + validateDoExecute(randomLockModel(), new RuntimeException(), null); + } + + public void testDoExecute_whenExceptionAfterAcquiringLock_thenError() throws IOException { + validateDoExecute(randomLockModel(), null, new RuntimeException()); + } + + private void validateDoExecute(final LockModel lockModel, final Exception before, final Exception after) throws IOException { + Task task = mock(Task.class); + TIFJobParameter tifJobParameter = randomTifJobParameter(); + + PutTIFJobRequest request = new PutTIFJobRequest(tifJobParameter.getName(), clusterSettings.get(SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL)); + ActionListener listener = mock(ActionListener.class); + if (after != null) { + doThrow(after).when(tifJobParameterService).createJobIndexIfNotExists(any(StepListener.class)); + } + + // Run + action.doExecute(task, request, listener); + + // Verify + ArgumentCaptor> captor = ArgumentCaptor.forClass(ActionListener.class); + verify(tifLockService).acquireLock(eq(tifJobParameter.getName()), anyLong(), captor.capture()); + + if (before == null) { + // Run + captor.getValue().onResponse(lockModel); + + // Verify + if (lockModel == null) { + verify(listener).onFailure(any(ConcurrentModificationException.class)); + } + if (after != null) { + verify(tifLockService).releaseLock(eq(lockModel)); + verify(listener).onFailure(after); + } else { + verify(tifLockService, never()).releaseLock(eq(lockModel)); + } + } else { + // Run + captor.getValue().onFailure(before); + // Verify + verify(listener).onFailure(before); + } + } + + public void testInternalDoExecute_whenValidInput_thenSucceed() { + PutTIFJobRequest request = new PutTIFJobRequest(TestHelpers.randomLowerCaseString(), clusterSettings.get(SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL)); + ActionListener listener = mock(ActionListener.class); + + // Run + action.internalDoExecute(request, randomLockModel(), listener); + + // Verify + ArgumentCaptor captor = ArgumentCaptor.forClass(StepListener.class); + verify(tifJobParameterService).createJobIndexIfNotExists(captor.capture()); + + // Run + captor.getValue().onResponse(null); + // Verify + ArgumentCaptor tifJobCaptor = ArgumentCaptor.forClass(TIFJobParameter.class); + ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + verify(tifJobParameterService).saveTIFJobParameter(tifJobCaptor.capture(), actionListenerCaptor.capture()); + assertEquals(request.getName(), tifJobCaptor.getValue().getName()); + + // Run next listener.onResponse + actionListenerCaptor.getValue().onResponse(null); + // Verify + verify(listener).onResponse(new AcknowledgedResponse(true)); + } + + public void testCreateTIFJobParameter_whenInvalidState_thenUpdateStateAsFailed() throws IOException { + TIFJobParameter tifJob = new TIFJobParameter(); + tifJob.setState(randomStateExcept(TIFJobState.CREATING)); + tifJob.getUpdateStats().setLastFailedAt(null); + + // Run + action.createThreatIntelFeedData(tifJob, mock(Runnable.class)); + + // Verify + assertEquals(TIFJobState.CREATE_FAILED, tifJob.getState()); + assertNotNull(tifJob.getUpdateStats().getLastFailedAt()); + verify(tifJobParameterService).updateJobSchedulerParameter(tifJob); + verify(tifJobUpdateService, never()).createThreatIntelFeedData(any(TIFJobParameter.class), any(Runnable.class)); + } + + public void testCreateTIFJobParameter_whenExceptionHappens_thenUpdateStateAsFailed() throws IOException { + TIFJobParameter tifJob = new TIFJobParameter(); + doThrow(new RuntimeException()).when(tifJobUpdateService).createThreatIntelFeedData(any(TIFJobParameter.class), any(Runnable.class)); + + // Run + action.createThreatIntelFeedData(tifJob, mock(Runnable.class)); + + // Verify + assertEquals(TIFJobState.CREATE_FAILED, tifJob.getState()); + assertNotNull(tifJob.getUpdateStats().getLastFailedAt()); + verify(tifJobParameterService).updateJobSchedulerParameter(tifJob); + } + + public void testCreateTIFJobParameter_whenValidInput_thenUpdateStateAsCreating() throws IOException { + TIFJobParameter tifJob = new TIFJobParameter(); + + Runnable renewLock = mock(Runnable.class); + // Run + action.createThreatIntelFeedData(tifJob, renewLock); + + // Verify + verify(tifJobUpdateService).createThreatIntelFeedData(tifJob, renewLock); + assertEquals(TIFJobState.CREATING, tifJob.getState()); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java new file mode 100644 index 000000000..4b6423a3e --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +import static org.mockito.Mockito.mock; +import static org.opensearch.securityanalytics.threatIntel.common.TIFLockService.LOCK_DURATION_IN_SECONDS; +import static org.opensearch.securityanalytics.threatIntel.common.TIFLockService.RENEW_AFTER_IN_SECONDS; + +import java.time.Instant; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Before; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.TestHelpers; + +public class ThreatIntelLockServiceTests extends ThreatIntelTestCase { + private TIFLockService threatIntelLockService; + private TIFLockService noOpsLockService; + + @Before + public void init() { + threatIntelLockService = new TIFLockService(clusterService, verifyingClient); + noOpsLockService = new TIFLockService(clusterService, client); + } + + public void testAcquireLock_whenValidInput_thenSucceed() { + // Cannot test because LockService is final class + // Simply calling method to increase coverage + noOpsLockService.acquireLock(TestHelpers.randomLowerCaseString(), randomPositiveLong(), mock(ActionListener.class)); + } + + public void testAcquireLock_whenCalled_thenNotBlocked() { + long expectedDurationInMillis = 1000; + Instant before = Instant.now(); + assertTrue(threatIntelLockService.acquireLock(null, null).isEmpty()); + Instant after = Instant.now(); + assertTrue(after.toEpochMilli() - before.toEpochMilli() < expectedDurationInMillis); + } + + public void testReleaseLock_whenValidInput_thenSucceed() { + // Cannot test because LockService is final class + // Simply calling method to increase coverage + noOpsLockService.releaseLock(null); + } + + public void testRenewLock_whenCalled_thenNotBlocked() { + long expectedDurationInMillis = 1000; + Instant before = Instant.now(); + assertNull(threatIntelLockService.renewLock(null)); + Instant after = Instant.now(); + assertTrue(after.toEpochMilli() - before.toEpochMilli() < expectedDurationInMillis); + } + + public void testGetRenewLockRunnable_whenLockIsFresh_thenDoNotRenew() { + LockModel lockModel = new LockModel( + TestHelpers.randomLowerCaseString(), + TestHelpers.randomLowerCaseString(), + Instant.now(), + LOCK_DURATION_IN_SECONDS, + false + ); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verifying + assertTrue(actionRequest instanceof UpdateRequest); + return new UpdateResponse( + mock(ShardId.class), + TestHelpers.randomLowerCaseString(), + randomPositiveLong(), + randomPositiveLong(), + randomPositiveLong(), + DocWriteResponse.Result.UPDATED + ); + }); + + AtomicReference reference = new AtomicReference<>(lockModel); + threatIntelLockService.getRenewLockRunnable(reference).run(); + assertEquals(lockModel, reference.get()); + } + + public void testGetRenewLockRunnable_whenLockIsStale_thenRenew() { + LockModel lockModel = new LockModel( + TestHelpers.randomLowerCaseString(), + TestHelpers.randomLowerCaseString(), + Instant.now().minusSeconds(RENEW_AFTER_IN_SECONDS), + LOCK_DURATION_IN_SECONDS, + false + ); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verifying + assertTrue(actionRequest instanceof UpdateRequest); + return new UpdateResponse( + mock(ShardId.class), + TestHelpers.randomLowerCaseString(), + randomPositiveLong(), + randomPositiveLong(), + randomPositiveLong(), + DocWriteResponse.Result.UPDATED + ); + }); + + AtomicReference reference = new AtomicReference<>(lockModel); + threatIntelLockService.getRenewLockRunnable(reference).run(); + assertNotEquals(lockModel, reference.get()); + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/integTests/TIFJobExtensionPluginIT.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/integTests/TIFJobExtensionPluginIT.java new file mode 100644 index 000000000..ff682e6dd --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/integTests/TIFJobExtensionPluginIT.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.securityanalytics.threatIntel.integTests; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.cluster.health.ClusterHealthRequest; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; +import org.opensearch.cluster.health.ClusterHealthStatus; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobRunner; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.junit.Assert; + +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class TIFJobExtensionPluginIT extends OpenSearchIntegTestCase { + private static final Logger log = LogManager.getLogger(TIFJobExtensionPluginIT.class); + + public void testPluginsAreInstalled() { + ClusterHealthRequest request = new ClusterHealthRequest(); + ClusterHealthResponse response = OpenSearchIntegTestCase.client().admin().cluster().health(request).actionGet(); + Assert.assertEquals(ClusterHealthStatus.GREEN, response.getStatus()); + + NodesInfoRequest nodesInfoRequest = new NodesInfoRequest(); + nodesInfoRequest.addMetric(NodesInfoRequest.Metric.PLUGINS.metricName()); + NodesInfoResponse nodesInfoResponse = OpenSearchIntegTestCase.client().admin().cluster().nodesInfo(nodesInfoRequest).actionGet(); + List pluginInfos = nodesInfoResponse.getNodes() + .stream() + .flatMap( + (Function>) nodeInfo -> nodeInfo.getInfo(PluginsAndModules.class).getPluginInfos().stream() + ) + .collect(Collectors.toList()); + Assert.assertTrue(pluginInfos.stream().anyMatch(pluginInfo -> pluginInfo.getName().equals("opensearch-job-scheduler"))); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/integTests/ThreatIntelJobRunnerIT.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/integTests/ThreatIntelJobRunnerIT.java new file mode 100644 index 000000000..cf4cc800c --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/integTests/ThreatIntelJobRunnerIT.java @@ -0,0 +1,240 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.securityanalytics.threatIntel.integTests; + +import org.apache.hc.core5.http.HttpStatus; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchHit; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; +import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.model.DetectorRule; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameter; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.opensearch.securityanalytics.SecurityAnalyticsPlugin.JOB_INDEX_NAME; +import static org.opensearch.securityanalytics.TestHelpers.*; +import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; +import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL; +import static org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedDataUtils.getTifdList; + +public class ThreatIntelJobRunnerIT extends SecurityAnalyticsRestTestCase { + private static final Logger log = LogManager.getLogger(ThreatIntelJobRunnerIT.class); + + public void testCreateDetector_threatIntelEnabled_testJobRunner() throws IOException, InterruptedException { + + // update job runner to run every minute + updateClusterSetting(TIF_UPDATE_INTERVAL.getKey(), "1m"); + + // Create a detector + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String randomDocRuleId = createRule(randomRule()); + List detectorRules = List.of(new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector detector = randomDetectorWithInputsAndThreatIntel(List.of(input), true); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(2, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + String detectoraLstUpdateTime1 = detectorMap.get("last_update_time").toString(); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + List iocs = getThreatIntelFeedIocs(3); + assertEquals(iocs.size(), 3); + + // get job runner index and verify parameters exist + List jobMetaDataList = getJobSchedulerParameter(); + assertEquals(1, jobMetaDataList.size()); + TIFJobParameter jobMetaData = jobMetaDataList.get(0); + Instant firstUpdatedTime = jobMetaData.getLastUpdateTime(); + assertNotNull("Job runner parameter index does not have metadata set", jobMetaData.getLastUpdateTime()); + assertEquals(jobMetaData.isEnabled(), true); + + // get list of first updated time for threat intel feed data + List originalFeedTimestamp = getThreatIntelFeedsTime(); + + //verify feed index exists and each feed_id exists + List feedId = getThreatIntelFeedIds(); + assertNotNull(feedId); + + // wait for job runner to run + Thread.sleep(60000); + waitUntil(() -> { + try { + return verifyJobRan(firstUpdatedTime); + } catch (IOException e) { + throw new RuntimeException("failed to verify that job ran"); + } + }, 120, TimeUnit.SECONDS); + + // verify job's last update time is different + List newJobMetaDataList = getJobSchedulerParameter(); + assertEquals(1, newJobMetaDataList.size()); + TIFJobParameter newJobMetaData = newJobMetaDataList.get(0); + Instant lastUpdatedTime = newJobMetaData.getLastUpdateTime(); + assertNotEquals(firstUpdatedTime.toString(), lastUpdatedTime.toString()); + + // verify new threat intel feed timestamp is different + List newFeedTimestamp = getThreatIntelFeedsTime(); + for (int i = 0; i < newFeedTimestamp.size(); i++) { + assertNotEquals(newFeedTimestamp.get(i), originalFeedTimestamp.get(i)); + } + + // verify detectors updated with latest threat intel feed data + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + String detectoraLstUpdateTime2 = detectorMap.get("last_update_time").toString(); + assertFalse(detectoraLstUpdateTime2.equals(detectoraLstUpdateTime1)); + + } + + protected boolean verifyJobRan(Instant firstUpdatedTime) throws IOException { + // verify job's last update time is different + List newJobMetaDataList = getJobSchedulerParameter(); + assertEquals(1, newJobMetaDataList.size()); + + TIFJobParameter newJobMetaData = newJobMetaDataList.get(0); + Instant newUpdatedTime = newJobMetaData.getLastUpdateTime(); + if (!firstUpdatedTime.toString().equals(newUpdatedTime.toString())) { + return true; + } + return false; + } + + private List getThreatIntelFeedIds() throws IOException { + String request = getMatchAllSearchRequestString(); + SearchResponse res = executeSearchAndGetResponse(".opensearch-sap-threat-intel*", request, false); + return getTifdList(res, xContentRegistry()).stream().map(it -> it.getFeedId()).collect(Collectors.toList()); + } + + private List getThreatIntelFeedsTime() throws IOException { + String request = getMatchAllSearchRequestString(); + SearchResponse res = executeSearchAndGetResponse(".opensearch-sap-threat-intel*", request, false); + return getTifdList(res, xContentRegistry()).stream().map(it -> it.getTimestamp()).collect(Collectors.toList()); + } + + private List getJobSchedulerParameter() throws IOException { + String request = getMatchAllSearchRequestString(); + SearchResponse res = executeSearchAndGetResponse(JOB_INDEX_NAME + "*", request, false); + return getTIFJobParameterList(res, xContentRegistry()).stream().collect(Collectors.toList()); + } + + public static List getTIFJobParameterList(SearchResponse searchResponse, NamedXContentRegistry xContentRegistry) { + List list = new ArrayList<>(); + if (searchResponse.getHits().getHits().length != 0) { + Arrays.stream(searchResponse.getHits().getHits()).forEach(hit -> { + try { + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() + ); + list.add(TIFJobParameter.parse(xcp, hit.getId(), hit.getVersion())); + } catch (Exception e) { + log.error(() -> new ParameterizedMessage( + "Failed to parse TIF Job Parameter metadata from hit {}", hit), + e + ); + } + + }); + } + return list; + } + + private static String getMatchAllSearchRequestString() { + return "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + } + + private static String getMatchNumSearchRequestString(int num) { + return "{\n" + + "\"size\" : " + num + "," + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterServiceTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterServiceTests.java new file mode 100644 index 000000000..6e3b83a78 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterServiceTests.java @@ -0,0 +1,207 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.junit.Before; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.StepListener; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.TestHelpers; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TIFJobParameterServiceTests extends ThreatIntelTestCase { + private TIFJobParameterService tifJobParameterService; + + @Before + public void init() { + tifJobParameterService = new TIFJobParameterService(verifyingClient, clusterService); + } + + public void testcreateJobIndexIfNotExists_whenIndexExist_thenCreateRequestIsNotCalled() { + when(metadata.hasIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME)).thenReturn(true); + + // Verify + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { throw new RuntimeException("Shouldn't get called"); }); + + // Run + StepListener stepListener = new StepListener<>(); + tifJobParameterService.createJobIndexIfNotExists(stepListener); + + // Verify stepListener is called + stepListener.result(); + } + + public void testcreateJobIndexIfNotExists_whenIndexExist_thenCreateRequestIsCalled() { + when(metadata.hasIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME)).thenReturn(false); + + // Verify + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + assertTrue(actionRequest instanceof CreateIndexRequest); + CreateIndexRequest request = (CreateIndexRequest) actionRequest; + assertEquals(SecurityAnalyticsPlugin.JOB_INDEX_NAME, request.index()); + assertEquals("1", request.settings().get("index.number_of_shards")); + assertEquals("0-all", request.settings().get("index.auto_expand_replicas")); + assertEquals("true", request.settings().get("index.hidden")); + assertNotNull(request.mappings()); + return null; + }); + + // Run + StepListener stepListener = new StepListener<>(); + tifJobParameterService.createJobIndexIfNotExists(stepListener); + + // Verify stepListener is called + stepListener.result(); + } + + public void testcreateJobIndexIfNotExists_whenIndexCreatedAlready_thenExceptionIsIgnored() { + when(metadata.hasIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME)).thenReturn(false); + verifyingClient.setExecuteVerifier( + (actionResponse, actionRequest) -> { throw new ResourceAlreadyExistsException(SecurityAnalyticsPlugin.JOB_INDEX_NAME); } + ); + + // Run + StepListener stepListener = new StepListener<>(); + tifJobParameterService.createJobIndexIfNotExists(stepListener); + + // Verify stepListener is called + stepListener.result(); + } + + public void testcreateJobIndexIfNotExists_whenExceptionIsThrown_thenExceptionIsThrown() { + when(metadata.hasIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME)).thenReturn(false); + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { throw new RuntimeException(); }); + + // Run + StepListener stepListener = new StepListener<>(); + tifJobParameterService.createJobIndexIfNotExists(stepListener); + + // Verify stepListener is called + expectThrows(RuntimeException.class, () -> stepListener.result()); + } + + public void testUpdateTIFJobParameter_whenValidInput_thenSucceed() throws Exception { + String tifJobName = TestHelpers.randomLowerCaseString(); + TIFJobParameter tifJobParameter = new TIFJobParameter( + tifJobName, + new IntervalSchedule(Instant.now().truncatedTo(ChronoUnit.MILLIS), 1, ChronoUnit.DAYS) + ); + Instant previousTime = Instant.now().minusMillis(1); + tifJobParameter.setLastUpdateTime(previousTime); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + assertTrue(actionRequest instanceof IndexRequest); + IndexRequest request = (IndexRequest) actionRequest; + assertEquals(tifJobParameter.getName(), request.id()); + assertEquals(DocWriteRequest.OpType.INDEX, request.opType()); + assertEquals(SecurityAnalyticsPlugin.JOB_INDEX_NAME, request.index()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, request.getRefreshPolicy()); + return null; + }); + + tifJobParameterService.updateJobSchedulerParameter(tifJobParameter); + assertTrue(previousTime.isBefore(tifJobParameter.getLastUpdateTime())); + } + + public void testsaveTIFJobParameter_whenValidInput_thenSucceed() { + TIFJobParameter tifJobParameter = randomTifJobParameter(); + Instant previousTime = Instant.now().minusMillis(1); + tifJobParameter.setLastUpdateTime(previousTime); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + assertTrue(actionRequest instanceof IndexRequest); + IndexRequest indexRequest = (IndexRequest) actionRequest; + assertEquals(SecurityAnalyticsPlugin.JOB_INDEX_NAME, indexRequest.index()); + assertEquals(tifJobParameter.getName(), indexRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, indexRequest.getRefreshPolicy()); + assertEquals(DocWriteRequest.OpType.CREATE, indexRequest.opType()); + return null; + }); + + tifJobParameterService.saveTIFJobParameter(tifJobParameter, mock(ActionListener.class)); + assertTrue(previousTime.isBefore(tifJobParameter.getLastUpdateTime())); + } + + public void testGetTifJobParameter_whenException_thenNull() throws Exception { + TIFJobParameter tifJobParameter = setupClientForGetRequest(true, new IndexNotFoundException(SecurityAnalyticsPlugin.JOB_INDEX_NAME)); + assertNull(tifJobParameterService.getJobParameter(tifJobParameter.getName())); + } + + public void testGetTifJobParameter_whenExist_thenReturnTifJobParameter() throws Exception { + TIFJobParameter tifJobParameter = setupClientForGetRequest(true, null); + TIFJobParameter anotherTIFJobParameter = tifJobParameterService.getJobParameter(tifJobParameter.getName()); + + assertTrue(tifJobParameter.getName().equals(anotherTIFJobParameter.getName())); + assertTrue(tifJobParameter.getLastUpdateTime().equals(anotherTIFJobParameter.getLastUpdateTime())); + assertTrue(tifJobParameter.getSchedule().equals(anotherTIFJobParameter.getSchedule())); + assertTrue(tifJobParameter.getState().equals(anotherTIFJobParameter.getState())); + assertTrue(tifJobParameter.getIndices().equals(anotherTIFJobParameter.getIndices())); + } + + public void testGetTifJobParameter_whenNotExist_thenNull() throws Exception { + TIFJobParameter tifJobParameter = setupClientForGetRequest(false, null); + assertNull(tifJobParameterService.getJobParameter(tifJobParameter.getName())); + } + + private TIFJobParameter setupClientForGetRequest(final boolean isExist, final RuntimeException exception) { + TIFJobParameter tifJobParameter = randomTifJobParameter(); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + assertTrue(actionRequest instanceof GetRequest); + GetRequest request = (GetRequest) actionRequest; + assertEquals(tifJobParameter.getName(), request.id()); + assertEquals(SecurityAnalyticsPlugin.JOB_INDEX_NAME, request.index()); + GetResponse response = getMockedGetResponse(isExist ? tifJobParameter : null); + if (exception != null) { + throw exception; + } + return response; + }); + return tifJobParameter; + } + + private GetResponse getMockedGetResponse(TIFJobParameter tifJobParameter) { + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(tifJobParameter != null); + when(response.getSourceAsBytesRef()).thenReturn(toBytesReference(tifJobParameter)); + return response; + } + + private BytesReference toBytesReference(TIFJobParameter tifJobParameter) { + if (tifJobParameter == null) { + return null; + } + + try { + return BytesReference.bytes(tifJobParameter.toXContent(JsonXContent.contentBuilder(), null)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterTests.java new file mode 100644 index 000000000..f7b7ff8d1 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterTests.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.securityanalytics.TestHelpers; +import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.common.TIFMetadata; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameter.THREAT_INTEL_DATA_INDEX_NAME_PREFIX; + +public class TIFJobParameterTests extends ThreatIntelTestCase { + private static final Logger log = LogManager.getLogger(TIFJobParameterTests.class); + + public void testParser_whenAllValueIsFilled_thenSucceed() throws IOException { + String id = TestHelpers.randomLowerCaseString(); + IntervalSchedule schedule = new IntervalSchedule(Instant.now().truncatedTo(ChronoUnit.MILLIS), 1, ChronoUnit.DAYS); + TIFJobParameter tifJobParameter = new TIFJobParameter(id, schedule); + tifJobParameter.enable(); + tifJobParameter.getUpdateStats().setLastProcessingTimeInMillis(randomPositiveLong()); + tifJobParameter.getUpdateStats().setLastSucceededAt(Instant.now().truncatedTo(ChronoUnit.MILLIS)); + tifJobParameter.getUpdateStats().setLastSkippedAt(Instant.now().truncatedTo(ChronoUnit.MILLIS)); + tifJobParameter.getUpdateStats().setLastFailedAt(Instant.now().truncatedTo(ChronoUnit.MILLIS)); + + TIFJobParameter anotherTIFJobParameter = TIFJobParameter.PARSER.parse( + createParser(tifJobParameter.toXContent(XContentFactory.jsonBuilder(), null)), + null + ); + + assertTrue(tifJobParameter.getName().equals(anotherTIFJobParameter.getName())); + assertTrue(tifJobParameter.getLastUpdateTime().equals(anotherTIFJobParameter.getLastUpdateTime())); + assertTrue(tifJobParameter.getEnabledTime().equals(anotherTIFJobParameter.getEnabledTime())); + assertTrue(tifJobParameter.getSchedule().equals(anotherTIFJobParameter.getSchedule())); + assertTrue(tifJobParameter.getState().equals(anotherTIFJobParameter.getState())); + assertTrue(tifJobParameter.getIndices().equals(anotherTIFJobParameter.getIndices())); + assertTrue(tifJobParameter.getUpdateStats().getLastFailedAt().equals(anotherTIFJobParameter.getUpdateStats().getLastFailedAt())); + assertTrue(tifJobParameter.getUpdateStats().getLastSkippedAt().equals(anotherTIFJobParameter.getUpdateStats().getLastSkippedAt())); + assertTrue(tifJobParameter.getUpdateStats().getLastSucceededAt().equals(anotherTIFJobParameter.getUpdateStats().getLastSucceededAt())); + assertTrue(tifJobParameter.getUpdateStats().getLastProcessingTimeInMillis().equals(anotherTIFJobParameter.getUpdateStats().getLastProcessingTimeInMillis())); + + } + + public void testParser_whenNullForOptionalFields_thenSucceed() throws IOException { + String id = TestHelpers.randomLowerCaseString(); + IntervalSchedule schedule = new IntervalSchedule(Instant.now().truncatedTo(ChronoUnit.MILLIS), 1, ChronoUnit.DAYS); + TIFJobParameter tifJobParameter = new TIFJobParameter(id, schedule); + TIFJobParameter anotherTIFJobParameter = TIFJobParameter.PARSER.parse( + createParser(tifJobParameter.toXContent(XContentFactory.jsonBuilder(), null)), + null + ); + + assertTrue(tifJobParameter.getName().equals(anotherTIFJobParameter.getName())); + assertTrue(tifJobParameter.getLastUpdateTime().equals(anotherTIFJobParameter.getLastUpdateTime())); + assertTrue(tifJobParameter.getSchedule().equals(anotherTIFJobParameter.getSchedule())); + assertTrue(tifJobParameter.getState().equals(anotherTIFJobParameter.getState())); + assertTrue(tifJobParameter.getIndices().equals(anotherTIFJobParameter.getIndices())); + } + + public void testCurrentIndexName_whenNotExpired_thenReturnName() { + String id = TestHelpers.randomLowerCaseString(); + TIFJobParameter datasource = new TIFJobParameter(); + datasource.setName(id); + } + + public void testNewIndexName_whenCalled_thenReturnedExpectedValue() { + TIFMetadata tifMetadata = new TIFMetadata("mock_id", + "mock url", + "mock name", + "mock org", + "mock description", + "mock csv", + "mock ip", + 1, + false); + + String name = tifMetadata.getFeedId(); + String suffix = "1"; + TIFJobParameter tifJobParameter = new TIFJobParameter(); + tifJobParameter.setName(name); + assertEquals(String.format(Locale.ROOT, "%s-%s%s", THREAT_INTEL_DATA_INDEX_NAME_PREFIX, name, suffix), tifJobParameter.newIndexName(tifJobParameter,tifMetadata)); + tifJobParameter.getIndices().add(tifJobParameter.newIndexName(tifJobParameter,tifMetadata)); + + log.error(tifJobParameter.getIndices()); + + String anotherSuffix = "2"; + assertEquals(String.format(Locale.ROOT, "%s-%s%s", THREAT_INTEL_DATA_INDEX_NAME_PREFIX, name, anotherSuffix), tifJobParameter.newIndexName(tifJobParameter,tifMetadata)); + } + + public void testLockDurationSeconds() { + TIFJobParameter datasource = new TIFJobParameter(); + assertNotNull(datasource.getLockDurationSeconds()); + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunnerTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunnerTests.java new file mode 100644 index 000000000..82038a91f --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunnerTests.java @@ -0,0 +1,168 @@ + +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.junit.Before; +import org.opensearch.jobscheduler.spi.JobDocVersion; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.common.TIFJobState; +import org.opensearch.securityanalytics.threatIntel.common.TIFLockService; +import org.opensearch.securityanalytics.TestHelpers; + +import java.io.IOException; +import java.time.Instant; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +public class TIFJobRunnerTests extends ThreatIntelTestCase { + @Before + public void init() { + TIFJobRunner.getJobRunnerInstance() + .initialize(clusterService, tifJobUpdateService, tifJobParameterService, tifLockService, threadPool, detectorThreatIntelService); + } + + public void testGetJobRunnerInstance_whenCalledAgain_thenReturnSameInstance() { + assertTrue(TIFJobRunner.getJobRunnerInstance() == TIFJobRunner.getJobRunnerInstance()); + } + + public void testRunJob_whenInvalidClass_thenThrowException() { + JobDocVersion jobDocVersion = new JobDocVersion(randomInt(), randomInt(), randomInt()); + String jobIndexName = TestHelpers.randomLowerCaseString(); + String jobId = TestHelpers.randomLowerCaseString(); + JobExecutionContext jobExecutionContext = new JobExecutionContext(Instant.now(), jobDocVersion, lockService, jobIndexName, jobId); + ScheduledJobParameter jobParameter = mock(ScheduledJobParameter.class); + + // Run + expectThrows(IllegalStateException.class, () -> TIFJobRunner.getJobRunnerInstance().runJob(jobParameter, jobExecutionContext)); + } + + public void testRunJob_whenValidInput_thenSucceed() throws IOException { + JobDocVersion jobDocVersion = new JobDocVersion(randomInt(), randomInt(), randomInt()); + String jobIndexName = TestHelpers.randomLowerCaseString(); + String jobId = TestHelpers.randomLowerCaseString(); + JobExecutionContext jobExecutionContext = new JobExecutionContext(Instant.now(), jobDocVersion, lockService, jobIndexName, jobId); + TIFJobParameter tifJobParameter = randomTifJobParameter(); + + LockModel lockModel = randomLockModel(); + when(tifLockService.acquireLock(tifJobParameter.getName(), TIFLockService.LOCK_DURATION_IN_SECONDS)).thenReturn( + Optional.of(lockModel) + ); + + // Run + TIFJobRunner.getJobRunnerInstance().runJob(tifJobParameter, jobExecutionContext); + + // Verify + verify(tifLockService).acquireLock(tifJobParameter.getName(), tifLockService.LOCK_DURATION_IN_SECONDS); + verify(tifJobParameterService).getJobParameter(tifJobParameter.getName()); + verify(tifLockService).releaseLock(lockModel); + } + + public void testUpdateTIFJobRunner_whenExceptionBeforeAcquiringLock_thenNoReleaseLock() { + ScheduledJobParameter jobParameter = mock(ScheduledJobParameter.class); + when(jobParameter.getName()).thenReturn(TestHelpers.randomLowerCaseString()); + when(tifLockService.acquireLock(jobParameter.getName(), TIFLockService.LOCK_DURATION_IN_SECONDS)).thenThrow( + new RuntimeException() + ); + + // Run + expectThrows(Exception.class, () -> TIFJobRunner.getJobRunnerInstance().updateJobRunner(jobParameter).run()); + + // Verify + verify(tifLockService, never()).releaseLock(any()); + } + + public void testUpdateTIFJobRunner_whenExceptionAfterAcquiringLock_thenReleaseLock() throws IOException { + ScheduledJobParameter jobParameter = mock(ScheduledJobParameter.class); + when(jobParameter.getName()).thenReturn(TestHelpers.randomLowerCaseString()); + LockModel lockModel = randomLockModel(); + when(tifLockService.acquireLock(jobParameter.getName(), TIFLockService.LOCK_DURATION_IN_SECONDS)).thenReturn( + Optional.of(lockModel) + ); + when(tifJobParameterService.getJobParameter(jobParameter.getName())).thenThrow(new RuntimeException()); + + // Run + TIFJobRunner.getJobRunnerInstance().updateJobRunner(jobParameter).run(); + + // Verify + verify(tifLockService).releaseLock(any()); + } + + public void testUpdateTIFJob_whenTIFJobDoesNotExist_thenDoNothing() throws IOException { + TIFJobParameter tifJob = new TIFJobParameter(); + + // Run + TIFJobRunner.getJobRunnerInstance().updateJobParameter(tifJob, mock(Runnable.class)); + + // Verify + verify(tifJobUpdateService, never()).deleteAllTifdIndices(TestHelpers.randomLowerCaseStringList(),TestHelpers.randomLowerCaseStringList()); + } + + public void testUpdateTIFJob_whenInvalidState_thenUpdateLastFailedAt() throws IOException { + TIFJobParameter tifJob = new TIFJobParameter(); + tifJob.enable(); + tifJob.getUpdateStats().setLastFailedAt(null); + tifJob.setState(randomStateExcept(TIFJobState.AVAILABLE)); + when(tifJobParameterService.getJobParameter(tifJob.getName())).thenReturn(tifJob); + + // Run + TIFJobRunner.getJobRunnerInstance().updateJobParameter(tifJob, mock(Runnable.class)); + + // Verify + assertFalse(tifJob.isEnabled()); + assertNotNull(tifJob.getUpdateStats().getLastFailedAt()); + verify(tifJobParameterService).updateJobSchedulerParameter(tifJob); + } + + public void testUpdateTIFJob_whenValidInput_thenSucceed() throws IOException { + TIFJobParameter tifJob = randomTifJobParameter(); + tifJob.setState(TIFJobState.AVAILABLE); + when(tifJobParameterService.getJobParameter(tifJob.getName())).thenReturn(tifJob); + Runnable renewLock = mock(Runnable.class); + + // Run + TIFJobRunner.getJobRunnerInstance().updateJobParameter(tifJob, renewLock); + + // Verify + verify(tifJobUpdateService, times(0)).deleteAllTifdIndices(TestHelpers.randomLowerCaseStringList(),TestHelpers.randomLowerCaseStringList()); + verify(tifJobUpdateService).createThreatIntelFeedData(tifJob, renewLock); + } + + public void testUpdateTIFJob_whenDeleteTask_thenDeleteOnly() throws IOException { + TIFJobParameter tifJob = randomTifJobParameter(); + tifJob.setState(TIFJobState.AVAILABLE); + when(tifJobParameterService.getJobParameter(tifJob.getName())).thenReturn(tifJob); + Runnable renewLock = mock(Runnable.class); + + // Run + TIFJobRunner.getJobRunnerInstance().updateJobParameter(tifJob, renewLock); + + // Verify + verify(tifJobUpdateService, times(0)).deleteAllTifdIndices(TestHelpers.randomLowerCaseStringList(),TestHelpers.randomLowerCaseStringList()); + } + + public void testUpdateTIFJobExceptionHandling() throws IOException { + TIFJobParameter tifJob = new TIFJobParameter(); + tifJob.setName(TestHelpers.randomLowerCaseString()); + tifJob.getUpdateStats().setLastFailedAt(null); + when(tifJobParameterService.getJobParameter(tifJob.getName())).thenReturn(tifJob); + doThrow(new RuntimeException("test failure")).when(tifJobUpdateService).deleteAllTifdIndices(TestHelpers.randomLowerCaseStringList(),TestHelpers.randomLowerCaseStringList()); + + // Run + TIFJobRunner.getJobRunnerInstance().updateJobParameter(tifJob, mock(Runnable.class)); + + // Verify + assertNotNull(tifJob.getUpdateStats().getLastFailedAt()); + verify(tifJobParameterService).updateJobSchedulerParameter(tifJob); + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobUpdateServiceTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobUpdateServiceTests.java new file mode 100644 index 000000000..76b0f8fe4 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobUpdateServiceTests.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.Before; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.common.TIFJobState; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +@SuppressForbidden(reason = "unit test") +public class TIFJobUpdateServiceTests extends ThreatIntelTestCase { + + private TIFJobUpdateService tifJobUpdateService1; + + @Before + public void init() { + tifJobUpdateService1 = new TIFJobUpdateService(clusterService, tifJobParameterService, threatIntelFeedDataService, builtInTIFMetadataLoader); + } + + public void testUpdateOrCreateThreatIntelFeedData_whenValidInput_thenSucceed() throws IOException { + + ShardRouting shardRouting = mock(ShardRouting.class); + when(shardRouting.started()).thenReturn(true); + when(routingTable.allShards(anyString())).thenReturn(Arrays.asList(shardRouting)); + + TIFJobParameter tifJobParameter = new TIFJobParameter(); + tifJobParameter.setState(TIFJobState.AVAILABLE); + + tifJobParameter.getUpdateStats().setLastSucceededAt(null); + tifJobParameter.getUpdateStats().setLastProcessingTimeInMillis(null); + + // Run + List newFeeds = tifJobUpdateService1.createThreatIntelFeedData(tifJobParameter, mock(Runnable.class)); + + // Verify feeds + assertNotNull(newFeeds); + } + +} diff --git a/src/test/java/org/opensearch/securityanalytics/writable/LogTypeTests.java b/src/test/java/org/opensearch/securityanalytics/writable/LogTypeTests.java index 4ede7891b..d9d592641 100644 --- a/src/test/java/org/opensearch/securityanalytics/writable/LogTypeTests.java +++ b/src/test/java/org/opensearch/securityanalytics/writable/LogTypeTests.java @@ -21,7 +21,8 @@ public class LogTypeTests { public void testLogTypeAsStreamRawFieldOnly() throws IOException { LogType logType = new LogType( "1", "my_log_type", "description", false, - List.of(new LogType.Mapping("rawField", null, null)) + List.of(new LogType.Mapping("rawField", null, null)), + List.of(new LogType.IocFields("ip", List.of("dst.ip"))) ); BytesStreamOutput out = new BytesStreamOutput(); logType.writeTo(out); @@ -32,13 +33,16 @@ public void testLogTypeAsStreamRawFieldOnly() throws IOException { assertEquals(logType.getIsBuiltIn(), newLogType.getIsBuiltIn()); assertEquals(logType.getMappings().size(), newLogType.getMappings().size()); assertEquals(logType.getMappings().get(0).getRawField(), newLogType.getMappings().get(0).getRawField()); + assertEquals(logType.getIocFieldsList().get(0).getFields().get(0), newLogType.getIocFieldsList().get(0).getFields().get(0)); + assertEquals(logType.getIocFieldsList().get(0).getIoc(), newLogType.getIocFieldsList().get(0).getIoc()); } @Test public void testLogTypeAsStreamFull() throws IOException { LogType logType = new LogType( "1", "my_log_type", "description", false, - List.of(new LogType.Mapping("rawField", "some_ecs_field", "some_ocsf_field")) + List.of(new LogType.Mapping("rawField", "some_ecs_field", "some_ocsf_field")), + List.of(new LogType.IocFields("ip", List.of("dst.ip"))) ); BytesStreamOutput out = new BytesStreamOutput(); logType.writeTo(out); @@ -49,11 +53,14 @@ public void testLogTypeAsStreamFull() throws IOException { assertEquals(logType.getIsBuiltIn(), newLogType.getIsBuiltIn()); assertEquals(logType.getMappings().size(), newLogType.getMappings().size()); assertEquals(logType.getMappings().get(0).getRawField(), newLogType.getMappings().get(0).getRawField()); + assertEquals(logType.getIocFieldsList().get(0).getFields().get(0), newLogType.getIocFieldsList().get(0).getFields().get(0)); + assertEquals(logType.getIocFieldsList().get(0).getIoc(), newLogType.getIocFieldsList().get(0).getIoc()); + } @Test public void testLogTypeAsStreamNoMappings() throws IOException { - LogType logType = new LogType("1", "my_log_type", "description", false, null); + LogType logType = new LogType("1", "my_log_type", "description", false, null, null); BytesStreamOutput out = new BytesStreamOutput(); logType.writeTo(out); StreamInput sin = StreamInput.wrap(out.bytes().toBytesRef().bytes); diff --git a/src/test/resources/threatIntel/sample_csv_with_description_and_header.csv b/src/test/resources/threatIntel/sample_csv_with_description_and_header.csv new file mode 100644 index 000000000..750377fd6 --- /dev/null +++ b/src/test/resources/threatIntel/sample_csv_with_description_and_header.csv @@ -0,0 +1,4 @@ +# description + +ip +1.0.0.0/24 \ No newline at end of file diff --git a/src/test/resources/threatIntel/sample_invalid_less_than_two_fields.csv b/src/test/resources/threatIntel/sample_invalid_less_than_two_fields.csv new file mode 100644 index 000000000..08670061c --- /dev/null +++ b/src/test/resources/threatIntel/sample_invalid_less_than_two_fields.csv @@ -0,0 +1,2 @@ +network +1.0.0.0/24 \ No newline at end of file diff --git a/src/test/resources/threatIntel/sample_valid.csv b/src/test/resources/threatIntel/sample_valid.csv new file mode 100644 index 000000000..c599b6888 --- /dev/null +++ b/src/test/resources/threatIntel/sample_valid.csv @@ -0,0 +1,2 @@ +1.0.0.0/24,Australia +10.0.0.0/24,USA \ No newline at end of file diff --git a/src/test/resources/threatIntelFeed/feedMetadata.json b/src/test/resources/threatIntelFeed/feedMetadata.json new file mode 100644 index 000000000..0e5583797 --- /dev/null +++ b/src/test/resources/threatIntelFeed/feedMetadata.json @@ -0,0 +1,12 @@ +{ + "alienvault_reputation_ip_database": { + "id": "alienvault_reputation_ip_database", + "url": "https://reputation.alienvault.com/reputation.generic", + "name": "Alienvault IP Reputation", + "organization": "Alienvault", + "description": "Alienvault IP Reputation threat intelligence feed managed by AlienVault", + "feed_format": "csv", + "ioc_type": "ip", + "ioc_col": 0 + } +}