diff --git a/README.md b/README.md
index 03f5787..e0d3067 100644
--- a/README.md
+++ b/README.md
@@ -331,6 +331,22 @@ These are the options for `DruidSource`, to be passed with `write.options()`.
| `druid.segment_storage.hdfs.security.kerberos.principal` | Kerberos principal |
| `druid.segment_storage.hdfs.security.kerberos.keytab` | Kerberos keytab |
+4. **If Deep Storage is `azure`:**
+
+ | Property | Description |
+ | --- | --- |
+ | `druid.azure.account` | Azure account |
+ | `druid.azure.key` | Azure key |
+ | `druid.azure.sharedAccessStorageToken` | Azure token (if no key) |
+ | `druid.azure.useAzureCredentialsChain` | Use DefaultAzureCredential for authentication |
+ | `druid.azure.managedIdentityClientId` | If you want to use managed identity authentication in the DefaultAzureCredential, useAzureCredentialsChain must be true. |
+ | `druid.azure.endpointSuffix` | The endpoint suffix to use. Override the default value to connect to |
+ | `druid.azure.container` | Azure container |
+ | `druid.azure.prefix` | Azure prefix |
+ | `druid.azure.protocol` | Azure protocol (http or https ) |
+ | `druid.azure.maxTries` | Max tries to connect to Azure |
+ | `druid.azure.maxListingLength` | Azure max listing length |
+
#### Optional properties
| Property | Description | Default |
@@ -342,7 +358,7 @@ These are the options for `DruidSource`, to be passed with `write.options()`.
| `druid.exclude_dimensions` | Comma separated list of Spark input columns that have to be excluded in Druid ingestion | |
| `druid.segment.max_rows` | Max number of rows per segment | `5000000` |
| `druid.memory.max_rows` | Max number of rows to keep in memory in spark data writer | `75000` |
-| `druid.segment_storage.type` | Type of Deep Storage to use. Allowed values: `s3`, `local`, `hdfs`. | `s3` |
+| `druid.segment_storage.type` | Type of Deep Storage to use. Allowed values: `s3`, `local`, `hdfs`, `azure`. | `s3` |
| `druid.segment_storage.s3.disableacl` | Whether to disable ACL in S3 config. | `false` |
| `druid.datasource.init` | Boolean flag for (re-)initializing Druid datasource. If `true`, any pre-existing segments for the datasource is marked as unused. | `false` |
| `druid.bitmap_factory` | Compression format for bitmap indexes. Possible values: `concise`, `roaring`. For type `roaring`, the boolean property compressRunOnSerialization is always set to `true`. `rovio-ingest` uses `concise` by default regardless of Druid library version. | `concise` |
diff --git a/pom.xml b/pom.xml
index 9446b5d..374a226 100644
--- a/pom.xml
+++ b/pom.xml
@@ -206,6 +206,11 @@
druid-hdfs-storage
${druid.version}
+
+ org.apache.druid.extensions
+ druid-azure-extensions
+ ${druid.version}
+
org.apache.druid.extensions
druid-datasketches
diff --git a/src/main/java/com/rovio/ingest/WriterContext.java b/src/main/java/com/rovio/ingest/WriterContext.java
index 6e3f807..862252d 100644
--- a/src/main/java/com/rovio/ingest/WriterContext.java
+++ b/src/main/java/com/rovio/ingest/WriterContext.java
@@ -61,6 +61,17 @@ public class WriterContext implements Serializable {
private final String hdfsDefaultFS;
private final String hdfsSecurityKerberosPrincipal;
private final String hdfsSecurityKerberosKeytab;
+ private final String azureAccount;
+ private final String azureKey;
+ private final String azureSharedAccessStorageToken;
+ private final Boolean azureUseAzureCredentialsChain;
+ private final String azureContainer;
+ private final String azurePrefix;
+ private final String azureManagedIdentityClientId;
+ private final String azureProtocol;
+ private final int azureMaxTries;
+ private final int azureMaxListingLength;
+ private final String azureEndpointSuffix;
private final String deepStorageType;
private final boolean initDataSource;
private final String version;
@@ -108,9 +119,20 @@ private WriterContext(CaseInsensitiveStringMap options, String version) {
this.hdfsDefaultFS = options.getOrDefault(ConfKeys.DEEP_STORAGE_HDFS_DEFAULT_FS, null);
this.hdfsSecurityKerberosPrincipal = options.getOrDefault(ConfKeys.DEEP_STORAGE_HDFS_SECURITY_KERBEROS_PRINCIPAL, null);
this.hdfsSecurityKerberosKeytab = options.getOrDefault(ConfKeys.DEEP_STORAGE_HDFS_SECURITY_KERBEROS_KEYTAB, null);
+ this.azureAccount = options.getOrDefault(ConfKeys.DEEP_STORAGE_AZURE_ACCOUNT, null);
+ this.azureKey = options.getOrDefault(ConfKeys.DEEP_STORAGE_AZURE_KEY, null);
+ this.azureSharedAccessStorageToken = options.getOrDefault(ConfKeys.DEEP_STORAGE_AZURE_SHAREDACCESSSTORAGETOKEN, null);
+ this.azureUseAzureCredentialsChain = options.getBoolean(ConfKeys.DEEP_STORAGE_AZURE_USEAZURECRENDENTIALSCHAIN, false);
+ this.azureContainer = options.getOrDefault(ConfKeys.DEEP_STORAGE_AZURE_CONTAINER, null);
+ this.azurePrefix = options.getOrDefault(ConfKeys.DEEP_STORAGE_AZURE_PREFIX, "");
+ this.azureProtocol = options.getOrDefault(ConfKeys.DEEP_STORAGE_AZURE_PROTOCOL, "https");
+ this.azureMaxTries = options.getInt(ConfKeys.DEEP_STORAGE_AZURE_MAXTRIES, 3);
+ this.azureMaxListingLength = options.getInt(ConfKeys.DEEP_STORAGE_AZURE_MAXLISTINGLENGTH, 1024);
+ this.azureEndpointSuffix = options.getOrDefault(ConfKeys.DEEP_STORAGE_AZURE_ENDPOINTSUFFIX, "core.windows.net");
+ this.azureManagedIdentityClientId = options.getOrDefault(ConfKeys.DEEP_STORAGE_AZURE_MANAGEDIDENTITYCLIENTID, null);
this.deepStorageType = options.getOrDefault(ConfKeys.DEEP_STORAGE_TYPE, DEFAULT_DRUID_DEEP_STORAGE_TYPE);
- Preconditions.checkArgument(Arrays.asList("s3", "local", "hdfs").contains(this.deepStorageType),
+ Preconditions.checkArgument(Arrays.asList("s3", "local", "hdfs", "azure").contains(this.deepStorageType),
String.format("Invalid %s: %s", ConfKeys.DEEP_STORAGE_TYPE, this.deepStorageType));
this.initDataSource = options.getBoolean(ConfKeys.DATASOURCE_INIT, false);
@@ -228,6 +250,50 @@ public String getHdfsSecurityKerberosKeytab() {
return hdfsSecurityKerberosKeytab;
}
+ public String getAzureAccount() {
+ return azureAccount;
+ }
+
+ public String getAzureKey() {
+ return azureKey;
+ }
+
+ public String getAzureSharedAccessStorageToken() {
+ return azureSharedAccessStorageToken;
+ }
+
+ public Boolean getAzureUseAzureCredentialsChain() {
+ return azureUseAzureCredentialsChain;
+ }
+
+ public String getAzureContainer() {
+ return azureContainer;
+ }
+
+ public String getAzurePrefix() {
+ return azurePrefix;
+ }
+
+ public String getAzureProtocol() {
+ return azureProtocol;
+ }
+
+ public int getAzureMaxTries() {
+ return azureMaxTries;
+ }
+
+ public int getAzureMaxListingLength() {
+ return azureMaxListingLength;
+ }
+
+ public String getAzureEndpointSuffix() {
+ return azureEndpointSuffix;
+ }
+
+ public String getAzureManagedIdentityClientId() {
+ return azureManagedIdentityClientId;
+ }
+
public boolean isInitDataSource() {
return initDataSource;
}
@@ -244,6 +310,10 @@ public boolean isHdfsDeepStorage() {
return "hdfs".equals(deepStorageType);
}
+ public boolean isAzureDeepStorage() {
+ return "azure".equals(deepStorageType);
+ }
+
public boolean isRollup() {
return rollup;
}
@@ -306,5 +376,17 @@ public static class ConfKeys {
public static final String DEEP_STORAGE_HDFS_DEFAULT_FS = "druid.segment_storage.hdfs.default.fs";
public static final String DEEP_STORAGE_HDFS_SECURITY_KERBEROS_PRINCIPAL = "druid.segment_storage.hdfs.security.kerberos.principal";
public static final String DEEP_STORAGE_HDFS_SECURITY_KERBEROS_KEYTAB = "druid.segment_storage.hdfs.security.kerberos.keytab";
+ // Azure config
+ public static final String DEEP_STORAGE_AZURE_ACCOUNT = "druid.azure.account";
+ public static final String DEEP_STORAGE_AZURE_KEY = "druid.azure.key";
+ public static final String DEEP_STORAGE_AZURE_SHAREDACCESSSTORAGETOKEN = "druid.azure.sharedAccessStorageToken";
+ public static final String DEEP_STORAGE_AZURE_USEAZURECRENDENTIALSCHAIN = "druid.azure.useAzureCredentialsChain";
+ public static final String DEEP_STORAGE_AZURE_CONTAINER = "druid.azure.container";
+ public static final String DEEP_STORAGE_AZURE_PREFIX = "druid.azure.prefix";
+ public static final String DEEP_STORAGE_AZURE_PROTOCOL = "druid.azure.protocol";
+ public static final String DEEP_STORAGE_AZURE_MAXTRIES = "druid.azure.maxTries";
+ public static final String DEEP_STORAGE_AZURE_MAXLISTINGLENGTH = "druid.azure.maxListingLength";
+ public static final String DEEP_STORAGE_AZURE_ENDPOINTSUFFIX = "druid.azure.endpointSuffix";
+ public static final String DEEP_STORAGE_AZURE_MANAGEDIDENTITYCLIENTID = "druid.azure.managedIdentityClientId";
}
}
diff --git a/src/main/java/com/rovio/ingest/util/SegmentStorageUpdater.java b/src/main/java/com/rovio/ingest/util/SegmentStorageUpdater.java
index 4a42afb..4db5a2a 100644
--- a/src/main/java/com/rovio/ingest/util/SegmentStorageUpdater.java
+++ b/src/main/java/com/rovio/ingest/util/SegmentStorageUpdater.java
@@ -20,11 +20,19 @@
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.rovio.ingest.WriterContext;
+import com.rovio.ingest.util.azure.LocalAzureAccountConfig;
+import com.rovio.ingest.util.azure.LocalAzureClientFactory;
+import com.rovio.ingest.util.azure.LocalAzureCloudBlobIterableFactory;
import org.apache.druid.segment.loading.DataSegmentKiller;
import org.apache.druid.segment.loading.DataSegmentPusher;
import org.apache.druid.segment.loading.LocalDataSegmentKiller;
import org.apache.druid.segment.loading.LocalDataSegmentPusher;
import org.apache.druid.segment.loading.LocalDataSegmentPusherConfig;
+import org.apache.druid.storage.azure.AzureDataSegmentConfig;
+import org.apache.druid.storage.azure.AzureDataSegmentKiller;
+import org.apache.druid.storage.azure.AzureDataSegmentPusher;
+import org.apache.druid.storage.azure.AzureInputDataConfig;
+import org.apache.druid.storage.azure.AzureStorage;
import org.apache.druid.storage.hdfs.HdfsDataSegmentKiller;
import org.apache.druid.storage.hdfs.HdfsDataSegmentPusher;
import org.apache.druid.storage.hdfs.HdfsDataSegmentPusherConfig;
@@ -59,6 +67,30 @@ public static DataSegmentPusher createPusher(WriterContext param) {
getHdfsHadoopConfiguration(param.getHdfsCoreSitePath(), param.getHdfsHdfsSitePath(), param.getHdfsDefaultFS()),
MAPPER
);
+ } else if (param.isAzureDeepStorage()) {
+ LocalAzureAccountConfig azureAccountConfig = new LocalAzureAccountConfig();
+ azureAccountConfig.setAccount(param.getAzureAccount());
+ if (param.getAzureKey() != null && !param.getAzureKey().isEmpty()) {
+ azureAccountConfig.setKey(param.getAzureKey());
+ }
+ if (param.getAzureSharedAccessStorageToken() != null && !param.getAzureSharedAccessStorageToken().isEmpty()) {
+ azureAccountConfig.setSharedAccessStorageToken(param.getAzureSharedAccessStorageToken());
+ }
+ if (param.getAzureEndpointSuffix() != null && !param.getAzureEndpointSuffix().isEmpty()) {
+ azureAccountConfig.setEndpointSuffix(param.getAzureEndpointSuffix());
+ }
+ if (param.getAzureManagedIdentityClientId() != null && !param.getAzureManagedIdentityClientId().isEmpty()) {
+ azureAccountConfig.setManagedIdentityClientId(param.getAzureManagedIdentityClientId());
+ }
+ azureAccountConfig.setUseAzureCredentialsChain(param.getAzureUseAzureCredentialsChain());
+ azureAccountConfig.setProtocol(param.getAzureProtocol());
+ azureAccountConfig.setMaxTries(param.getAzureMaxTries());
+ LocalAzureClientFactory azureClientFactory = new LocalAzureClientFactory(azureAccountConfig);
+ AzureStorage azureStorage = new AzureStorage(azureClientFactory);
+ AzureDataSegmentConfig azureDataSegmentConfig = new AzureDataSegmentConfig();
+ azureDataSegmentConfig.setContainer(param.getAzureContainer());
+ azureDataSegmentConfig.setPrefix(param.getAzurePrefix());
+ return new AzureDataSegmentPusher(azureStorage, azureAccountConfig, azureDataSegmentConfig);
} else {
ServerSideEncryptingAmazonS3 serverSideEncryptingAmazonS3 = getAmazonS3().get();
S3DataSegmentPusherConfig s3Config = new S3DataSegmentPusherConfig();
@@ -84,6 +116,34 @@ public static DataSegmentKiller createKiller(WriterContext param) {
getHdfsHadoopConfiguration(param.getHdfsCoreSitePath(), param.getHdfsHdfsSitePath(), param.getHdfsDefaultFS())
)
);
+ } else if (param.isAzureDeepStorage()) {
+
+ LocalAzureAccountConfig azureAccountConfig = new LocalAzureAccountConfig();
+ azureAccountConfig.setAccount(param.getAzureAccount());
+ if (param.getAzureKey() != null && !param.getAzureKey().isEmpty()) {
+ azureAccountConfig.setKey(param.getAzureKey());
+ }
+ if (param.getAzureSharedAccessStorageToken() != null && !param.getAzureSharedAccessStorageToken().isEmpty()) {
+ azureAccountConfig.setSharedAccessStorageToken(param.getAzureSharedAccessStorageToken());
+ }
+ if (param.getAzureEndpointSuffix() != null && !param.getAzureEndpointSuffix().isEmpty()) {
+ azureAccountConfig.setEndpointSuffix(param.getAzureEndpointSuffix());
+ }
+ if (param.getAzureManagedIdentityClientId() != null && !param.getAzureManagedIdentityClientId().isEmpty()) {
+ azureAccountConfig.setManagedIdentityClientId(param.getAzureManagedIdentityClientId());
+ }
+ azureAccountConfig.setUseAzureCredentialsChain(param.getAzureUseAzureCredentialsChain());
+ azureAccountConfig.setProtocol(param.getAzureProtocol());
+ azureAccountConfig.setMaxTries(param.getAzureMaxTries());
+ LocalAzureClientFactory azureClientFactory = new LocalAzureClientFactory(azureAccountConfig);
+ AzureStorage azureStorage = new AzureStorage(azureClientFactory);
+ AzureDataSegmentConfig azureDataSegmentConfig = new AzureDataSegmentConfig();
+ azureDataSegmentConfig.setContainer(param.getAzureContainer());
+ azureDataSegmentConfig.setPrefix(param.getAzurePrefix());
+ AzureInputDataConfig azureInputDataConfig = new AzureInputDataConfig();
+ azureInputDataConfig.setMaxListingLength(param.getAzureMaxListingLength());
+ LocalAzureCloudBlobIterableFactory azureFactory = new LocalAzureCloudBlobIterableFactory();
+ return new AzureDataSegmentKiller(azureDataSegmentConfig, azureInputDataConfig, azureAccountConfig, azureStorage, azureFactory);
} else {
Supplier serverSideEncryptingAmazonS3 = getAmazonS3();
S3DataSegmentPusherConfig s3Config = new S3DataSegmentPusherConfig();
diff --git a/src/main/java/com/rovio/ingest/util/azure/LocalAzureAccountConfig.java b/src/main/java/com/rovio/ingest/util/azure/LocalAzureAccountConfig.java
new file mode 100644
index 0000000..6ca12a0
--- /dev/null
+++ b/src/main/java/com/rovio/ingest/util/azure/LocalAzureAccountConfig.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2021 Rovio Entertainment Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.rovio.ingest.util.azure;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import org.apache.druid.storage.azure.AzureAccountConfig;
+
+public class LocalAzureAccountConfig extends AzureAccountConfig {
+ @JsonProperty
+ private String managedIdentityClientId;
+
+ @SuppressWarnings("unused") // Used by Jackson deserialization?
+ public void setManagedIdentityClientId(String managedIdentityClientId) {
+ this.managedIdentityClientId = managedIdentityClientId;
+ }
+
+ public String getManagedIdentityClientId() {
+ return managedIdentityClientId;
+ }
+}
diff --git a/src/main/java/com/rovio/ingest/util/azure/LocalAzureClientFactory.java b/src/main/java/com/rovio/ingest/util/azure/LocalAzureClientFactory.java
new file mode 100644
index 0000000..4cf7936
--- /dev/null
+++ b/src/main/java/com/rovio/ingest/util/azure/LocalAzureClientFactory.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2021 Rovio Entertainment Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.rovio.ingest.util.azure;
+
+import com.azure.core.http.policy.ExponentialBackoffOptions;
+import com.azure.core.http.policy.RetryOptions;
+import com.azure.identity.DefaultAzureCredentialBuilder;
+import com.azure.storage.blob.BlobContainerClient;
+import com.azure.storage.blob.BlobServiceClient;
+import com.azure.storage.blob.BlobServiceClientBuilder;
+import com.azure.storage.blob.batch.BlobBatchClient;
+import com.azure.storage.blob.batch.BlobBatchClientBuilder;
+import com.azure.storage.common.StorageSharedKeyCredential;
+
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.druid.storage.azure.AzureClientFactory;
+
+public class LocalAzureClientFactory extends AzureClientFactory {
+
+ private final LocalAzureAccountConfig config;
+ private final Map cachedBlobServiceClients;
+
+ public LocalAzureClientFactory(LocalAzureAccountConfig config) {
+ super(config);
+ this.config = config;
+ this.cachedBlobServiceClients = new HashMap<>();
+ }
+
+ // It's okay to store clients in a map here because all the configs for specifying azure retries are static, and there are only 2 of them.
+ // The 2 configs are AzureAccountConfig.maxTries and AzureOutputConfig.maxRetry.
+ // We will only ever have at most 2 clients in cachedBlobServiceClients.
+ public BlobServiceClient getBlobServiceClient(Integer retryCount) {
+ if (!cachedBlobServiceClients.containsKey(retryCount)) {
+ BlobServiceClientBuilder clientBuilder = getAuthenticatedBlobServiceClientBuilder()
+ .retryOptions(new RetryOptions(
+ new ExponentialBackoffOptions()
+ .setMaxRetries(retryCount != null ? retryCount : config.getMaxTries())
+ .setBaseDelay(Duration.ofMillis(1000))
+ .setMaxDelay(Duration.ofMillis(60000))
+ ));
+ cachedBlobServiceClients.put(retryCount, clientBuilder.buildClient());
+ }
+
+ return cachedBlobServiceClients.get(retryCount);
+ }
+
+ // Mainly here to make testing easier.
+ public BlobBatchClient getBlobBatchClient(BlobContainerClient blobContainerClient) {
+ return new BlobBatchClientBuilder(blobContainerClient).buildClient();
+ }
+
+ private BlobServiceClientBuilder getAuthenticatedBlobServiceClientBuilder() {
+ BlobServiceClientBuilder clientBuilder = new BlobServiceClientBuilder()
+ .endpoint(config.getProtocol() + "://" + config.getAccount() + "." + config.getBlobStorageEndpoint());
+
+ if (config.getKey() != null) {
+ clientBuilder.credential(new StorageSharedKeyCredential(config.getAccount(), config.getKey()));
+ } else if (config.getSharedAccessStorageToken() != null) {
+ clientBuilder.sasToken(config.getSharedAccessStorageToken());
+ } else if (config.getUseAzureCredentialsChain()) {
+ // We might not use the managed identity client id in the credential chain but we can just set it here and it will no-op.
+ DefaultAzureCredentialBuilder defaultAzureCredentialBuilder = new DefaultAzureCredentialBuilder()
+ .managedIdentityClientId(config.getManagedIdentityClientId());
+ clientBuilder.credential(defaultAzureCredentialBuilder.build());
+ }
+ return clientBuilder;
+ }
+}
diff --git a/src/main/java/com/rovio/ingest/util/azure/LocalAzureCloudBlobIterableFactory.java b/src/main/java/com/rovio/ingest/util/azure/LocalAzureCloudBlobIterableFactory.java
new file mode 100644
index 0000000..4b0b841
--- /dev/null
+++ b/src/main/java/com/rovio/ingest/util/azure/LocalAzureCloudBlobIterableFactory.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2021 Rovio Entertainment Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.rovio.ingest.util.azure;
+
+import org.apache.druid.storage.azure.AzureCloudBlobIterable;
+import org.apache.druid.storage.azure.AzureCloudBlobIterableFactory;
+import org.apache.druid.storage.azure.AzureCloudBlobIteratorFactory;
+
+import java.net.URI;
+
+public class LocalAzureCloudBlobIterableFactory implements AzureCloudBlobIterableFactory {
+ AzureCloudBlobIteratorFactory azureCloudBlobIteratorFactory;
+
+ @Override
+ public AzureCloudBlobIterable create(Iterable prefixes, int maxListingLength) {
+ return new AzureCloudBlobIterable(azureCloudBlobIteratorFactory, prefixes, maxListingLength);
+ }
+}
diff --git a/src/test/java/com/rovio/ingest/DruidDeepStorageAzureTest.java b/src/test/java/com/rovio/ingest/DruidDeepStorageAzureTest.java
new file mode 100644
index 0000000..8028a35
--- /dev/null
+++ b/src/test/java/com/rovio/ingest/DruidDeepStorageAzureTest.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2021 Rovio Entertainment Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.rovio.ingest;
+
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.junit.jupiter.Container;
+import org.testcontainers.junit.jupiter.Testcontainers;
+import org.testcontainers.utility.DockerImageName;
+
+import java.util.HashMap;
+import java.util.Map;
+
+@Testcontainers
+public class DruidDeepStorageAzureTest {
+ public static String azureAccount = "user";
+ public static String azureKey = "key";
+ public static String azureContainer = "container";
+ public static String azurePrefix = "prefix";
+ public static String azureProtocol = "http";
+ public static Integer azurePort = 10000;
+ public static final DockerImageName AZURE_IMAGE = DockerImageName.parse("mcr.microsoft.com/azure-storage/azurite:latest");
+ @Container
+ public static GenericContainer> AZURE = getAzureContainer();
+
+ public static GenericContainer> getAzureContainer() {
+ return new GenericContainer<>(AZURE_IMAGE)
+ .withEnv("AZURITE_ACCOUNTS", azureAccount + ":" + azureKey)
+ .withExposedPorts(azurePort);
+ }
+
+ public static Integer getAzureMappedPort() {
+ return AZURE.getMappedPort(azurePort);
+ }
+
+ public static String getAzureHost() {
+ return AZURE.getHost();
+ }
+
+ public static String getAzureEndPointSuffix() {
+ return getAzureHost() + ":" + getAzureMappedPort();
+ }
+
+ public static Map getAzureOptions() {
+ Map options = new HashMap<>();
+ options.put(WriterContext.ConfKeys.DEEP_STORAGE_AZURE_ACCOUNT, azureAccount);
+ options.put(WriterContext.ConfKeys.DEEP_STORAGE_AZURE_KEY, azureKey);
+ options.put(WriterContext.ConfKeys.DEEP_STORAGE_AZURE_CONTAINER, azureContainer);
+ options.put(WriterContext.ConfKeys.DEEP_STORAGE_AZURE_PREFIX, azurePrefix);
+ options.put(WriterContext.ConfKeys.DEEP_STORAGE_AZURE_PROTOCOL, azureProtocol);
+ options.put(WriterContext.ConfKeys.DEEP_STORAGE_AZURE_ENDPOINTSUFFIX, getAzureEndPointSuffix());
+ return options;
+ }
+}
diff --git a/src/test/scala/com/rovio/ingest/DruidDatasetExtensionsAzureSpec.scala b/src/test/scala/com/rovio/ingest/DruidDatasetExtensionsAzureSpec.scala
new file mode 100644
index 0000000..c0b354f
--- /dev/null
+++ b/src/test/scala/com/rovio/ingest/DruidDatasetExtensionsAzureSpec.scala
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2021 Rovio Entertainment Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.rovio.ingest
+
+import org.scalatest._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{Dataset, SaveMode, SparkSession}
+import org.junit.runner.RunWith
+import org.scalatest.matchers.should.Matchers
+import org.scalatestplus.junit.JUnitRunner
+import org.scalatest.flatspec.AnyFlatSpec
+
+import scala.collection.JavaConverters._
+
+// This is needed for mvn test. It wouldn't find this test otherwise.
+@RunWith(classOf[JUnitRunner])
+class DruidDatasetExtensionsAzureSpec extends AnyFlatSpec with Matchers with BeforeAndAfter with BeforeAndAfterEach {
+
+ before {
+ DruidDeepStorageAzureTest.AZURE.start()
+ DruidSourceBaseTest.MYSQL.start()
+ DruidSourceBaseTest.prepareDatabase(DruidSourceBaseTest.MYSQL)
+ }
+
+ after {
+ DruidSourceBaseTest.MYSQL.stop()
+ DruidDeepStorageAzureTest.AZURE.stop()
+ }
+
+ // Could instead try assertSmallDataFrameEquality from
+ // https://github.com/MrPowers/spark-fast-tests
+ // for now avoid the dependency and collect as an array & use toSet to ignore order
+ def assertEqual[T](expected: Dataset[T], actual: Dataset[T]): Assertion =
+ actual.collect().toSet should be(expected.collect().toSet)
+
+ lazy val spark: SparkSession = {
+ SparkSession.builder()
+ .appName("Spark/MLeap Parity Tests")
+ .config("spark.sql.session.timeZone", "UTC")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .master("local[2]")
+ .getOrCreate()
+ }
+
+ import spark.implicits._
+
+ // import the implicit classes – those provide the functions being tested!
+ import com.rovio.ingest.extensions.DruidDatasetExtensions._
+
+ it should "save the dataset to druid on Azure" in {
+
+ DruidSourceBaseTest.setUpDb(DruidSourceBaseTest.MYSQL)
+
+ // create Data source options
+ val options: Map[String, String] = DruidSourceBaseTest.getDataSourceOptions(DruidSourceBaseTest.MYSQL).asScala.toMap ++
+ DruidDeepStorageAzureTest.getAzureOptions.asScala.toMap
+
+ val ds = Seq(
+ // same as content of data.csv
+ KpiRow(date = "2019-10-17", country = "US", dau = 50, revenue = 100.0, is_segmented = true),
+ KpiRow(date = "2019-10-17", country = "GB", dau = 20, revenue = 20.0, is_segmented = true),
+ KpiRow(date = "2019-10-17", country = "DE", dau = 20, revenue = 20.0, is_segmented = true),
+ KpiRow(date = "2019-10-16", country = "US", dau = 50, revenue = 100.0, is_segmented = false),
+ KpiRow(date = "2019-10-16", country = "FI", dau = 20, revenue = 20.0, is_segmented = false),
+ KpiRow(date = "2019-10-16", country = "GB", dau = 20, revenue = 20.0, is_segmented = false),
+ KpiRow(date = "2019-10-16", country = "DE", dau = 20, revenue = 20.0, is_segmented = false)
+ ).toDS
+ .withColumn("date", 'date.cast(DataTypes.TimestampType))
+
+ // note how we can call .repartitionByDruidSegmentSize directly on Dataset[Row]
+ // the nice thing is this allows continuous method chaining on Dataset without breaking the chain
+ ds.repartitionByDruidSegmentSize("date", rowsPerSegment = 2)
+ .write
+ .mode(SaveMode.Overwrite)
+ .options(options)
+ .druid("target-datasource-name-in-druid-on-azure", timeColumn = "date")
+
+ }
+}