diff --git a/tif/build.gradle b/tif/build.gradle index 294d6992c..1de8ee548 100644 --- a/tif/build.gradle +++ b/tif/build.gradle @@ -1,3 +1,5 @@ +import org.opensearch.gradle.test.RestIntegTestTask + /* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 @@ -21,6 +23,10 @@ dependencies { implementation 'com.fasterxml.jackson.core:jackson-databind:2.17.0' implementation 'com.fasterxml.jackson.module:jackson-module-afterburner:2.17.0' implementation 'com.google.guava:guava:33.1.0-jre' + implementation 'org.opensearch.client:opensearch-rest-high-level-client:2.13.0' + implementation platform('org.apache.logging.log4j:log4j-bom:2.22.1') + implementation 'org.apache.logging.log4j:log4j-core' + implementation 'org.apache.logging.log4j:log4j-slf4j2-impl' testImplementation "org.mockito:mockito-inline:5.2.0" testImplementation "org.mockito:mockito-core:5.11.0" @@ -31,4 +37,54 @@ dependencies { test { useJUnitPlatform() -} \ No newline at end of file +} + +sourceSets { + integrationTest { + java { + compileClasspath += main.output + test.output + runtimeClasspath += main.output + test.output + srcDir file('src/test/java') + } + resources.srcDir file('src/test/resources') + } +} + +configurations { + integrationTestImplementation.extendsFrom testImplementation + integrationTestRuntime.extendsFrom testRuntime +} + +task s3ConnectorIT(type: Test) { + group = 'verification' + testClassesDirs = sourceSets.integrationTest.output.classesDirs + + useJUnitPlatform() + + classpath = sourceSets.integrationTest.runtimeClasspath + systemProperty 'tests.s3connector.bucket', System.getProperty('tests.s3connector.bucket') + systemProperty 'tests.s3connector.region', System.getProperty('tests.s3connector.region') + systemProperty 'tests.s3connector.roleArn', System.getProperty('tests.s3connector.roleArn') + + filter { + includeTestsMatching 'S3ConnectorIT' + } +} + +task systemIndexFeedStoreIT(type: Test) { + group = 'verification' + testClassesDirs = sourceSets.integrationTest.output.classesDirs + + useJUnitPlatform() + + classpath = sourceSets.integrationTest.runtimeClasspath + systemProperty 'tests.opensearch.host', System.getProperty('tests.opensearch.host') + systemProperty 'tests.opensearch.user', System.getProperty('tests.opensearch.user') + systemProperty 'tests.opensearch.password', System.getProperty('tests.opensearch.password') + + filter { + includeTestsMatching 'SystemIndexFeedStoreIT' + } +} + + diff --git a/tif/src/main/java/org/opensearch/securityanalytics/connector/S3Connector.java b/tif/src/main/java/org/opensearch/securityanalytics/connector/S3Connector.java index f8cafbd5f..59cc3f92d 100644 --- a/tif/src/main/java/org/opensearch/securityanalytics/connector/S3Connector.java +++ b/tif/src/main/java/org/opensearch/securityanalytics/connector/S3Connector.java @@ -32,7 +32,10 @@ public List loadIOCs() { final GetObjectRequest getObjectRequest = getObjectRequest(); final ResponseInputStream response = s3Client.getObject(getObjectRequest); - return inputCodec.parse(response); + final List iocs = inputCodec.parse(response); + setFeedId(iocs); + + return iocs; } private GetObjectRequest getObjectRequest() { @@ -41,4 +44,8 @@ private GetObjectRequest getObjectRequest() { .key(s3ConnectorConfig.getObjectKey()) .build(); } + + private void setFeedId(final List iocs) { + iocs.forEach(ioc -> ioc.setFeedId(s3ConnectorConfig.getFeedId())); + } } diff --git a/tif/src/main/java/org/opensearch/securityanalytics/connector/model/S3ConnectorConfig.java b/tif/src/main/java/org/opensearch/securityanalytics/connector/model/S3ConnectorConfig.java index 1bce20a48..f9fff8950 100644 --- a/tif/src/main/java/org/opensearch/securityanalytics/connector/model/S3ConnectorConfig.java +++ b/tif/src/main/java/org/opensearch/securityanalytics/connector/model/S3ConnectorConfig.java @@ -13,15 +13,18 @@ public class S3ConnectorConfig { private final String roleArn; private final IOCSchema iocSchema; private final InputCodecSchema inputCodecSchema; + private final String feedId; public S3ConnectorConfig(final String bucketName, final String objectKey, final String region, - final String roleArn, final IOCSchema iocSchema, final InputCodecSchema inputCodecSchema) { + final String roleArn, final IOCSchema iocSchema, final InputCodecSchema inputCodecSchema, + final String feedId) { this.bucketName = bucketName; this.objectKey = objectKey; this.region = region; this.roleArn = roleArn; this.iocSchema = iocSchema; this.inputCodecSchema = inputCodecSchema; + this.feedId = feedId; } public String getBucketName() { @@ -47,4 +50,8 @@ public IOCSchema getIocSchema() { public InputCodecSchema getInputCodecSchema() { return inputCodecSchema; } + + public String getFeedId() { + return feedId; + } } diff --git a/tif/src/main/java/org/opensearch/securityanalytics/exceptions/FeedStoreException.java b/tif/src/main/java/org/opensearch/securityanalytics/exceptions/FeedStoreException.java new file mode 100644 index 000000000..2f448b228 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/exceptions/FeedStoreException.java @@ -0,0 +1,11 @@ +package org.opensearch.securityanalytics.exceptions; + +public class FeedStoreException extends RuntimeException { + public FeedStoreException(final String message) { + super(message); + } + + public FeedStoreException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/exceptions/IndexAccessorException.java b/tif/src/main/java/org/opensearch/securityanalytics/exceptions/IndexAccessorException.java new file mode 100644 index 000000000..9c9599248 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/exceptions/IndexAccessorException.java @@ -0,0 +1,11 @@ +package org.opensearch.securityanalytics.exceptions; + +public class IndexAccessorException extends RuntimeException { + public IndexAccessorException(final String message) { + super(message); + } + + public IndexAccessorException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/exceptions/ResourceReaderException.java b/tif/src/main/java/org/opensearch/securityanalytics/exceptions/ResourceReaderException.java new file mode 100644 index 000000000..773feffd1 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/exceptions/ResourceReaderException.java @@ -0,0 +1,7 @@ +package org.opensearch.securityanalytics.exceptions; + +public class ResourceReaderException extends RuntimeException { + public ResourceReaderException(final String message) { + super(message); + } +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/feed/FeedManager.java b/tif/src/main/java/org/opensearch/securityanalytics/feed/FeedManager.java new file mode 100644 index 000000000..1838c61c7 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/feed/FeedManager.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.feed; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +public class FeedManager { + private static final Logger log = LoggerFactory.getLogger(FeedManager.class); + + private final ScheduledExecutorService executorService; + private final Map> registeredTasks; + + public FeedManager() { + final ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(1); + scheduledThreadPoolExecutor.setRemoveOnCancelPolicy(true); + + executorService = Executors.unconfigurableScheduledExecutorService(scheduledThreadPoolExecutor); + registeredTasks = new HashMap<>(); + } + + @VisibleForTesting + FeedManager(final ScheduledExecutorService scheduledExecutorService, final Map> registeredTasks) { + this.executorService = scheduledExecutorService; + this.registeredTasks = registeredTasks; + } + + public void registerFeedRetriever(final String feedId, final Runnable feedRetriever, final Duration refreshInterval) { + if (registeredTasks.containsKey(feedId)) { + log.warn("Field with ID {} already has a retriever registered. Will replace existing feed retriever with new definition.", feedId); + deregisterFeedRetriever(feedId); + } + + final ScheduledFuture retrieverFuture = executorService.scheduleAtFixedRate(feedRetriever, 0, refreshInterval.toMillis(), TimeUnit.MILLISECONDS); + registeredTasks.put(feedId, retrieverFuture); + } + + public void deregisterFeedRetriever(final String feedId) { + if (registeredTasks.containsKey(feedId)) { + final ScheduledFuture retrieverFuture = registeredTasks.remove(feedId); + retrieverFuture.cancel(true); + } + } +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/feed/retriever/FeedRetriever.java b/tif/src/main/java/org/opensearch/securityanalytics/feed/retriever/FeedRetriever.java new file mode 100644 index 000000000..a6d992216 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/feed/retriever/FeedRetriever.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.feed.retriever; + +import org.opensearch.securityanalytics.connector.IOCConnector; +import org.opensearch.securityanalytics.feed.store.FeedStore; +import org.opensearch.securityanalytics.feed.store.model.UpdateType; +import org.opensearch.securityanalytics.model.IOC; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class FeedRetriever implements Runnable { + private static final Logger log = LoggerFactory.getLogger(FeedRetriever.class); + + private final IOCConnector iocConnector; + private final FeedStore feedStore; + private final UpdateType updateType; + private final String feedId; + + public FeedRetriever(final IOCConnector iocConnector, final FeedStore feedStore, final UpdateType updateType, final String feedId) { + this.iocConnector = iocConnector; + this.feedStore = feedStore; + this.updateType = updateType; + this.feedId = feedId; + } + + @Override + public void run() { + try { + final List iocs = iocConnector.loadIOCs(); + feedStore.storeIOCs(iocs, updateType); + } catch (final Exception e) { + log.error("Unable to fetch feed with ID {}", feedId, e); + } + } +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/feed/store/FeedStore.java b/tif/src/main/java/org/opensearch/securityanalytics/feed/store/FeedStore.java new file mode 100644 index 000000000..92df9e032 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/feed/store/FeedStore.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.feed.store; + +import org.opensearch.securityanalytics.feed.store.model.UpdateType; +import org.opensearch.securityanalytics.model.IOC; + +import java.util.List; + +public interface FeedStore { + /** + * Accepts a list of IOCs and stores them locally for use in feed processing + * + * @param iocs - A list of the IOCs to store + * @param updateType - The type of update to make to the underlying store + */ + void storeIOCs(List iocs, UpdateType updateType); +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStore.java b/tif/src/main/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStore.java new file mode 100644 index 000000000..82af52dd5 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStore.java @@ -0,0 +1,127 @@ +package org.opensearch.securityanalytics.feed.store; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.securityanalytics.exceptions.FeedStoreException; +import org.opensearch.securityanalytics.feed.store.model.UpdateType; +import org.opensearch.securityanalytics.index.IndexAccessor; +import org.opensearch.securityanalytics.model.IOC; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class SystemIndexFeedStore implements FeedStore { + private static final Logger log = LoggerFactory.getLogger(SystemIndexFeedStore.class); + + // TODO - alias with rollover + static final String ALIAS_NAME = ".opensearch-sap-tif-store"; + static final int PRIMARY_SHARD_COUNT = 1; + static final boolean HIDDEN_INDEX = true; + static final String IOC_DOC_ID_FORMAT = "%s-%s"; + + private final IndexAccessor indexAccessor; + private final ObjectMapper objectMapper; + + public SystemIndexFeedStore(final IndexAccessor indexAccessor) { + this.indexAccessor = indexAccessor; + this.objectMapper = new ObjectMapper(); + } + + @Override + public void storeIOCs(final List iocs, final UpdateType updateType) { + if (iocs.isEmpty()) { + log.info("No IOCs found, skipping update"); + return; + } + + try { + validateIOCs(iocs); + setupFeedIndex(updateType, iocs.get(0).getFeedId()); + updateFeedIndex(iocs); + } catch (final Exception e) { + throw new FeedStoreException("Exception updating feed store for feed ID: " + iocs.get(0).getFeedId(), e); + } + } + + private void validateIOCs(final List iocs) { + final Set feedIds = iocs.stream() + .map(IOC::getFeedId) + .collect(Collectors.toSet()); + + if (feedIds.size() != 1) { + throw new IllegalArgumentException("Exactly one feed should be updated at a time. Found feed IDs: " + feedIds); + } + } + + private void setupFeedIndex(final UpdateType updateType, final String feedId) { + indexAccessor.createRolloverAlias(ALIAS_NAME, createRolloverAliasSettings(), createISMPolicyRolloverConfiguration()); + + /* TODO - this probably needs locking for consistency. The IOC scan done by the TIF platform could read a partial state + The tradeoff is that it may take seconds if not minutes to update the feed system index. Do we want to block scanning + during that time or live with a potentially inconsistent state? */ + if (UpdateType.REPLACE.equals(updateType)) { + indexAccessor.deleteByQuery(ALIAS_NAME, deleteByQueryBuilder(feedId)); + } + } + + private QueryBuilder deleteByQueryBuilder(final String feedId) { + return QueryBuilders.matchQuery(IOC.FEED_ID_FIELD_NAME, feedId); + } + + private Settings createRolloverAliasSettings() { + return Settings.builder() + .put(IndexAccessor.SHARD_COUNT_SETTING_NAME, PRIMARY_SHARD_COUNT) + .put(IndexAccessor.AUTO_EXPAND_REPLICA_COUNT_SETTING_NAME, IndexAccessor.EXPAND_ALL_REPLICA_COUNT_SETTING_VALUE) + .put(IndexAccessor.HIDDEN_INDEX_SETTING_NAME, HIDDEN_INDEX) + .put(IndexAccessor.INDEX_ROLLOVER_ALIAS_SETTING_NAME, ALIAS_NAME) + .build(); + } + + private Map createISMPolicyRolloverConfiguration() { + return Map.of( + IndexAccessor.ROLLOVER_INDEX_SIZE_SETTING_NAME, IndexAccessor.DEFAULT_ROLLOVER_INDEX_SIZE_SETTING_VALUE + ); + } + + private void updateFeedIndex(final List iocs) { + // TODO - paginate. can be GBs of IOCs + final BulkRequest bulkRequest = bulkRequest(iocs); + final BulkResponse bulkResponse = indexAccessor.bulk(bulkRequest); + + if (bulkResponse.hasFailures()) { + throw new FeedStoreException(bulkResponse.buildFailureMessage()); + } + } + + private BulkRequest bulkRequest(final List iocs) { + final List> bulkActions = iocs.stream() + .map(this::indexRequest) + .collect(Collectors.toList()); + + return new BulkRequest().add(bulkActions); + } + + private IndexRequest indexRequest(final IOC ioc) { + // TODO - nearly positive you can just index a POJO. Need to investigate how to do that instead of expensively converting each to a Map + final Map iocAsMap = objectMapper.convertValue(ioc, new TypeReference<>() {}); + return new IndexRequest(ALIAS_NAME) + .source(iocAsMap, XContentType.JSON) + .id(docId(ioc)); + } + + private String docId(final IOC ioc) { + return String.format(IOC_DOC_ID_FORMAT, ioc.getFeedId(), ioc.getId()); + } +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/feed/store/model/UpdateType.java b/tif/src/main/java/org/opensearch/securityanalytics/feed/store/model/UpdateType.java new file mode 100644 index 000000000..13813de33 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/feed/store/model/UpdateType.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.feed.store.model; + +public enum UpdateType { + /** + * The provided IOCs should be considered the entire set of IOCs in the feed. FeedStores are expected to purge any + * existing IOCs for this feed in favor of the received set + */ + REPLACE, + /** + * The provided IOCs should be considered a delta from the current state of the feed store. Any conflicts should be resolved + * by updating the feed store with the new definition of a given IOC + */ + DELTA +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/index/IndexAccessor.java b/tif/src/main/java/org/opensearch/securityanalytics/index/IndexAccessor.java new file mode 100644 index 000000000..8a38d81fc --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/index/IndexAccessor.java @@ -0,0 +1,62 @@ +package org.opensearch.securityanalytics.index; + +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.unit.ByteSizeUnit; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; + +import java.util.Map; + +public interface IndexAccessor { + String SHARD_COUNT_SETTING_NAME = "index.number_of_shards"; + String AUTO_EXPAND_REPLICA_COUNT_SETTING_NAME = "index.auto_expand_replicas"; + String EXPAND_ALL_REPLICA_COUNT_SETTING_VALUE = "0-all"; + String HIDDEN_INDEX_SETTING_NAME = "index.hidden"; + String ROLLOVER_INDEX_FORMAT = "%s-000001"; + String INDEX_PATTERN_FORMAT = "%s*"; + //String ROLLOVER_INDEX_SIZE_SETTING_NAME = "min_primary_shard_size"; + //String DEFAULT_ROLLOVER_INDEX_SIZE_SETTING_VALUE = new ByteSizeValue(30, ByteSizeUnit.GB).getStringRep(); + String ROLLOVER_INDEX_SIZE_SETTING_NAME = "min_doc_count"; + String DEFAULT_ROLLOVER_INDEX_SIZE_SETTING_VALUE = "1"; + String INDEX_ROLLOVER_ALIAS_SETTING_NAME = "index.plugins.index_state_management.rollover_alias"; + + /** + * Creates a rollover alias if it is not already present. This consists of 3 steps: + * 1. Create an index template based on the provided settings + * 2. Create an ISM policy with the provided rollover conditions + * 3. Create the initial write index with the rollover alias attached + * + * @param aliasName - the name of the rollover alias + * @param settings - the settings to apply to the index + * @param rolloverConfiguration - a map of the rollover setting name to its value + */ + void createRolloverAlias(String aliasName, Settings settings, Map rolloverConfiguration); + + /** + * Deletes a rollover alias by name. Also deletes the ISM policy, indices that match the alias pattern, and the index template + * associated with the alias. + * + * @param aliasName - the name of the alias + */ + void deleteRolloverAlias(String aliasName); + + /** + * Deletes a set of documents that match the provided query + * + * @param indexName - the name of the index to delete from + * @param queryBuilder - the filter conditions for the delete-by-query + * @return BulkByScrollResponse - the results of the delete-by-query execution + */ + BulkByScrollResponse deleteByQuery(String indexName, QueryBuilder queryBuilder); + + /** + * Executes a bulk request + * + * @param bulkRequest - the request to execute + * @return BulkResponse - the results of the bulk execution + */ + BulkResponse bulk(BulkRequest bulkRequest); +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/index/InternalClientIndexAccessor.java b/tif/src/main/java/org/opensearch/securityanalytics/index/InternalClientIndexAccessor.java new file mode 100644 index 000000000..d556471cf --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/index/InternalClientIndexAccessor.java @@ -0,0 +1,147 @@ +//package org.opensearch.securityanalytics.index; +// +//import org.opensearch.action.admin.indices.alias.Alias; +//import org.opensearch.action.admin.indices.alias.get.GetAliasesRequest; +//import org.opensearch.action.admin.indices.alias.get.GetAliasesResponse; +//import org.opensearch.action.admin.indices.create.CreateIndexRequest; +//import org.opensearch.action.admin.indices.create.CreateIndexResponse; +//import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +//import org.opensearch.action.admin.indices.exists.indices.IndicesExistsRequest; +//import org.opensearch.action.admin.indices.exists.indices.IndicesExistsResponse; +//import org.opensearch.action.bulk.BulkRequest; +//import org.opensearch.action.bulk.BulkResponse; +//import org.opensearch.action.support.master.AcknowledgedResponse; +//import org.opensearch.client.Client; +//import org.opensearch.common.action.ActionFuture; +//import org.opensearch.common.settings.Settings; +//import org.opensearch.core.common.unit.ByteSizeValue; +//import org.opensearch.index.query.QueryBuilder; +//import org.opensearch.index.reindex.BulkByScrollResponse; +//import org.opensearch.index.reindex.DeleteByQueryAction; +//import org.opensearch.index.reindex.DeleteByQueryRequestBuilder; +//import org.opensearch.securityanalytics.exceptions.IndexAccessorException; +//import org.slf4j.Logger; +//import org.slf4j.LoggerFactory; +// +//import java.util.concurrent.TimeUnit; +// +//public class InternalClientIndexAccessor implements IndexAccessor { +// private static final Logger log = LoggerFactory.getLogger(InternalClientIndexAccessor.class); +// +// private static final long REQUEST_TIMEOUT_SECONDS = 30L; +// private static final TimeUnit REQUEST_TIMEOUT_TIME_UNIT = TimeUnit.SECONDS; +// +// private final Client client; +// +// public InternalClientIndexAccessor(final Client client) { +// this.client = client; +// } +// +// @Override +// public void createIndex(final String aliasName, final Settings settings) { +// final boolean doesAliasExist = doesAliasExist(aliasName); +// if (doesAliasExist) { +// log.debug("Alias with name {} already exists. Skipping index creation", aliasName); +// return; +// } +// +// final boolean doesIndexExist = doesIndexExist(indexName); +// if (doesIndexExist) { +// log.debug("Index with name {} already exists. Skipping creation", indexName); +// return; +// } +// +// doCreateIndex(indexName, aliasName, settings); +// } +// +// private boolean doesAliasExist(final String aliasName) { +// if (aliasName == null) { +// return false; +// } +// +// final GetAliasesRequest getAliasesRequest = new GetAliasesRequest(aliasName); +// try { +// final ActionFuture getAliasesResponseFuture = client.admin().indices().getAliases(getAliasesRequest); +// final GetAliasesResponse getAliasesResponse = getAliasesResponseFuture.actionGet(REQUEST_TIMEOUT_SECONDS, REQUEST_TIMEOUT_TIME_UNIT); +// +// return getAliasesResponse.getAliases().containsKey(aliasName); +// } catch (final Exception e) { +// throw new IndexAccessorException("Failed to get aliases for " + aliasName, e); +// } +// } +// +// private boolean doesIndexExist(final String indexName) { +// final IndicesExistsRequest indicesExistsRequest = new IndicesExistsRequest(indexName); +// try { +// final ActionFuture indicesExistsResponseFuture = client.admin().indices().exists(indicesExistsRequest); +// final IndicesExistsResponse indicesExistsResponse = indicesExistsResponseFuture.actionGet(REQUEST_TIMEOUT_SECONDS, REQUEST_TIMEOUT_TIME_UNIT); +// +// return indicesExistsResponse.isExists(); +// } catch (final Exception e) { +// throw new IndexAccessorException("Failed to check if index exists with name: " + indexName, e); +// } +// } +// +// private void doCreateIndex(final String indexName, final String aliasName, final Settings settings) { +// final CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); +// createIndexRequest.settings(settings); +// createIndexRequest.alias(new Alias(aliasName).writeIndex(true)); +// +// try { +// final ActionFuture createIndexResponseFuture = client.admin().indices().create(createIndexRequest); +// createIndexResponseFuture.actionGet(REQUEST_TIMEOUT_SECONDS, REQUEST_TIMEOUT_TIME_UNIT); +// } catch (final Exception e) { +// throw new IndexAccessorException("Failed to create index with name: " + indexName, e); +// } +// } +// +// @Override +// public void deleteIndex(final String indexName) { +// final DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(indexName); +// try { +// final ActionFuture deleteIndexResponseFuture = client.admin().indices().delete(deleteIndexRequest); +// final AcknowledgedResponse deleteIndexResponse = deleteIndexResponseFuture.actionGet(REQUEST_TIMEOUT_SECONDS, REQUEST_TIMEOUT_TIME_UNIT); +// +// if (!deleteIndexResponse.isAcknowledged()) { +// throw new IndexAccessorException("Delete index request was not acknowledged for index with name: " + indexName); +// } +// } catch (final Exception e) { +// throw new IndexAccessorException("Failed to delete index with name: " + indexName, e); +// } +// } +// +// @Override +// public void createRolloverAlias(final String aliasName, final ByteSizeValue indexSizeRolloverValue, final Settings settings) { +// +// } +// +// @Override +// public void deleteAlias(final String aliasName) { +// +// } +// +// @Override +// public BulkByScrollResponse deleteByQuery(final String indexName, final QueryBuilder queryBuilder) { +// final DeleteByQueryRequestBuilder deleteByQueryRequestBuilder = new DeleteByQueryRequestBuilder(client, DeleteByQueryAction.INSTANCE) +// .source(indexName) +// .filter(queryBuilder) +// .refresh(true); +// +// try { +// final ActionFuture deleteByQueryResponseFuture = deleteByQueryRequestBuilder.execute(); +// return deleteByQueryResponseFuture.actionGet(REQUEST_TIMEOUT_SECONDS, REQUEST_TIMEOUT_TIME_UNIT); +// } catch (final Exception e) { +// throw new IndexAccessorException("Failed to delete by query", e); +// } +// } +// +// @Override +// public BulkResponse bulk(final BulkRequest bulkRequest) { +// try { +// final ActionFuture bulkResponseFuture = client.bulk(bulkRequest); +// return bulkResponseFuture.actionGet(REQUEST_TIMEOUT_SECONDS, REQUEST_TIMEOUT_TIME_UNIT); +// } catch (final Exception e) { +// throw new IndexAccessorException("Failed to execute bulk request", e); +// } +// } +//} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/index/RHLCIndexAccessor.java b/tif/src/main/java/org/opensearch/securityanalytics/index/RHLCIndexAccessor.java new file mode 100644 index 000000000..280d23cbe --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/index/RHLCIndexAccessor.java @@ -0,0 +1,322 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.index; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.admin.indices.alias.Alias; +import org.opensearch.action.admin.indices.alias.IndicesAliasesRequest; +import org.opensearch.action.admin.indices.alias.get.GetAliasesRequest; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.admin.indices.template.delete.DeleteIndexTemplateRequest; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.client.GetAliasesResponse; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.client.indices.CreateIndexRequest; +import org.opensearch.client.indices.DeleteAliasRequest; +import org.opensearch.client.indices.GetIndexRequest; +import org.opensearch.client.indices.PutComposableIndexTemplateRequest; +import org.opensearch.client.indices.PutIndexTemplateRequest; +import org.opensearch.cluster.metadata.AliasMetadata; +import org.opensearch.cluster.metadata.ComposableIndexTemplate; +import org.opensearch.cluster.metadata.Template; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.rest.RestRequest; +import org.opensearch.securityanalytics.exceptions.IndexAccessorException; +import org.opensearch.securityanalytics.util.ResourceReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public class RHLCIndexAccessor implements IndexAccessor { + private static final Logger log = LoggerFactory.getLogger(RHLCIndexAccessor.class); + + private static final String ROLLOVER_TEMPLATE_PATH = "ISM-policies/rollover-template.txt"; + private static final String ISM_POLICY_PATH_FORMAT = "_plugins/_ism/policies/%s"; + private static final String ISM_ATTACH_POLICY_PATH_FORMAT = "_plugins/_ism/add/%s"; + private static final String POLICY_ID_FORMAT = "{\"policy_id\":\"%s\"}"; + + private final RestHighLevelClient client; + private final ResourceReader resourceReader; + private final ObjectMapper objectMapper; + + public RHLCIndexAccessor(final RestHighLevelClient client, final ResourceReader resourceReader, final ObjectMapper objectMapper) { + this.client = client; + this.resourceReader = resourceReader; + this.objectMapper = objectMapper; + } + + @Override + public void createRolloverAlias(final String aliasName, final Settings settings, final Map rolloverConfiguration) { + final boolean doesAliasExist = doesAliasExist(aliasName); + if (doesAliasExist) { + log.debug("Alias with name {} already exists. Skipping rollover alias creation", aliasName); + return; + } + + final String initialWriteIndex = String.format(IndexAccessor.ROLLOVER_INDEX_FORMAT, aliasName); + final boolean doesIndexExist = doesIndexExist(initialWriteIndex); + if (doesIndexExist) { + log.debug("Index with name {} already exists. Skipping rollover alias creation", initialWriteIndex); + return; + } + + putIndexTemplate(aliasName, settings); + createRolloverPolicyIfNotPresent(aliasName, rolloverConfiguration); + doCreateIndex(initialWriteIndex, aliasName); + attachPolicyToWriteIndex(aliasName); + } + + private boolean doesAliasExist(final String aliasName) { + if (aliasName == null) { + return false; + } + + final GetAliasesRequest getAliasesRequest = new GetAliasesRequest(aliasName); + try { + return client.indices().existsAlias(getAliasesRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + throw new IndexAccessorException("Failed to check if alias exists with name: " + aliasName, e); + } + } + + private boolean doesIndexExist(final String indexName) { + final GetIndexRequest getIndexRequest = new GetIndexRequest(indexName); + try { + return client.indices().exists(getIndexRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + throw new IndexAccessorException("Failed to check if index exists with name: " + indexName, e); + } + } + + private void putIndexTemplate(final String aliasName, final Settings settings) { + final String indexPattern = String.format(IndexAccessor.INDEX_PATTERN_FORMAT, aliasName); + + try { + final Template template = new Template(settings, null, null); + final ComposableIndexTemplate composableIndexTemplate = new ComposableIndexTemplate(List.of(indexPattern), template, null, null, null, null); + final PutComposableIndexTemplateRequest putIndexTemplateRequest = new PutComposableIndexTemplateRequest() + .name(aliasName) + .indexTemplate(composableIndexTemplate); + client.indices().putIndexTemplate(putIndexTemplateRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + throw new IndexAccessorException("Failed to create index template for alias: " + aliasName, e); + } + } + + private void createRolloverPolicyIfNotPresent(final String policyName, final Map rolloverConfiguration) { + final boolean doesISMPolicyExist = doesISMPolicyExist(policyName); + if (doesISMPolicyExist) { + log.debug("ISM policy with name {} already exists. Skipping ISM policy creation", policyName); + } else { + createRolloverPolicy(policyName, rolloverConfiguration); + } + } + + private boolean doesISMPolicyExist(final String policyName) { + final Request checkExistsRequest = new Request(RestRequest.Method.HEAD.name(), String.format(ISM_POLICY_PATH_FORMAT, policyName)); + final Response checkExistsResponse; + try { + checkExistsResponse = client.getLowLevelClient().performRequest(checkExistsRequest); + return checkExistsResponse.getStatusLine().getStatusCode() == 200; + } catch (final Exception e) { + throw new IndexAccessorException("Exception checking if ISM policy exists with name: " + policyName, e); + } + } + + private void createRolloverPolicy(final String policyName, final Map rolloverConfiguration) { + final Request createRequest = new Request(RestRequest.Method.PUT.name(), String.format(ISM_POLICY_PATH_FORMAT, policyName)); + try { + final String rolloverPolicy = getISMPolicy(policyName, rolloverConfiguration); + final StringEntity stringEntity = new StringEntity(rolloverPolicy, ContentType.APPLICATION_JSON); + createRequest.setEntity(stringEntity); + client.getLowLevelClient().performRequest(createRequest); + } catch (final Exception e) { + throw new IndexAccessorException("Exception creating rollover policy: " + policyName, e); + } + } + + private String getISMPolicy(final String policyName, final Map rolloverConfiguration) { + final String policyTemplate = resourceReader.readResourceAsString(ROLLOVER_TEMPLATE_PATH); + final String rolloverConfigurationAsString = getRolloverConfigurationAsString(rolloverConfiguration); + return String.format(policyTemplate, policyName, rolloverConfigurationAsString, policyName); + } + + private String getRolloverConfigurationAsString(final Map rolloverConfiguration) { + try { + return objectMapper.writeValueAsString(rolloverConfiguration); + } catch (final Exception e) { + throw new IndexAccessorException("Failed to serialize rollover configuration: " + rolloverConfiguration, e); + } + } + + private void doCreateIndex(final String indexName, final String aliasName) { + final CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName) + .alias(new Alias(aliasName).writeIndex(true)); + + try { + client.indices().create(createIndexRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + throw new IndexAccessorException("Exception creating index: " + indexName, e); + } + } + + // ISM ignores hidden and system indices so we must manually attach the policy to the alias + private void attachPolicyToWriteIndex(final String aliasName) { + System.out.println(aliasName); + + final GetAliasesRequest getAliasesRequest = new GetAliasesRequest(); + final String writeIndex; + try { + final GetAliasesResponse getAliasesResponse = client.indices().getAlias(getAliasesRequest, RequestOptions.DEFAULT); + System.out.println(getAliasesResponse.getError()); + System.out.println(getAliasesResponse.getAliases().size()); + final Optional optionalWriteIndex = getAliasesResponse.getAliases().entrySet().stream() + .peek(mapEntry -> System.out.println("eval " + mapEntry.getKey())) + .filter(mapEntry -> isWriteIndex(mapEntry.getValue(), aliasName)) + .map(Map.Entry::getKey) + .findFirst(); + + if (optionalWriteIndex.isEmpty()) { + throw new IndexAccessorException("No write index found for alias: " + aliasName); + } + + writeIndex = optionalWriteIndex.get(); + } catch (final Exception e) { + throw new IndexAccessorException("Failed to determine write index for alias: " + aliasName, e); + } + + final Request attachPolicyRequest = new Request(RestRequest.Method.POST.name(), String.format(ISM_ATTACH_POLICY_PATH_FORMAT, writeIndex)); + + try { + final String policyValue = String.format(POLICY_ID_FORMAT, aliasName); + final StringEntity stringEntity = new StringEntity(policyValue, ContentType.APPLICATION_JSON); + attachPolicyRequest.setEntity(stringEntity); + client.getLowLevelClient().performRequest(attachPolicyRequest); + } catch (final Exception e) { + throw new IndexAccessorException("Failed to attach policy to alias: " + aliasName, e); + } + } + + private boolean isWriteIndex(final Set aliasMetadata, final String aliasName) { + final Optional optionalRelevantAliasMetadata = aliasMetadata.stream() + .filter(alias -> alias.getAlias().equals(aliasName)) + .findFirst(); + + if (optionalRelevantAliasMetadata.isEmpty()) { + return false; + } + + final AliasMetadata relevantAliasMetadata = optionalRelevantAliasMetadata.get(); + return relevantAliasMetadata.writeIndex(); + } + + @Override + public void deleteRolloverAlias(final String aliasName) { + deleteAlias(aliasName); + deleteRolloverPolicy(aliasName); + deleteIndexTemplate(aliasName); + + final String indexPattern = String.format(IndexAccessor.INDEX_PATTERN_FORMAT, aliasName); + deleteIndex(indexPattern); + } + + private void deleteAlias(final String aliasName) { + final DeleteAliasRequest deleteAliasRequest = new DeleteAliasRequest(String.format(ROLLOVER_INDEX_FORMAT, aliasName), aliasName); + + try { + client.indices().deleteAlias(deleteAliasRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + if (e instanceof OpenSearchStatusException && + (e.getMessage().contains("aliases_not_found_exception") || e.getMessage().contains("index_not_found_exception"))) { + log.info("Alias with name {} not found, assuming it was already deleted", aliasName); + return; + } + + throw new IndexAccessorException("Exception deleting alias: " + aliasName, e); + } + } + + private void deleteRolloverPolicy(final String policyName) { + final Request deleteRequest = new Request(RestRequest.Method.DELETE.name(), String.format(ISM_POLICY_PATH_FORMAT, policyName)); + + try { + client.getLowLevelClient().performRequest(deleteRequest); + } catch (final Exception e) { + if (e instanceof ResponseException && ((ResponseException) e).getResponse().getStatusLine().getStatusCode() == 404) { + log.info("Policy with name {} was not found. Assuming it was already deleted", policyName); + return; + } + throw new IndexAccessorException("Exception deleting rollover policy: " + policyName, e); + } + } + + private void deleteIndexTemplate(final String templateName) { + final DeleteIndexTemplateRequest deleteIndexTemplateRequest = new DeleteIndexTemplateRequest(templateName); + + try { + client.indices().deleteTemplate(deleteIndexTemplateRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + if (e instanceof OpenSearchStatusException && e.getMessage().contains("index_template_missing_exception")) { + log.info("Template with name {} not found, assuming it was already deleted", templateName); + return; + } + throw new IndexAccessorException("Exception deleting index template: " + templateName, e); + } + } + + private void deleteIndex(final String indexName) { + final DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(indexName); + + try { + client.indices().delete(deleteIndexRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + if (e instanceof OpenSearchStatusException && e.getMessage().contains("index_not_found_exception")) { + log.info("Index with name {} not found, assuming it was already deleted", indexName); + return; + } + + throw new IndexAccessorException("Exception deleting index: " + indexName, e); + } + } + + @Override + public BulkByScrollResponse deleteByQuery(final String indexName, final QueryBuilder queryBuilder) { + final DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(indexName) + .setQuery(queryBuilder) + .setRefresh(true); + + try { + return client.deleteByQuery(deleteByQueryRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + throw new IndexAccessorException("Exception deleting by query for index: " + indexName, e); + } + } + + @Override + public BulkResponse bulk(final BulkRequest bulkRequest) { + try { + return client.bulk(bulkRequest, RequestOptions.DEFAULT); + } catch (final Exception e) { + throw new IndexAccessorException("Exception making bulk request", e); + } + } +} diff --git a/tif/src/main/java/org/opensearch/securityanalytics/model/IOC.java b/tif/src/main/java/org/opensearch/securityanalytics/model/IOC.java index c1de87609..c9382a60f 100644 --- a/tif/src/main/java/org/opensearch/securityanalytics/model/IOC.java +++ b/tif/src/main/java/org/opensearch/securityanalytics/model/IOC.java @@ -7,4 +7,24 @@ import java.io.Serializable; public abstract class IOC implements Serializable { + public static final String FEED_ID_FIELD_NAME = "feedId"; + + private String id; + private String feedId; + + public String getId() { + return id; + } + + public void setId(final String id) { + this.id = id; + } + + public String getFeedId() { + return feedId; + } + + public void setFeedId(final String feedId) { + this.feedId = feedId; + } } diff --git a/tif/src/main/java/org/opensearch/securityanalytics/model/STIX2.java b/tif/src/main/java/org/opensearch/securityanalytics/model/STIX2.java index db07f3fe7..b384f46fa 100644 --- a/tif/src/main/java/org/opensearch/securityanalytics/model/STIX2.java +++ b/tif/src/main/java/org/opensearch/securityanalytics/model/STIX2.java @@ -12,7 +12,6 @@ public class STIX2 extends IOC { private String type; @JsonProperty("spec_version") private String specVersion; - private String id; public String getType() { return type; @@ -29,12 +28,4 @@ public String getSpecVersion() { public void setSpecVersion(final String specVersion) { this.specVersion = specVersion; } - - public String getId() { - return id; - } - - public void setId(final String id) { - this.id = id; - } } diff --git a/tif/src/main/java/org/opensearch/securityanalytics/util/ResourceReader.java b/tif/src/main/java/org/opensearch/securityanalytics/util/ResourceReader.java new file mode 100644 index 000000000..917a8ef77 --- /dev/null +++ b/tif/src/main/java/org/opensearch/securityanalytics/util/ResourceReader.java @@ -0,0 +1,28 @@ +package org.opensearch.securityanalytics.util; + +import org.opensearch.securityanalytics.exceptions.ResourceReaderException; + +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; + +public class ResourceReader { + public String readResourceAsString(final String resourcePath) { + final Optional optionalResourcePath = Optional.of(getClass()) + .map(Class::getClassLoader) + .map(classLoader -> classLoader.getResource(resourcePath)) + .map(URL::getPath) + .map(Path::of); + + if (optionalResourcePath.isEmpty()) { + throw new ResourceReaderException(String.format("Unable to find resource [%s]", resourcePath)); + } + + try { + return Files.readString(optionalResourcePath.get()); + } catch (final Exception e) { + throw new ResourceReaderException(String.format("Unable to read resource [%s]", resourcePath)); + } + } +} diff --git a/tif/src/main/resources/ISM-policies/rollover-template.txt b/tif/src/main/resources/ISM-policies/rollover-template.txt new file mode 100644 index 000000000..9c045d931 --- /dev/null +++ b/tif/src/main/resources/ISM-policies/rollover-template.txt @@ -0,0 +1,24 @@ +{ + "policy": { + "description": "%s", + "default_state": "size_rollover", + "states": [ + { + "name": "size_rollover", + "actions": [ + { + "rollover": %s + } + ] + } + ], + "ism_template": [ + { + "index_patterns": [ + "%s*" + ], + "priority": 100 + } + ] + } +} \ No newline at end of file diff --git a/tif/src/test/java/org/opensearch/securityanalytics/connector/S3ConnectorIT.java b/tif/src/test/java/org/opensearch/securityanalytics/connector/S3ConnectorIT.java new file mode 100644 index 000000000..f6dd37478 --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/connector/S3ConnectorIT.java @@ -0,0 +1,175 @@ +package org.opensearch.securityanalytics.connector; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfSystemProperty; +import org.opensearch.securityanalytics.connector.factory.InputCodecFactory; +import org.opensearch.securityanalytics.connector.factory.S3ClientFactory; +import org.opensearch.securityanalytics.connector.factory.StsAssumeRoleCredentialsProviderFactory; +import org.opensearch.securityanalytics.connector.factory.StsClientFactory; +import org.opensearch.securityanalytics.connector.model.InputCodecSchema; +import org.opensearch.securityanalytics.connector.model.S3ConnectorConfig; +import org.opensearch.securityanalytics.connector.util.NewlineDelimitedIOCGenerator; +import org.opensearch.securityanalytics.connector.util.S3ObjectGenerator; +import org.opensearch.securityanalytics.model.IOC; +import org.opensearch.securityanalytics.model.IOCSchema; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.NoSuchBucketException; +import software.amazon.awssdk.services.s3.model.NoSuchKeyException; +import software.amazon.awssdk.services.sts.model.StsException; + +import java.io.IOException; +import java.util.List; +import java.util.Random; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Integration test class for the S3 connector. The following system parameters must be specified to successfully run the tests: + * + * tests.s3connector.bucket - the name of the S3 bucket to use for the tests + * tests.s3connector.region - the AWS region of the S3 bucket + * tests.s3connector.roleArn - the IAM role ARN to assume when making S3 calls + * + * The local system must have sufficient credentials to write to S3, delete from S3, and assume the provided role. + * + * The tests are disabled by default as there is no default value for the tests.s3connector.bucket system property. This is + * intentional as the tests will fail when run without the proper setup, such as during CI workflows. + * + * Example command to manually run this class's ITs: + * ./gradlew :tif:s3ConnectorIT -Dtests.s3connector.bucket= -Dtests.s3connector.region= -Dtests.s3connector.roleArn= + */ +@EnabledIfSystemProperty(named = "tests.s3connector.bucket", matches = ".+") +public class S3ConnectorIT { + private static final String FEED_ID = UUID.randomUUID().toString(); + private static final int NUMBER_OF_IOCS = new Random().nextInt(100); + + private S3Client s3Client; + private S3ObjectGenerator s3ObjectGenerator; + private String bucket; + private String region; + private String roleArn; + + @BeforeEach + public void setup() { + region = System.getProperty("tests.s3connector.region"); + roleArn = System.getProperty("tests.s3connector.roleArn"); + bucket = System.getProperty("tests.s3connector.bucket"); + + s3Client = S3Client.builder() + .region(Region.of(region)) + .build(); + s3ObjectGenerator = new S3ObjectGenerator(s3Client, bucket); + } + + private S3Connector createS3Connector(final S3ConnectorConfig s3ConnectorConfig) { + final StsClientFactory stsClientFactory = new StsClientFactory(); + final StsAssumeRoleCredentialsProviderFactory stsAssumeRoleCredentialsProviderFactory = new StsAssumeRoleCredentialsProviderFactory(stsClientFactory); + final S3ClientFactory s3ClientFactory = new S3ClientFactory(stsAssumeRoleCredentialsProviderFactory); + final InputCodecFactory inputCodecFactory = new InputCodecFactory(); + + return new S3Connector(s3ConnectorConfig, s3ClientFactory, inputCodecFactory); + } + + @Test + public void testS3Connector_Success() throws IOException { + final String objectKey = UUID.randomUUID().toString(); + s3ObjectGenerator.write(NUMBER_OF_IOCS, objectKey, new NewlineDelimitedIOCGenerator()); + + final S3ConnectorConfig s3ConnectorConfig = new S3ConnectorConfig( + bucket, + objectKey, + region, + roleArn, + IOCSchema.STIX2, + InputCodecSchema.ND_JSON, + FEED_ID + ); + final S3Connector s3Connector = createS3Connector(s3ConnectorConfig); + + final List iocs = s3Connector.loadIOCs(); + assertEquals(NUMBER_OF_IOCS, iocs.size()); + + deleteObject(objectKey); + } + + @Test + public void testS3Connector_BucketDoesNotExist() { + final String objectKey = UUID.randomUUID().toString(); + final S3ConnectorConfig s3ConnectorConfig = new S3ConnectorConfig( + UUID.randomUUID().toString(), + objectKey, + region, + roleArn, + IOCSchema.STIX2, + InputCodecSchema.ND_JSON, + FEED_ID + ); + final S3Connector s3Connector = createS3Connector(s3ConnectorConfig); + + assertThrows(NoSuchBucketException.class, s3Connector::loadIOCs); + } + + @Test + public void testS3Connector_ObjectDoesNotExist() { + final S3ConnectorConfig s3ConnectorConfig = new S3ConnectorConfig( + bucket, + UUID.randomUUID().toString(), + region, + roleArn, + IOCSchema.STIX2, + InputCodecSchema.ND_JSON, + FEED_ID + ); + final S3Connector s3Connector = createS3Connector(s3ConnectorConfig); + + assertThrows(NoSuchKeyException.class, s3Connector::loadIOCs); + } + + @Test + public void testS3Connector_InvalidRegion() { + final String objectKey = UUID.randomUUID().toString(); + final S3ConnectorConfig s3ConnectorConfig = new S3ConnectorConfig( + bucket, + objectKey, + UUID.randomUUID().toString(), + roleArn, + IOCSchema.STIX2, + InputCodecSchema.ND_JSON, + FEED_ID + ); + final S3Connector s3Connector = createS3Connector(s3ConnectorConfig); + + assertThrows(SdkClientException.class, s3Connector::loadIOCs); + } + + @Test + public void testS3Connector_FailToAssumeRule() { + final String objectKey = UUID.randomUUID().toString(); + final S3ConnectorConfig s3ConnectorConfig = new S3ConnectorConfig( + bucket, + objectKey, + region, + roleArn + UUID.randomUUID(), + IOCSchema.STIX2, + InputCodecSchema.ND_JSON, + FEED_ID + ); + final S3Connector s3Connector = createS3Connector(s3ConnectorConfig); + + assertThrows(StsException.class, s3Connector::loadIOCs); + } + + private void deleteObject(final String objectKey) { + final DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder() + .bucket(bucket) + .key(objectKey) + .build(); + s3Client.deleteObject(deleteObjectRequest); + } +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/connector/S3ConnectorTests.java b/tif/src/test/java/org/opensearch/securityanalytics/connector/S3ConnectorTests.java index 3bb24c842..691d436a9 100644 --- a/tif/src/test/java/org/opensearch/securityanalytics/connector/S3ConnectorTests.java +++ b/tif/src/test/java/org/opensearch/securityanalytics/connector/S3ConnectorTests.java @@ -40,13 +40,15 @@ public class S3ConnectorTests { private static final String ROLE_ARN = UUID.randomUUID().toString(); private static final IOCSchema IOC_SCHEMA = IOCSchema.STIX2; private static final InputCodecSchema INPUT_CODEC_SCHEMA = InputCodecSchema.ND_JSON; + private static final String FEED_ID = UUID.randomUUID().toString(); private static final S3ConnectorConfig S3_CONNECTOR_CONFIG = new S3ConnectorConfig( BUCKET_NAME, OBJECT_KEY, REGION, ROLE_ARN, IOC_SCHEMA, - INPUT_CODEC_SCHEMA + INPUT_CODEC_SCHEMA, + FEED_ID ); @Mock @@ -78,7 +80,7 @@ public void tearDown() { verify(s3ClientFactory).create(eq(ROLE_ARN), eq(REGION)); verify(inputCodecFactory).create(eq(INPUT_CODEC_SCHEMA), eq(IOC_SCHEMA)); - verifyNoMoreInteractions(s3ClientFactory, inputCodecFactory, s3Client, inputCodec, responseInputStream); + verifyNoMoreInteractions(s3ClientFactory, inputCodecFactory, s3Client, inputCodec, responseInputStream, ioc); } @Test @@ -91,6 +93,7 @@ public void testLoadIOCs() { final ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(GetObjectRequest.class); verify(s3Client).getObject(argumentCaptor.capture()); verify(inputCodec).parse(eq(responseInputStream)); + verify(ioc).setFeedId(eq(FEED_ID)); assertEquals(OBJECT_KEY, argumentCaptor.getValue().key()); assertEquals(BUCKET_NAME, argumentCaptor.getValue().bucket()); diff --git a/tif/src/test/java/org/opensearch/securityanalytics/connector/factory/InputCodecFactoryTests.java b/tif/src/test/java/org/opensearch/securityanalytics/connector/factory/InputCodecFactoryTests.java index 4a95f4208..d423d719f 100644 --- a/tif/src/test/java/org/opensearch/securityanalytics/connector/factory/InputCodecFactoryTests.java +++ b/tif/src/test/java/org/opensearch/securityanalytics/connector/factory/InputCodecFactoryTests.java @@ -6,7 +6,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.opensearch.securityanalytics.connector.codec.NewlineDelimitedJsonCodecTests; +import org.opensearch.securityanalytics.connector.codec.NewlineDelimitedJsonCodec; import org.opensearch.securityanalytics.connector.model.InputCodecSchema; import org.opensearch.securityanalytics.model.IOCSchema; @@ -22,6 +22,6 @@ public void setup() { @Test public void testDoCreate_ND_JSON() { - assertInstanceOf(NewlineDelimitedJsonCodecTests.class, inputCodecFactory.doCreate(InputCodecSchema.ND_JSON, IOCSchema.STIX2)); + assertInstanceOf(NewlineDelimitedJsonCodec.class, inputCodecFactory.doCreate(InputCodecSchema.ND_JSON, IOCSchema.STIX2)); } } diff --git a/tif/src/test/java/org/opensearch/securityanalytics/connector/model/InputCodecSchemaTests.java b/tif/src/test/java/org/opensearch/securityanalytics/connector/model/InputCodecSchemaTests.java index 658bda81e..2f5423c08 100644 --- a/tif/src/test/java/org/opensearch/securityanalytics/connector/model/InputCodecSchemaTests.java +++ b/tif/src/test/java/org/opensearch/securityanalytics/connector/model/InputCodecSchemaTests.java @@ -5,7 +5,7 @@ package org.opensearch.securityanalytics.connector.model; import org.junit.jupiter.api.Test; -import org.opensearch.securityanalytics.connector.codec.NewlineDelimitedJsonCodecTests; +import org.opensearch.securityanalytics.connector.codec.NewlineDelimitedJsonCodec; import org.opensearch.securityanalytics.model.IOCSchema; import static org.junit.jupiter.api.Assertions.assertInstanceOf; @@ -13,6 +13,6 @@ public class InputCodecSchemaTests { @Test public void testGetInputCodecConstructor_ND_JSON() { - assertInstanceOf(NewlineDelimitedJsonCodecTests.class, InputCodecSchema.ND_JSON.getInputCodecConstructor().apply(IOCSchema.STIX2)); + assertInstanceOf(NewlineDelimitedJsonCodec.class, InputCodecSchema.ND_JSON.getInputCodecConstructor().apply(IOCSchema.STIX2)); } } diff --git a/tif/src/test/java/org/opensearch/securityanalytics/connector/util/IOCGenerator.java b/tif/src/test/java/org/opensearch/securityanalytics/connector/util/IOCGenerator.java new file mode 100644 index 000000000..8d33a0f76 --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/connector/util/IOCGenerator.java @@ -0,0 +1,8 @@ +package org.opensearch.securityanalytics.connector.util; + +import java.io.IOException; +import java.io.OutputStream; + +public interface IOCGenerator { + void write(int numberOfIOCs, OutputStream outputStream) throws IOException; +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/connector/util/NewlineDelimitedIOCGenerator.java b/tif/src/test/java/org/opensearch/securityanalytics/connector/util/NewlineDelimitedIOCGenerator.java new file mode 100644 index 000000000..b0e640468 --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/connector/util/NewlineDelimitedIOCGenerator.java @@ -0,0 +1,41 @@ +package org.opensearch.securityanalytics.connector.util; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.securityanalytics.model.IOC; +import org.opensearch.securityanalytics.model.STIX2; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.PrintWriter; +import java.util.List; + +public class NewlineDelimitedIOCGenerator implements IOCGenerator { + private final ObjectMapper objectMapper; + private final STIX2Generator stix2Generator; + + public NewlineDelimitedIOCGenerator() { + this.objectMapper = new ObjectMapper(); + this.stix2Generator = new STIX2Generator(); + } + + @Override + public void write(final int numberOfIOCs, final OutputStream outputStream) { + try (final PrintWriter printWriter = new PrintWriter(outputStream)) { + writeLines(numberOfIOCs, printWriter); + } + } + + private void writeLines(final int numberOfIOCs, final PrintWriter printWriter) { + final List iocs = stix2Generator.generateSTIX2(numberOfIOCs); + iocs.forEach(ioc -> writeLine(ioc, printWriter)); + } + + private void writeLine(final IOC ioc, final PrintWriter printWriter) { + try { + final String iocAsString = objectMapper.writeValueAsString(ioc); + printWriter.write(iocAsString + "\n"); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/connector/util/S3ObjectGenerator.java b/tif/src/test/java/org/opensearch/securityanalytics/connector/util/S3ObjectGenerator.java new file mode 100644 index 000000000..755c056cf --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/connector/util/S3ObjectGenerator.java @@ -0,0 +1,39 @@ +package org.opensearch.securityanalytics.connector.util; + +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +public class S3ObjectGenerator { + private final S3Client s3Client; + private final String bucketName; + + public S3ObjectGenerator(final S3Client s3Client, final String bucketName) { + this.s3Client = s3Client; + this.bucketName = bucketName; + } + + public void write(final int numberOfIOCs, final String key, final IOCGenerator iocGenerator) throws IOException { + final File tempFile = File.createTempFile("s3-object-" + numberOfIOCs + "-", null); + + try { + try (final OutputStream outputStream = new FileOutputStream(tempFile)) { + + iocGenerator.write(numberOfIOCs, outputStream); + outputStream.flush(); + } + + final PutObjectRequest putObjectRequest = PutObjectRequest.builder() + .bucket(bucketName) + .key(key) + .build(); + s3Client.putObject(putObjectRequest, tempFile.toPath()); + } finally { + tempFile.delete(); + } + } +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/connector/util/STIX2Generator.java b/tif/src/test/java/org/opensearch/securityanalytics/connector/util/STIX2Generator.java new file mode 100644 index 000000000..3d74efe3b --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/connector/util/STIX2Generator.java @@ -0,0 +1,40 @@ +package org.opensearch.securityanalytics.connector.util; + +import org.opensearch.securityanalytics.model.IOC; +import org.opensearch.securityanalytics.model.STIX2; + +import java.util.List; +import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class STIX2Generator { + public List generateSTIX2(final int count) { + return generateSTIX2(count, i -> randomSTIX2()); + } + + public List generateSTIX2(final int count, final String feedId) { + return generateSTIX2(count, i -> randomSTIX2(feedId)); + } + + public List generateSTIX2(final int count, final Function generatorFunction) { + return IntStream.range(0, count) + .mapToObj(generatorFunction::apply) + .collect(Collectors.toList()); + } + + public STIX2 randomSTIX2() { + return randomSTIX2(UUID.randomUUID().toString()); + } + + public STIX2 randomSTIX2(final String feedId) { + final STIX2 ioc = new STIX2(); + ioc.setId(UUID.randomUUID().toString()); + ioc.setFeedId(feedId); + ioc.setSpecVersion(UUID.randomUUID().toString()); + ioc.setType(UUID.randomUUID().toString()); + + return ioc; + } +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/feed/FeedManagerTests.java b/tif/src/test/java/org/opensearch/securityanalytics/feed/FeedManagerTests.java new file mode 100644 index 000000000..1e1210c96 --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/feed/FeedManagerTests.java @@ -0,0 +1,105 @@ +package org.opensearch.securityanalytics.feed; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class FeedManagerTests { + private static final String FEED_ID = UUID.randomUUID().toString(); + + @Mock + private ScheduledExecutorService scheduledExecutorService; + @Mock + private Runnable feedRetriever; + @Mock + private ScheduledFuture feedRetrieverFuture; + + private Map> registeredTasks; + private FeedManager feedManager; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + registeredTasks = new HashMap<>(); + feedManager = new FeedManager(scheduledExecutorService, registeredTasks); + } + + @AfterEach + public void teardown() { + verifyNoMoreInteractions(scheduledExecutorService, feedRetriever, feedRetrieverFuture); + } + + @Test + public void testRegisterFeedRetriever() { + final long millisDuration = new Random().nextLong(); + when(scheduledExecutorService.scheduleAtFixedRate(eq(feedRetriever), eq(0L), eq(millisDuration), eq(TimeUnit.MILLISECONDS))) + .thenReturn(feedRetrieverFuture); + + feedManager.registerFeedRetriever(FEED_ID, feedRetriever, Duration.ofMillis(millisDuration)); + + assertEquals(1, registeredTasks.size()); + assertTrue(registeredTasks.containsKey(FEED_ID)); + assertEquals(feedRetrieverFuture, registeredTasks.get(FEED_ID)); + + verify(scheduledExecutorService).scheduleAtFixedRate(eq(feedRetriever), eq(0L), eq(millisDuration), eq(TimeUnit.MILLISECONDS)); + } + + @Test + public void testDeregisterFeedRetriever() { + registeredTasks.put(FEED_ID, feedRetrieverFuture); + assertEquals(1, registeredTasks.size()); + + feedManager.deregisterFeedRetriever(FEED_ID); + + assertTrue(registeredTasks.isEmpty()); + verify(feedRetrieverFuture).cancel(eq(true)); + } + + @Test + public void testDeregisterFeedRetriever_FeedNotRegistered() { + registeredTasks.put(UUID.randomUUID().toString(), feedRetrieverFuture); + assertEquals(1, registeredTasks.size()); + + feedManager.deregisterFeedRetriever(FEED_ID); + + assertEquals(1, registeredTasks.size()); + assertFalse(registeredTasks.containsKey(FEED_ID)); + } + + @Test + public void testRegisterFeedRetriever_FeedAlreadyRegistered() { + registeredTasks.put(FEED_ID, feedRetrieverFuture); + + final long millisDuration = new Random().nextLong(); + when(scheduledExecutorService.scheduleAtFixedRate(eq(feedRetriever), eq(0L), eq(millisDuration), eq(TimeUnit.MILLISECONDS))) + .thenReturn(feedRetrieverFuture); + + feedManager.registerFeedRetriever(FEED_ID, feedRetriever, Duration.ofMillis(millisDuration)); + + assertEquals(1, registeredTasks.size()); + assertTrue(registeredTasks.containsKey(FEED_ID)); + assertEquals(feedRetrieverFuture, registeredTasks.get(FEED_ID)); + + verify(scheduledExecutorService).scheduleAtFixedRate(eq(feedRetriever), eq(0L), eq(millisDuration), eq(TimeUnit.MILLISECONDS)); + verify(feedRetrieverFuture).cancel(eq(true)); + } +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/feed/retriever/FeedRetrieverTests.java b/tif/src/test/java/org/opensearch/securityanalytics/feed/retriever/FeedRetrieverTests.java new file mode 100644 index 000000000..59646595e --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/feed/retriever/FeedRetrieverTests.java @@ -0,0 +1,75 @@ +package org.opensearch.securityanalytics.feed.retriever; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.securityanalytics.connector.IOCConnector; +import org.opensearch.securityanalytics.feed.store.FeedStore; +import org.opensearch.securityanalytics.feed.store.model.UpdateType; +import org.opensearch.securityanalytics.model.IOC; + +import java.util.List; +import java.util.UUID; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class FeedRetrieverTests { + private static final UpdateType UPDATE_TYPE = UpdateType.REPLACE; + private static final String FEED_ID = UUID.randomUUID().toString(); + + @Mock + private IOCConnector iocConnector; + @Mock + private FeedStore feedStore; + @Mock + private IOC ioc; + + private FeedRetriever feedRetriever; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + feedRetriever = new FeedRetriever(iocConnector, feedStore, UPDATE_TYPE, FEED_ID); + } + + @AfterEach + public void teardown() { + verifyNoMoreInteractions(iocConnector, feedStore, ioc); + } + + @Test + public void testRun() { + when(iocConnector.loadIOCs()).thenReturn(List.of(ioc)); + + feedRetriever.run(); + + verify(iocConnector).loadIOCs(); + verify(feedStore).storeIOCs(eq(List.of(ioc)), eq(UPDATE_TYPE)); + } + + @Test + public void testRun_ExceptionLoadingIOCs_DoesNotThrow() { + when(iocConnector.loadIOCs()).thenThrow(new RuntimeException()); + + feedRetriever.run(); + + verify(iocConnector).loadIOCs(); + } + + @Test + public void testRun_ExceptionStoringIOCs_DoesNotThrow() { + when(iocConnector.loadIOCs()).thenReturn(List.of(ioc)); + doThrow(new RuntimeException()).when(feedStore).storeIOCs(eq(List.of(ioc)), eq(UPDATE_TYPE)); + + feedRetriever.run(); + + verify(iocConnector).loadIOCs(); + verify(feedStore).storeIOCs(eq(List.of(ioc)), eq(UPDATE_TYPE)); + } +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStoreIT.java b/tif/src/test/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStoreIT.java new file mode 100644 index 000000000..98e5a3361 --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStoreIT.java @@ -0,0 +1,179 @@ +package org.opensearch.securityanalytics.feed.store; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.http.HttpHost; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.action.admin.indices.refresh.RefreshRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.securityanalytics.connector.util.STIX2Generator; +import org.opensearch.securityanalytics.feed.store.model.UpdateType; +import org.opensearch.securityanalytics.index.IndexAccessor; +import org.opensearch.securityanalytics.index.RHLCIndexAccessor; +import org.opensearch.securityanalytics.model.IOC; +import org.opensearch.securityanalytics.model.STIX2; +import org.opensearch.securityanalytics.util.ResourceReader; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.securityanalytics.index.IndexAccessor.ROLLOVER_INDEX_FORMAT; + +public class SystemIndexFeedStoreIT { + private static final int NUMBER_OF_IOCS = new Random().nextInt(99) + 1; + private static final String FEED_ID = UUID.randomUUID().toString(); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private RestHighLevelClient client; + private IndexAccessor indexAccessor; + private SystemIndexFeedStore systemIndexFeedStore; + private STIX2Generator stix2Generator; + + @BeforeEach + public void setup() { + final String userName = System.getProperty("tests.opensearch.user"); + final String password = System.getProperty("tests.opensearch.password"); + final String host = System.getProperty("tests.opensearch.host"); + + final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials(userName, password); + credentialsProvider.setCredentials(AuthScope.ANY, credentials); + + final RestClientBuilder restClientBuilder = RestClient.builder(new HttpHost(host, 9200, "http")) + .setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider)); + client = new RestHighLevelClient(restClientBuilder); + + indexAccessor = new RHLCIndexAccessor(client, new ResourceReader(), OBJECT_MAPPER); + systemIndexFeedStore = new SystemIndexFeedStore(indexAccessor); + stix2Generator = new STIX2Generator(); + + deleteAlias(); + } + + @AfterEach + public void tearDown() { + //deleteAlias(); + } + +// @ParameterizedTest +// @MethodSource("getUpdateTypes") +// public void testStoreIOCs_Success(final UpdateType updateType) throws IOException { +// final List iocs = stix2Generator.generateSTIX2(NUMBER_OF_IOCS, FEED_ID); +// systemIndexFeedStore.storeIOCs(iocs, updateType); +// +// validateIOCs(iocs); +// } + + @Test + public void testStoreIOCs_ReplaceUpdateTypeDeletesOriginalIOCs() throws IOException { + // Load initial IOCs + final List iocs = stix2Generator.generateSTIX2(NUMBER_OF_IOCS, FEED_ID); + systemIndexFeedStore.storeIOCs(iocs, UpdateType.REPLACE); + validateIOCs(iocs); + + // Load replacement IOCs + final List replacementIOCs = stix2Generator.generateSTIX2(12, FEED_ID); + systemIndexFeedStore.storeIOCs(replacementIOCs, UpdateType.REPLACE); + validateIOCs(replacementIOCs); + } + +// @Test +// public void testStoreIOCs_DeltaUpdateTypeReplacesOriginalIOC() throws IOException { +// // Load initial IOCs +// final List iocs = stix2Generator.generateSTIX2(NUMBER_OF_IOCS, FEED_ID); +// systemIndexFeedStore.storeIOCs(iocs, UpdateType.DELTA); +// validateIOCs(iocs); +// +// final STIX2 originalIOC = (STIX2) iocs.get(0); +// final STIX2 indexedIOC = getIOCById(originalIOC.getFeedId(), originalIOC.getId()); +// assertEquals(originalIOC.getType(), indexedIOC.getType()); +// +// // Load replacement IOC and validate original IOCs still present +// final STIX2 updatedIOC = (STIX2) iocs.get(0); +// final String updatedType = UUID.randomUUID().toString(); +// updatedIOC.setType(updatedType); +// final List replacementIOCs = List.of(updatedIOC); +// systemIndexFeedStore.storeIOCs(replacementIOCs, UpdateType.DELTA); +// validateIOCs(iocs); +// +// final STIX2 replacedIOC = getIOCById(updatedIOC.getFeedId(), updatedIOC.getId()); +// assertEquals(updatedType, replacedIOC.getType()); +// assertNotEquals(indexedIOC.getType(), replacedIOC.getType()); +// } + + private static Stream getUpdateTypes() { + return Arrays.stream(UpdateType.values()) + .map(Arguments::of); + } + + private void validateIOCs(final List iocs) throws IOException { + refreshIndex(); + final SearchResponse searchResponse = searchIndex(); + assertEquals(iocs.size(), searchResponse.getHits().getHits().length); + + final Set actualDocIds = Arrays.stream(searchResponse.getHits().getHits()) + .map(SearchHit::getId) + .collect(Collectors.toSet()); + final Set expectedDocIds = iocs.stream() + .map(ioc -> String.format(SystemIndexFeedStore.IOC_DOC_ID_FORMAT, ioc.getFeedId(), ioc.getId())) + .collect(Collectors.toSet()); + + assertEquals(expectedDocIds, actualDocIds); + iocs.forEach(ioc -> assertEquals(FEED_ID, ioc.getFeedId())); + } + + private void deleteAlias() { + indexAccessor.deleteRolloverAlias(SystemIndexFeedStore.ALIAS_NAME); + } + + private void refreshIndex() throws IOException { + final RefreshRequest refreshRequest = new RefreshRequest(); + refreshRequest.indices(SystemIndexFeedStore.ALIAS_NAME); + client.indices().refresh(refreshRequest, RequestOptions.DEFAULT); + } + + private SearchResponse searchIndex() throws IOException { + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + searchSourceBuilder.size(10000); + + final SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(SystemIndexFeedStore.ALIAS_NAME); + searchRequest.source(searchSourceBuilder); + + return client.search(searchRequest, RequestOptions.DEFAULT); + } + + private STIX2 getIOCById(final String feedId, final String id) throws IOException { + final String docId = String.format(SystemIndexFeedStore.IOC_DOC_ID_FORMAT, feedId, id); + final GetRequest getRequest = new GetRequest(SystemIndexFeedStore.ALIAS_NAME, docId); + final GetResponse getResponse = client.get(getRequest, RequestOptions.DEFAULT); + + return OBJECT_MAPPER.readValue(getResponse.getSourceAsBytes(), STIX2.class); + } +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStoreTests.java b/tif/src/test/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStoreTests.java new file mode 100644 index 000000000..49f4b1bf9 --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/feed/store/SystemIndexFeedStoreTests.java @@ -0,0 +1,216 @@ +package org.opensearch.securityanalytics.feed.store; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.securityanalytics.exceptions.FeedStoreException; +import org.opensearch.securityanalytics.feed.store.model.UpdateType; +import org.opensearch.securityanalytics.index.IndexAccessor; +import org.opensearch.securityanalytics.model.IOC; +import org.opensearch.securityanalytics.model.STIX2; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.securityanalytics.feed.store.SystemIndexFeedStore.ALIAS_NAME; +import static org.opensearch.securityanalytics.feed.store.SystemIndexFeedStore.HIDDEN_INDEX; +import static org.opensearch.securityanalytics.feed.store.SystemIndexFeedStore.IOC_DOC_ID_FORMAT; +import static org.opensearch.securityanalytics.feed.store.SystemIndexFeedStore.PRIMARY_SHARD_COUNT; + +public class SystemIndexFeedStoreTests { + private static final String FEED_ID = UUID.randomUUID().toString(); + private static final String IOC_ID = UUID.randomUUID().toString(); + + @Mock + private IndexAccessor indexAccessor; + @Mock + private BulkResponse bulkResponse; + + private SystemIndexFeedStore systemIndexFeedStore; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + systemIndexFeedStore = new SystemIndexFeedStore(indexAccessor); + } + + @AfterEach + public void tearDown() { + verifyNoMoreInteractions(indexAccessor, bulkResponse); + } + + @ParameterizedTest + @MethodSource("getUpdateTypes") + public void testStoreIOCs_Success(final UpdateType updateType) { + final IOC ioc = getIOC(FEED_ID, IOC_ID); + when(indexAccessor.bulk(any(BulkRequest.class))).thenReturn(bulkResponse); + when(bulkResponse.hasFailures()).thenReturn(false); + + systemIndexFeedStore.storeIOCs(List.of(ioc), updateType); + + verifyCreateIndexIfPresent(); + if (updateType == UpdateType.REPLACE) { + verifyDeleteByQuery(); + } + verifyBulkIndexIOCs(List.of(ioc)); + verify(bulkResponse).hasFailures(); + } + + @ParameterizedTest + @MethodSource("getUpdateTypes") + public void testStoreIOCs_Success_MultipleIOCs(final UpdateType updateType) { + final IOC ioc = getIOC(FEED_ID, IOC_ID); + when(indexAccessor.bulk(any(BulkRequest.class))).thenReturn(bulkResponse); + when(bulkResponse.hasFailures()).thenReturn(false); + + systemIndexFeedStore.storeIOCs(List.of(ioc, ioc, ioc, ioc), updateType); + + verifyCreateIndexIfPresent(); + if (updateType == UpdateType.REPLACE) { + verifyDeleteByQuery(); + } + verifyBulkIndexIOCs(List.of(ioc, ioc, ioc, ioc)); + verify(bulkResponse).hasFailures(); + } + + @ParameterizedTest + @MethodSource("getUpdateTypes") + public void testStoreIOCs_NoIOCs(final UpdateType updateType) { + systemIndexFeedStore.storeIOCs(Collections.emptyList(), updateType); + } + + @ParameterizedTest + @MethodSource("getUpdateTypes") + public void testStoreIOCs_TooManyFeedIds(final UpdateType updateType) { + final IOC ioc = getIOC(FEED_ID, IOC_ID); + final IOC ioc2 = getIOC(UUID.randomUUID().toString(), IOC_ID); + + assertThrows(FeedStoreException.class, () -> systemIndexFeedStore.storeIOCs(List.of(ioc, ioc2), updateType)); + } + + @ParameterizedTest + @MethodSource("getUpdateTypes") + public void testStoreIOCs_ExceptionCreatingIndex(final UpdateType updateType) { + final IOC ioc = getIOC(FEED_ID, IOC_ID); + doThrow(new RuntimeException()).when(indexAccessor).createRolloverAlias(eq(ALIAS_NAME), any(Settings.class), any(Map.class)); + + assertThrows(FeedStoreException.class, () -> systemIndexFeedStore.storeIOCs(List.of(ioc), updateType)); + + verifyCreateIndexIfPresent(); + } + + @Test + public void testStoreIOCs_ExceptionDeletingByQuery() { + final IOC ioc = getIOC(FEED_ID, IOC_ID); + doThrow(new RuntimeException()).when(indexAccessor).deleteByQuery(eq(ALIAS_NAME), any(QueryBuilder.class)); + + assertThrows(FeedStoreException.class, () -> systemIndexFeedStore.storeIOCs(List.of(ioc), UpdateType.REPLACE)); + + verifyCreateIndexIfPresent(); + verifyDeleteByQuery(); + } + + @ParameterizedTest + @MethodSource("getUpdateTypes") + public void testStoreIOCs_ExceptionBulkingIOCs(final UpdateType updateType) { + final IOC ioc = getIOC(FEED_ID, IOC_ID); + when(indexAccessor.bulk(any(BulkRequest.class))).thenThrow(new RuntimeException()); + + assertThrows(FeedStoreException.class, () -> systemIndexFeedStore.storeIOCs(List.of(ioc), updateType)); + + verifyCreateIndexIfPresent(); + if (updateType == UpdateType.REPLACE) { + verifyDeleteByQuery(); + } + verifyBulkIndexIOCs(List.of(ioc)); + } + + @ParameterizedTest + @MethodSource("getUpdateTypes") + public void testStoreIOCs_BulkResponseHasFailures(final UpdateType updateType) { + final IOC ioc = getIOC(FEED_ID, IOC_ID); + when(indexAccessor.bulk(any(BulkRequest.class))).thenReturn(bulkResponse); + when(bulkResponse.hasFailures()).thenReturn(true); + + assertThrows(FeedStoreException.class, () -> systemIndexFeedStore.storeIOCs(List.of(ioc), updateType)); + + verifyCreateIndexIfPresent(); + if (updateType == UpdateType.REPLACE) { + verifyDeleteByQuery(); + } + verifyBulkIndexIOCs(List.of(ioc)); + verify(bulkResponse).hasFailures(); + verify(bulkResponse).buildFailureMessage(); + } + + private void verifyCreateIndexIfPresent() { + final ArgumentCaptor captor = ArgumentCaptor.forClass(Settings.class); + final ArgumentCaptor> rolloverCaptor = ArgumentCaptor.forClass(Map.class); + verify(indexAccessor).createRolloverAlias(eq(ALIAS_NAME), captor.capture(), rolloverCaptor.capture()); + + assertTrue(captor.getValue().hasValue(IndexAccessor.SHARD_COUNT_SETTING_NAME)); + assertEquals("" + PRIMARY_SHARD_COUNT, captor.getValue().get(IndexAccessor.SHARD_COUNT_SETTING_NAME)); + assertTrue(captor.getValue().hasValue(IndexAccessor.AUTO_EXPAND_REPLICA_COUNT_SETTING_NAME)); + assertEquals(IndexAccessor.EXPAND_ALL_REPLICA_COUNT_SETTING_VALUE, captor.getValue().get(IndexAccessor.AUTO_EXPAND_REPLICA_COUNT_SETTING_NAME)); + assertTrue(captor.getValue().hasValue(IndexAccessor.HIDDEN_INDEX_SETTING_NAME)); + assertEquals(String.valueOf(HIDDEN_INDEX), captor.getValue().get(IndexAccessor.HIDDEN_INDEX_SETTING_NAME)); + } + + private void verifyDeleteByQuery() { + final ArgumentCaptor captor = ArgumentCaptor.forClass(QueryBuilder.class); + verify(indexAccessor).deleteByQuery(eq(ALIAS_NAME), captor.capture()); + + assertNotNull(captor.getValue()); + } + + private void verifyBulkIndexIOCs(final List iocs) { + final ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); + verify(indexAccessor).bulk(captor.capture()); + + assertEquals(iocs.size(), captor.getValue().requests().size()); + IntStream.range(0, iocs.size()).forEach(i -> { + final IndexRequest indexRequest = (IndexRequest) captor.getValue().requests().get(i); + assertEquals(ALIAS_NAME, indexRequest.index()); + assertEquals(String.format(IOC_DOC_ID_FORMAT, FEED_ID, IOC_ID), indexRequest.id()); + }); + } + + private static Stream getUpdateTypes() { + return Arrays.stream(UpdateType.values()) + .map(Arguments::of); + } + + private IOC getIOC(final String feedId, final String iocId) { + final IOC ioc = new STIX2(); + ioc.setFeedId(feedId); + ioc.setId(iocId); + + return ioc; + } +} diff --git a/tif/src/test/java/org/opensearch/securityanalytics/index/RHLCIndexAccessorTests.java b/tif/src/test/java/org/opensearch/securityanalytics/index/RHLCIndexAccessorTests.java new file mode 100644 index 000000000..b915ab3f3 --- /dev/null +++ b/tif/src/test/java/org/opensearch/securityanalytics/index/RHLCIndexAccessorTests.java @@ -0,0 +1,180 @@ +//package org.opensearch.securityanalytics.index; +// +//import org.junit.jupiter.api.AfterEach; +//import org.junit.jupiter.api.BeforeEach; +//import org.junit.jupiter.api.Test; +//import org.mockito.ArgumentCaptor; +//import org.mockito.Mock; +//import org.mockito.MockitoAnnotations; +//import org.opensearch.OpenSearchStatusException; +//import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +//import org.opensearch.action.bulk.BulkRequest; +//import org.opensearch.action.bulk.BulkResponse; +//import org.opensearch.client.IndicesClient; +//import org.opensearch.client.RequestOptions; +//import org.opensearch.client.RestHighLevelClient; +//import org.opensearch.client.indices.CreateIndexRequest; +//import org.opensearch.client.indices.CreateIndexResponse; +//import org.opensearch.common.settings.Settings; +//import org.opensearch.core.rest.RestStatus; +//import org.opensearch.index.query.QueryBuilder; +//import org.opensearch.index.reindex.BulkByScrollResponse; +//import org.opensearch.index.reindex.DeleteByQueryRequest; +//import org.opensearch.securityanalytics.exceptions.IndexAccessorException; +// +//import java.io.IOException; +//import java.util.UUID; +// +//import static org.junit.jupiter.api.Assertions.assertEquals; +//import static org.junit.jupiter.api.Assertions.assertThrows; +//import static org.mockito.ArgumentMatchers.any; +//import static org.mockito.ArgumentMatchers.eq; +//import static org.mockito.Mockito.verify; +//import static org.mockito.Mockito.verifyNoMoreInteractions; +//import static org.mockito.Mockito.when; +// +//public class RHLCIndexAccessorTests { +// private static final String INDEX = UUID.randomUUID().toString(); +// +// @Mock +// private RestHighLevelClient client; +// @Mock +// private IndicesClient indicesClient; +// @Mock +// private CreateIndexResponse createIndexResponse; +// @Mock +// private DeleteByQueryRequest deleteByQueryRequest; +// @Mock +// private BulkByScrollResponse bulkByScrollResponse; +// @Mock +// private BulkRequest bulkRequest; +// @Mock +// private BulkResponse bulkResponse; +// @Mock +// private Settings settings; +// @Mock +// private QueryBuilder queryBuilder; +// +// private RHLCIndexAccessor rhlcIndexAccessor; +// +// @BeforeEach +// public void setup() { +// MockitoAnnotations.openMocks(this); +// rhlcIndexAccessor = new RHLCIndexAccessor(client); +// } +// +// @AfterEach +// public void teardown() { +// verifyNoMoreInteractions(client, indicesClient, createIndexResponse, deleteByQueryRequest, bulkByScrollResponse, +// bulkRequest, bulkResponse, settings, queryBuilder); +// } +// +// @Test +// public void testCreateIndex() throws IOException { +// when(client.indices()).thenReturn(indicesClient); +// +// rhlcIndexAccessor.createRolloverAlias(INDEX, settings); +// +// verify(client).indices(); +// final ArgumentCaptor captor = ArgumentCaptor.forClass(CreateIndexRequest.class); +// verify(indicesClient).create(captor.capture(), eq(RequestOptions.DEFAULT)); +// assertEquals(INDEX, captor.getValue().index()); +// assertEquals(settings, captor.getValue().settings()); +// } +// +// @Test +// public void testCreateIndex_ExceptionCreatingIndex() throws IOException { +// when(client.indices()).thenReturn(indicesClient); +// when(indicesClient.create(any(CreateIndexRequest.class), eq(RequestOptions.DEFAULT))).thenThrow(new RuntimeException()); +// +// assertThrows(IndexAccessorException.class, () -> rhlcIndexAccessor.createRolloverAlias(INDEX, settings)); +// +// verify(client).indices(); +// final ArgumentCaptor captor = ArgumentCaptor.forClass(CreateIndexRequest.class); +// verify(indicesClient).create(captor.capture(), eq(RequestOptions.DEFAULT)); +// assertEquals(INDEX, captor.getValue().index()); +// assertEquals(settings, captor.getValue().settings()); +// } +// +// @Test +// public void testDeleteIndex() throws IOException { +// when(client.indices()).thenReturn(indicesClient); +// +// rhlcIndexAccessor.deleteIndex(INDEX); +// +// verify(client).indices(); +// final ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteIndexRequest.class); +// verify(indicesClient).delete(captor.capture(), eq(RequestOptions.DEFAULT)); +// assertEquals(1, captor.getValue().indices().length); +// assertEquals(INDEX, captor.getValue().indices()[0]); +// } +// +// @Test +// public void testDeleteIndex_IndexDoesNotExist() throws IOException { +// when(client.indices()).thenReturn(indicesClient); +// when(indicesClient.delete(any(DeleteIndexRequest.class), eq(RequestOptions.DEFAULT))) +// .thenThrow(new OpenSearchStatusException("index_not_found_exception", RestStatus.NOT_FOUND)); +// +// rhlcIndexAccessor.deleteIndex(INDEX); +// +// verify(client).indices(); +// final ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteIndexRequest.class); +// verify(indicesClient).delete(captor.capture(), eq(RequestOptions.DEFAULT)); +// assertEquals(1, captor.getValue().indices().length); +// assertEquals(INDEX, captor.getValue().indices()[0]); +// } +// +// @Test +// public void testDeleteIndex_ExceptionDeletingIndex() throws IOException { +// when(client.indices()).thenReturn(indicesClient); +// when(indicesClient.delete(any(DeleteIndexRequest.class), eq(RequestOptions.DEFAULT))).thenThrow(new RuntimeException()); +// +// assertThrows(IndexAccessorException.class, () -> rhlcIndexAccessor.deleteIndex(INDEX)); +// +// verify(client).indices(); +// final ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteIndexRequest.class); +// verify(indicesClient).delete(captor.capture(), eq(RequestOptions.DEFAULT)); +// assertEquals(1, captor.getValue().indices().length); +// assertEquals(INDEX, captor.getValue().indices()[0]); +// } +// +// @Test +// public void testDeleteByQuery() throws IOException { +// when(client.deleteByQuery(any(DeleteByQueryRequest.class), eq(RequestOptions.DEFAULT))).thenReturn(bulkByScrollResponse); +// +// final BulkByScrollResponse result = rhlcIndexAccessor.deleteByQuery(INDEX, queryBuilder); +// assertEquals(bulkByScrollResponse, result); +// +// verify(client).deleteByQuery(eq(deleteByQueryRequest), eq(RequestOptions.DEFAULT)); +// } +// +// @Test +// public void testDeleteByQuery_ExceptionDeletingByQuery() throws IOException { +// when(client.deleteByQuery(any(DeleteByQueryRequest.class), eq(RequestOptions.DEFAULT))).thenThrow(new RuntimeException()); +// when(deleteByQueryRequest.indices()).thenReturn(new String[] { INDEX }); +// +// assertThrows(IndexAccessorException.class, () -> rhlcIndexAccessor.deleteByQuery(INDEX, queryBuilder)); +// +// verify(client).deleteByQuery(eq(deleteByQueryRequest), eq(RequestOptions.DEFAULT)); +// verify(deleteByQueryRequest).indices(); +// } +// +// @Test +// public void testBulk() throws IOException { +// when(client.bulk(eq(bulkRequest), eq(RequestOptions.DEFAULT))).thenReturn(bulkResponse); +// +// final BulkResponse result = rhlcIndexAccessor.bulk(bulkRequest); +// assertEquals(bulkResponse, result); +// +// verify(client).bulk(eq(bulkRequest), eq(RequestOptions.DEFAULT)); +// } +// +// @Test +// public void testBulk_ExceptionBulking() throws IOException { +// when(client.bulk(eq(bulkRequest), eq(RequestOptions.DEFAULT))).thenThrow(new RuntimeException()); +// +// assertThrows(IndexAccessorException.class, () -> rhlcIndexAccessor.bulk(bulkRequest)); +// +// verify(client).bulk(eq(bulkRequest), eq(RequestOptions.DEFAULT)); +// } +//}