Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Azure Batch worker pool supports managed identity #5670

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/reference/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,11 @@ The following settings are available:
`azure.batch.pools.<name>.lowPriority`
: Enable the use of low-priority VMs (default: `false`).

`azure.batch.pools.<name>.managedIdentityId`
: :::{versionadded} 25.01.0-edge
:::
: Specify the pool has a managed identity attached. This will be passed to the task as the environment variable `NXF_AZURE_MI_CLIENT_ID`.

`azure.batch.pools.<name>.maxVmCount`
: Specify the max of virtual machine when using auto scale option.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import com.azure.compute.batch.models.ContainerConfiguration
import com.azure.compute.batch.models.ContainerRegistryReference
import com.azure.compute.batch.models.ContainerType
import com.azure.compute.batch.models.ElevationLevel
import com.azure.compute.batch.models.EnvironmentSetting
import com.azure.compute.batch.models.MetadataItem
import com.azure.compute.batch.models.MountConfiguration
import com.azure.compute.batch.models.NetworkConfiguration
Expand Down Expand Up @@ -434,15 +435,23 @@ class AzBatchService implements Closeable {

log.trace "[AZURE BATCH] Submitting task: $taskId, cpus=${task.config.getCpus()}, mem=${task.config.getMemory()?:'-'}, slots: $slots"

// Add environment variables for managed identity if configured
final env = [:] as Map<String,String>
if( pool?.opts?.managedIdentityId ) {
env.put('AZCOPY_AUTO_LOGIN_TYPE', 'MSI')
env.put('AZCOPY_MSI_CLIENT_ID', pool.opts.managedIdentityId)
}

return new BatchTaskCreateContent(taskId, cmd)
.setUserIdentity(userIdentity(pool.opts.privileged, pool.opts.runAs, AutoUserScope.TASK))
.setContainerSettings(containerOpts)
.setResourceFiles(resourceFileUrls(task, sas))
.setOutputFiles(outputFileUrls(task, sas))
.setRequiredSlots(slots)
.setConstraints(constraints)


.setEnvironmentSettings(env.collect { name, value ->
new EnvironmentSetting(name).setValue(value)
})
}

AzTaskKey runTask(String poolId, String jobId, TaskRun task) {
Expand Down Expand Up @@ -503,6 +512,13 @@ class AzBatchService implements Closeable {
List<OutputFile> result = new ArrayList<>(20)
result << destFile(TaskRun.CMD_EXIT, task.workDir, sas)
result << destFile(TaskRun.CMD_LOG, task.workDir, sas)
result << destFile(TaskRun.CMD_OUTFILE, task.workDir, sas)
result << destFile(TaskRun.CMD_ERRFILE, task.workDir, sas)
result << destFile(TaskRun.CMD_SCRIPT, task.workDir, sas)
result << destFile(TaskRun.CMD_RUN, task.workDir, sas)
result << destFile(TaskRun.CMD_STAGE, task.workDir, sas)
result << destFile(TaskRun.CMD_TRACE, task.workDir, sas)
result << destFile(TaskRun.CMD_ENV, task.workDir, sas)
return result
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,5 @@ class AzBatchTaskHandler extends TaskHandler implements FusionAwareTask {
}
return machineInfo
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ class AzFileCopyStrategy extends SimpleFileCopyStrategy {
final result = new StringBuilder()
final copy = environment ? new LinkedHashMap<String,String>(environment) : new LinkedHashMap<String,String>()
copy.remove('PATH')
copy.put('PATH', '$PWD/.nextflow-bin:$AZ_BATCH_NODE_SHARED_DIR/bin/:$PATH')
copy.put('AZCOPY_LOG_LOCATION', '$PWD/.azcopy_log')
copy.put('PATH', '$AZ_BATCH_TASK_DIR/.nextflow-bin:$AZ_BATCH_NODE_SHARED_DIR/bin/:$PATH')
copy.put('AZCOPY_LOG_LOCATION', '$AZ_BATCH_TASK_DIR/.azcopy_log')
copy.put('AZCOPY_JOB_PLAN_LOCATION', '$AZ_BATCH_TASK_DIR/.azcopy_log')
copy.put('AZ_SAS', sasToken)

// finally render the environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class AzPoolOpts implements CacheFunnel {
boolean lowPriority
AzStartTaskOpts startTask

String managedIdentityId

AzPoolOpts() {
this(Collections.emptyMap())
}
Expand All @@ -92,6 +94,7 @@ class AzPoolOpts implements CacheFunnel {
this.password = opts.password
this.virtualNetwork = opts.virtualNetwork
this.lowPriority = opts.lowPriority as boolean
this.managedIdentityId = opts.managedIdentityId
}

@Override
Expand All @@ -114,6 +117,7 @@ class AzPoolOpts implements CacheFunnel {
hasher.putBoolean(lowPriority)
hasher.putUnencodedChars(startTask.script ?: '')
hasher.putBoolean(startTask.privileged)
hasher.putUnencodedChars(managedIdentityId ?: '')
return hasher
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,23 @@ class AzBashLib extends BashFunLib<AzBashLib> {
local base_name="$(basename "$name")"
local dir_name="$(dirname "$name")"

local auth_args=""
if [[ ! -z "$AZCOPY_MSI_CLIENT_ID" ]]; then
# When using managed identity, no additional args needed
auth_args=""
else
# Use SAS token authentication
auth_args="?$AZ_SAS"
fi

if [[ -d $name ]]; then
if [[ "$base_name" == "$name" ]]; then
azcopy cp "$name" "$target?$AZ_SAS" --recursive --block-blob-tier $AZCOPY_BLOCK_BLOB_TIER --block-size-mb $AZCOPY_BLOCK_SIZE_MB
azcopy cp "$name" "$target$auth_args" --recursive --block-blob-tier $AZCOPY_BLOCK_BLOB_TIER --block-size-mb $AZCOPY_BLOCK_SIZE_MB
else
azcopy cp "$name" "$target/$dir_name?$AZ_SAS" --recursive --block-blob-tier $AZCOPY_BLOCK_BLOB_TIER --block-size-mb $AZCOPY_BLOCK_SIZE_MB
azcopy cp "$name" "$target/$dir_name$auth_args" --recursive --block-blob-tier $AZCOPY_BLOCK_BLOB_TIER --block-size-mb $AZCOPY_BLOCK_SIZE_MB
fi
else
azcopy cp "$name" "$target/$name?$AZ_SAS" --block-blob-tier $AZCOPY_BLOCK_BLOB_TIER --block-size-mb $AZCOPY_BLOCK_SIZE_MB
azcopy cp "$name" "$target/$name$auth_args" --block-blob-tier $AZCOPY_BLOCK_BLOB_TIER --block-size-mb $AZCOPY_BLOCK_SIZE_MB
fi
}

Expand All @@ -79,12 +88,22 @@ class AzBashLib extends BashFunLib<AzBashLib> {
local target=$2
local basedir=$(dirname $2)
local ret

local auth_args=""
if [[ ! -z "$AZCOPY_MSI_CLIENT_ID" ]]; then
# When using managed identity, no additional args needed
auth_args=""
else
# Use SAS token authentication
auth_args="?$AZ_SAS"
fi

mkdir -p "$basedir"

ret=$(azcopy cp "$source?$AZ_SAS" "$target" 2>&1) || {
ret=$(azcopy cp "$source$auth_args" "$target" 2>&1) || {
## if fails check if it was trying to download a directory
mkdir -p $target
azcopy cp "$source/*?$AZ_SAS" "$target" --recursive >/dev/null || {
azcopy cp "$source/*$auth_args" "$target" --recursive >/dev/null || {
rm -rf $target
>&2 echo "Unable to download path: $source"
exit 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package nextflow.cloud.azure.batch

import nextflow.cloud.azure.config.AzPoolOpts
import nextflow.cloud.types.CloudMachineInfo
import nextflow.cloud.types.PriceModel
import nextflow.cloud.azure.batch.AzVmPoolSpec
import nextflow.exception.ProcessUnrecoverableException
import nextflow.executor.BashWrapperBuilder
import nextflow.executor.Executor
import nextflow.processor.TaskBean
import nextflow.processor.TaskConfig
import nextflow.processor.TaskProcessor
import nextflow.processor.TaskRun
Expand Down Expand Up @@ -84,5 +87,4 @@ class AzBatchTaskHandlerTest extends Specification {
trace.machineInfo.zone == 'west-eu'
trace.machineInfo.priceModel == PriceModel.standard
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

package nextflow.cloud.azure.config

import nextflow.cloud.azure.config.AzPoolOpts
import nextflow.util.Duration
import spock.lang.Specification
/**
Expand Down
Loading