diff --git a/modules/nextflow/src/main/groovy/nextflow/fusion/FusionConfig.groovy b/modules/nextflow/src/main/groovy/nextflow/fusion/FusionConfig.groovy index b17812baea..45d1bd489a 100644 --- a/modules/nextflow/src/main/groovy/nextflow/fusion/FusionConfig.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/fusion/FusionConfig.groovy @@ -40,6 +40,7 @@ class FusionConfig { final static public String FUSION_PATH = '/usr/bin/fusion' + final static private String PRODUCT_NAME = 'fusion' final static private Pattern VERSION_JSON = ~/https:\/\/.*\/releases\/v(\d+(?:\.\w+)*)-(\w*)\.json$/ final private Boolean enabled @@ -122,6 +123,17 @@ class FusionConfig { return null } + /** + * Return the Fusion SKU string + * + * @return A string representing the Fusion SKU + */ + String sku() { + return enabled + ? PRODUCT_NAME + : null + } + String version() { return enabled ? retrieveFusionVersion(this.containerConfigUrl ?: DEFAULT_FUSION_AMD64_URL) diff --git a/modules/nextflow/src/main/groovy/nextflow/platform/PlatformHelper.groovy b/modules/nextflow/src/main/groovy/nextflow/platform/PlatformHelper.groovy new file mode 100644 index 0000000000..52461fde91 --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/platform/PlatformHelper.groovy @@ -0,0 +1,74 @@ +package nextflow.platform + +import groovy.transform.CompileStatic + +/** + * Helper methods for Platform-related operations + * + * @author Alberto Miranda + */ +@CompileStatic +class PlatformHelper { + + /** + * Get the configured Platform API endpoint: if the endpoint is not provided in the configuration, we fallback to the + * environment variable `TOWER_API_ENDPOINT`. If neither is provided, we fallback to the default endpoint. + * + * @param opts the configuration options for Platform (e.g. `session.config.navigate('tower')`) + * @param env the applicable environment variables + * @return the Platform API endpoint + */ + static String getEndpoint(Map opts, Map env) { + def result = opts.endpoint as String + if( !result || result=='-' ) + result = env.get('TOWER_API_ENDPOINT') ?: 'https://api.cloud.seqera.io' + return result.stripEnd('/') + } + + /** + * Return the configured Platform access token: if `TOWER_WORKFLOW_ID` is provided in the environment, it means + * we are running in a Platform-made run and we should ONLY retrieve the token from the environment. Otherwise, + * check the configuration or fallback to the environment. If no token is found, return null. + * + * @param opts the configuration options for Platform (e.g. `session.config.navigate('tower')`) + * @param env the applicable environment variables + * @return the Platform access token + */ + static String getAccessToken(Map opts, Map env) { + final token = env.get('TOWER_WORKFLOW_ID') + ? env.get('TOWER_ACCESS_TOKEN') + : opts.containsKey('accessToken') ? opts.accessToken as String : env.get('TOWER_ACCESS_TOKEN') + return token + } + + /** + * Return the configured Platform refresh token: if `TOWER_WORKFLOW_ID` is provided in the environment, it means + * we are running in a Platform-made run and we should ONLY retrieve the token from the environment. Otherwise, + * check the configuration or fallback to the environment. If no token is found, return null. + * + * @param opts the configuration options for Platform (e.g. `session.config.navigate('tower')`) + * @param env the applicable environment variables + * @return the Platform refresh token + */ + static String getRefreshToken(Map opts, Map env) { + final token = env.get('TOWER_WORKFLOW_ID') + ? env.get('TOWER_REFRESH_TOKEN') + : opts.containsKey('refreshToken') ? opts.refreshToken as String : env.get('TOWER_REFRESH_TOKEN') + return token + } + + /** + * Return the Platform Workspace ID: if `TOWER_WORKFLOW_ID` is provided in the environment, it means we are running + * in a Platform-made run and we should ONLY retrieve the workspace ID from the environment. Otherwise, check the + * configuration or fallback to the environment. If no workspace ID is found, return null. + * @param opts + * @param env + * @return + */ + static String getWorkspaceId(Map opts, Map env) { + final workspaceId = env.get('TOWER_WORKFLOW_ID') + ? env.get('TOWER_WORKSPACE_ID') + : opts.workspaceId as Long ?: env.get('TOWER_WORKSPACE_ID') as Long + return workspaceId + } +} diff --git a/plugins/nf-tower/build.gradle b/plugins/nf-tower/build.gradle index e7757c8883..75bb588b68 100644 --- a/plugins/nf-tower/build.gradle +++ b/plugins/nf-tower/build.gradle @@ -39,4 +39,6 @@ dependencies { testImplementation(testFixtures(project(":nextflow"))) testImplementation "org.apache.groovy:groovy:4.0.24" testImplementation "org.apache.groovy:groovy-nio:4.0.24" + // wiremock required by TowerFusionEnvTest + testImplementation "org.wiremock:wiremock:3.5.4" } diff --git a/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerFusionEnv.groovy b/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerFusionEnv.groovy new file mode 100644 index 0000000000..0fd858bbcd --- /dev/null +++ b/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerFusionEnv.groovy @@ -0,0 +1,304 @@ +package io.seqera.tower.plugin + +import com.google.common.cache.Cache +import com.google.common.cache.CacheBuilder +import com.google.common.util.concurrent.UncheckedExecutionException +import com.google.gson.Gson +import com.google.gson.JsonSyntaxException +import dev.failsafe.Failsafe +import dev.failsafe.RetryPolicy +import dev.failsafe.event.EventListener +import dev.failsafe.event.ExecutionAttemptedEvent +import dev.failsafe.function.CheckedSupplier +import groovy.transform.CompileStatic +import groovy.util.logging.Slf4j +import io.seqera.tower.plugin.exception.BadResponseException +import io.seqera.tower.plugin.exception.UnauthorizedException +import io.seqera.tower.plugin.exchange.LicenseTokenRequest +import io.seqera.tower.plugin.exchange.LicenseTokenResponse +import nextflow.Global +import nextflow.Session +import nextflow.SysEnv +import nextflow.exception.AbortOperationException +import nextflow.fusion.FusionConfig +import nextflow.fusion.FusionEnv +import nextflow.platform.PlatformHelper +import nextflow.util.Threads +import org.pf4j.Extension + +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.time.Duration +import java.time.temporal.ChronoUnit +import java.util.concurrent.Executors +import java.util.function.Predicate + +/** + * Environment provider for Platform-specific environment variables. + * + * @author Alberto Miranda + */ +@Slf4j +@Extension +@CompileStatic +class TowerFusionEnv implements FusionEnv { + + // The path relative to the Platform endpoint where license-scoped JWT tokens are obtained + private static final String LICENSE_TOKEN_PATH = 'license/token/' + + // Server errors that should trigger a retry + private static final List SERVER_ERRORS = [408, 429, 500, 502, 503, 504] + + // Default connection timeout for HTTP requests + private static final Duration DEFAULT_CONNECTION_TIMEOUT = Duration.of(30, ChronoUnit.SECONDS) + + // Default retry policy settings for HTTP requests: delay, max delay, attempts, and jitter + private static final Duration DEFAULT_RETRY_POLICY_DELAY = Duration.of(450, ChronoUnit.MILLIS) + private static final Duration DEFAULT_RETRY_POLICY_MAX_DELAY = Duration.of(90, ChronoUnit.SECONDS) + private static final int DEFAULT_RETRY_POLICY_MAX_ATTEMPTS = 10 + private static final double DEFAULT_RETRY_POLICY_JITTER = 0.5 + + // The HttpClient instance used to send requests + private final HttpClient httpClient = newDefaultHttpClient() + + // The RetryPolicy instance used to retry requests + private final RetryPolicy retryPolicy = newDefaultRetryPolicy(SERVER_ERRORS) + + // Time-to-live for cached tokens + private Duration tokenTTL = Duration.of(1, ChronoUnit.HOURS) + + // Cache used for storing license tokens + private Cache tokenCache = CacheBuilder.newBuilder() + .expireAfterWrite(tokenTTL) + .build() + + // Nextflow session + private final Session session + + // Platform endpoint to use for requests + private final String endpoint + + // Platform access token to use for requests + private final String accessToken + + /** + * Constructor for the class. It initializes the session, endpoint, and access token. + */ + TowerFusionEnv() { + this.session = Global.session as Session + final towerConfig = session.config.navigate('tower') as Map ?: [:] + final env = SysEnv.get() + this.endpoint = PlatformHelper.getEndpoint(towerConfig, env) + this.accessToken = PlatformHelper.getAccessToken(towerConfig, env) + } + + /** + * Return any environment variables relevant to Fusion execution. This method is called + * by {@link nextflow.fusion.FusionEnvProvider#getEnvironment} to determine which + * environment variables are needed for the current run. + * + * @param scheme The scheme for which the environment variables are needed (currently unused) + * @param config The Fusion configuration object + * @return A map of environment variables + */ + @Override + Map getEnvironment(String scheme, FusionConfig config) { + + final product = config.sku() + final version = config.version() + + try { + final token = getLicenseToken(product, version) + return [ + FUSION_LICENSE_TOKEN: token, + ] + } catch (Exception e) { + log.debug("Error retrieving Fusion license information: ${e.message}") + return Map.of() + } + } + + /** + * Send a request to Platform to obtain a license-scoped JWT for Fusion. The request is authenticated using the + * Platform access token provided in the configuration of the current session. + * + * @throws AbortOperationException if a Platform access token cannot be found + * + * @return The signed JWT token + */ + protected String getLicenseToken(String product, String version) throws AbortOperationException { + if (accessToken == null) { + throw new AbortOperationException("Missing personal access token -- Make sure there's a variable TOWER_ACCESS_TOKEN in your environment") + } + + final req = new LicenseTokenRequest( + product: product, + version: version + ) + + try { + final key = '${product}-${version}' + def resp = tokenCache.get( + key, + () -> sendRequest(req) + ) as LicenseTokenResponse + + if( resp.expirationDate.before(new Date()) ) { + log.debug "Cached token already expired; refreshing" + resp = sendRequest(req) + tokenCache.put(key, resp) + } + return resp.signedToken + } catch (UncheckedExecutionException e) { + throw e.getCause() + } + } + + /************************************************************************** + * Helper methods + *************************************************************************/ + + /** + * Create a new HttpClient instance with default settings + * @return The new HttpClient instance + */ + private static HttpClient newDefaultHttpClient() { + final builder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .followRedirects(HttpClient.Redirect.NEVER) + .cookieHandler(new CookieManager()) + .connectTimeout(DEFAULT_CONNECTION_TIMEOUT) + // use virtual threads executor if enabled + if ( Threads.useVirtual() ) { + builder.executor(Executors.newVirtualThreadPerTaskExecutor()) + } + // build and return the new client + return builder.build() + } + + /** + * Create a new RetryPolicy instance with default settings and the given list of retryable errors. With this policy, + * a request is retried on IOExceptions and any server errors defined in errorsToRetry. The number of retries, delay, + * max delay, and jitter are controlled by the corresponding values defined at class level. + * + * @return The new RetryPolicy instance + */ + private static RetryPolicy> newDefaultRetryPolicy(List errorsToRetry) { + + final retryOnException = (e -> e instanceof IOException) as Predicate + final retryOnStatusCode = ((HttpResponse resp) -> resp.statusCode() in errorsToRetry) as Predicate> + + final listener = new EventListener>>() { + @Override + void accept(ExecutionAttemptedEvent event) throws Throwable { + def msg = "connection failure - attempt: ${event.attemptCount}" + if (event.lastResult != null) + msg += "; response: ${event.lastResult}" + if (event.lastFailure != null) + msg += "; exception: [${event.lastFailure.class.name}] ${event.lastFailure.message}" + log.debug(msg) + } + } + return RetryPolicy.> builder() + .handleIf(retryOnException) + .handleResultIf(retryOnStatusCode) + .withBackoff(DEFAULT_RETRY_POLICY_DELAY.toMillis(), DEFAULT_RETRY_POLICY_MAX_DELAY.toMillis(), ChronoUnit.MILLIS) + .withMaxAttempts(DEFAULT_RETRY_POLICY_MAX_ATTEMPTS) + .withJitter(DEFAULT_RETRY_POLICY_JITTER) + .onRetry(listener) + .build() + } + + /** + * Send an HTTP request and return the response. This method automatically retries the request according to the + * given RetryPolicy. + * + * @param req The HttpRequest to send + * @return The HttpResponse received + */ + private HttpResponse safeHttpSend(HttpRequest req, RetryPolicy policy) { + return Failsafe.with(policy).get( + () -> { + log.debug "Request: method:=${req.method()}; uri:=${req.uri()}; request:=${req}" + final resp = httpClient.send(req, HttpResponse.BodyHandlers.ofString()) + log.debug "Response: statusCode:=${resp.statusCode()}; body:=${resp.body()}" + return resp + } as CheckedSupplier + ) as HttpResponse + } + + /** + * Create a {@link HttpRequest} representing a {@link LicenseTokenRequest} object + * + * @param req The LicenseTokenRequest object + * @return The resulting HttpRequest object + */ + private HttpRequest makeHttpRequest(LicenseTokenRequest req) { + return HttpRequest.newBuilder() + .uri(URI.create("${endpoint}/${LICENSE_TOKEN_PATH}").normalize()) + .header('Content-Type', 'application/json') + .header('Authorization', "Bearer ${accessToken}") + .POST( + HttpRequest.BodyPublishers.ofString( + serializeToJson(req) + ) + ) + .build() + } + + /** + * Serialize a {@link LicenseTokenRequest} object into a JSON string + * + * @param req The LicenseTokenRequest object + * @return The resulting JSON string + */ + private static String serializeToJson(LicenseTokenRequest req) { + return new Gson().toJson(req) + } + + /** + * Parse a JSON string into a {@link LicenseTokenResponse} object + * + * @param resp The String containing the JSON representation of the LicenseTokenResponse object + * @return The resulting LicenseTokenResponse object + * + * @throws JsonSyntaxException if the JSON string is not well-formed + */ + private static LicenseTokenResponse parseLicenseTokenResponse(String resp) throws JsonSyntaxException { + return new Gson().fromJson(resp, LicenseTokenResponse.class) + } + + /** + * Request a license token from Platform. + * + * @param req The LicenseTokenRequest object + * @return The LicenseTokenResponse object + * + * @throws AbortOperationException if a Platform access token cannot be found + * @throws UnauthorizedException if the access token is invalid + * @throws BadResponseException if the response is not as expected + * @throws IllegalStateException if the request cannot be sent + */ + private LicenseTokenResponse sendRequest(LicenseTokenRequest req) throws AbortOperationException, UnauthorizedException, BadResponseException, IllegalStateException { + + final httpReq = makeHttpRequest(req) + + try { + final resp = safeHttpSend(httpReq, retryPolicy) + + if( resp.statusCode() == 200 ) { + final ret = parseLicenseTokenResponse(resp.body()) + return ret + } + + if( resp.statusCode() == 401 ) { + throw new UnauthorizedException("Unauthorized [401] - Verify you have provided a valid access token") + } + + throw new BadResponseException("Invalid response: ${httpReq.method()} ${httpReq.uri()} [${resp.statusCode()}] ${resp.body()}") + } catch (IOException e) { + throw new IllegalStateException("Unable to send request to '${httpReq.uri()}' : ${e.message}") + } + } +} diff --git a/plugins/nf-tower/src/main/io/seqera/tower/plugin/exception/BadResponseException.groovy b/plugins/nf-tower/src/main/io/seqera/tower/plugin/exception/BadResponseException.groovy new file mode 100644 index 0000000000..4de9a30882 --- /dev/null +++ b/plugins/nf-tower/src/main/io/seqera/tower/plugin/exception/BadResponseException.groovy @@ -0,0 +1,7 @@ +package io.seqera.tower.plugin.exception + +import groovy.transform.InheritConstructors + +@InheritConstructors +class BadResponseException extends RuntimeException{ +} diff --git a/plugins/nf-tower/src/main/io/seqera/tower/plugin/exception/UnauthorizedException.groovy b/plugins/nf-tower/src/main/io/seqera/tower/plugin/exception/UnauthorizedException.groovy new file mode 100644 index 0000000000..6269a825b5 --- /dev/null +++ b/plugins/nf-tower/src/main/io/seqera/tower/plugin/exception/UnauthorizedException.groovy @@ -0,0 +1,7 @@ +package io.seqera.tower.plugin.exception + +import groovy.transform.InheritConstructors + +@InheritConstructors +class UnauthorizedException extends RuntimeException { +} diff --git a/plugins/nf-tower/src/main/io/seqera/tower/plugin/exchange/LicenseTokenRequest.groovy b/plugins/nf-tower/src/main/io/seqera/tower/plugin/exchange/LicenseTokenRequest.groovy new file mode 100644 index 0000000000..831b6641a8 --- /dev/null +++ b/plugins/nf-tower/src/main/io/seqera/tower/plugin/exchange/LicenseTokenRequest.groovy @@ -0,0 +1,22 @@ +package io.seqera.tower.plugin.exchange + +import groovy.transform.CompileStatic +import groovy.transform.EqualsAndHashCode +import groovy.transform.ToString + +/** + * Models a REST request to obtain a license-scoped JWT token from Platform + * + * @author Alberto Miranda + */ +@EqualsAndHashCode +@ToString(includeNames = true, includePackage = false) +@CompileStatic +class LicenseTokenRequest { + + /** The product code */ + String product + + /** The product version */ + String version +} diff --git a/plugins/nf-tower/src/main/io/seqera/tower/plugin/exchange/LicenseTokenResponse.groovy b/plugins/nf-tower/src/main/io/seqera/tower/plugin/exchange/LicenseTokenResponse.groovy new file mode 100644 index 0000000000..a7c41454e9 --- /dev/null +++ b/plugins/nf-tower/src/main/io/seqera/tower/plugin/exchange/LicenseTokenResponse.groovy @@ -0,0 +1,23 @@ +package io.seqera.tower.plugin.exchange + +import groovy.transform.CompileStatic +import groovy.transform.ToString + +/** + * Models a REST response containing a license-scoped JWT token from Platform + * + * @author Alberto Miranda + */ +@CompileStatic +@ToString(includeNames = true, includePackage = false) +class LicenseTokenResponse { + /** + * The signed JWT token + */ + String signedToken + + /** + * The expiration date of the token + */ + Date expirationDate +} diff --git a/plugins/nf-tower/src/resources/META-INF/extensions.idx b/plugins/nf-tower/src/resources/META-INF/extensions.idx index 68286228bc..8a61ce9a9c 100644 --- a/plugins/nf-tower/src/resources/META-INF/extensions.idx +++ b/plugins/nf-tower/src/resources/META-INF/extensions.idx @@ -10,3 +10,4 @@ # io.seqera.tower.plugin.TowerFactory +io.seqera.tower.plugin.TowerFusionEnv diff --git a/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerFusionEnvTest.groovy b/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerFusionEnvTest.groovy new file mode 100644 index 0000000000..a7e47e683b --- /dev/null +++ b/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerFusionEnvTest.groovy @@ -0,0 +1,425 @@ +package io.seqera.tower.plugin + +import com.github.tomakehurst.wiremock.WireMockServer +import com.github.tomakehurst.wiremock.client.WireMock +import groovy.json.JsonOutput +import io.seqera.tower.plugin.exception.UnauthorizedException +import nextflow.Global +import nextflow.Session +import nextflow.SysEnv +import nextflow.exception.AbortOperationException +import nextflow.fusion.FusionConfig +import spock.lang.Shared +import spock.lang.Specification + +import java.time.temporal.ChronoUnit + +/** + * Test cases for the TowerFusionEnv class. + * + * @author Alberto Miranda + */ +class TowerFusionEnvTest extends Specification { + + @Shared + WireMockServer wireMockServer + + def setupSpec() { + wireMockServer = new WireMockServer(18080) + wireMockServer.start() + } + + def cleanupSpec() { + wireMockServer.stop() + } + + def setup() { + wireMockServer.resetAll() + SysEnv.push([:]) // <-- ensure the system host env does not interfere + } + + def cleanup() { + SysEnv.pop() // <-- restore the system host env + } + + + def 'should return the endpoint from the config'() { + given: 'a session' + Global.session = Mock(Session) { + config >> [ + tower: [ + endpoint: 'https://tower.nf' + ] + ] + } + + when: 'the provider is created' + def provider = new TowerFusionEnv() + + then: 'the endpoint has the expected value' + provider.endpoint == 'https://tower.nf' + } + + def 'should return the endpoint from the environment'() { + setup: + SysEnv.push(['TOWER_API_ENDPOINT': 'https://tower.nf']) + Global.session = Mock(Session) { + config >> [:] + } + + when: 'the provider is created' + def provider = new TowerFusionEnv() + + then: 'the endpoint has the expected value' + provider.endpoint == 'https://tower.nf' + + cleanup: + SysEnv.pop() + } + + def 'should return the default endpoint'() { + when: 'session config is empty' + Global.session = Mock(Session) { + config >> [ + tower: [:] + ] + } + def provider = new TowerFusionEnv() + + then: 'the endpoint has the expected value' + provider.endpoint == TowerClient.DEF_ENDPOINT_URL + + when: 'session config is null' + Global.session = Mock(Session) { + config >> null + } + + then: 'the endpoint has the expected value' + provider.endpoint == TowerClient.DEF_ENDPOINT_URL + + when: 'session config is missing' + Global.session = Mock(Session) { + config >> [:] + } + + then: 'the endpoint has the expected value' + provider.endpoint == TowerClient.DEF_ENDPOINT_URL + + when: 'session.config.tower.endpoint is not defined' + Global.session = Mock(Session) { + config >> [ + tower: [:] + ] + } + + then: 'the endpoint has the expected value' + provider.endpoint == TowerClient.DEF_ENDPOINT_URL + + when: 'session.config.tower.endpoint is null' + Global.session = Mock(Session) { + config >> [ + tower: [ + endpoint: null + ] + ] + } + + then: 'the endpoint has the expected value' + + when: 'session.config.tower.endpoint is empty' + Global.session = Mock(Session) { + config >> [ + tower: [ + endpoint: '' + ] + ] + } + + then: 'the endpoint has the expected value' + + when: 'session.config.tower.endpoint is defined as "-"' + Global.session = Mock(Session) { + config >> [ + tower: [ + endpoint: '-' + ] + ] + } + + then: 'the endpoint has the expected value' + } + + def 'should return the access token from the config'() { + given: 'a session' + Global.session = Mock(Session) { + config >> [ + tower: [ + accessToken: 'abc123' + ] + ] + } + + when: 'the provider is created' + def provider = new TowerFusionEnv() + + then: 'the access token has the expected value' + provider.accessToken == 'abc123' + } + + def 'should return the access token from the environment'() { + setup: + Global.session = Mock(Session) { + config >> [:] + } + SysEnv.push(['TOWER_ACCESS_TOKEN': 'abc123']) + + when: 'the provider is created' + def provider = new TowerFusionEnv() + + then: 'the access token has the expected value' + provider.accessToken == 'abc123' + + cleanup: + SysEnv.pop() + } + + def 'should prefer the access token from the config'() { + setup: + Global.session = Mock(Session) { + config >> [ + tower: [ + accessToken: 'abc123' + ] + ] + } + SysEnv.push(['TOWER_ACCESS_TOKEN': 'xyz789']) + + when: 'the provider is created' + def provider = new TowerFusionEnv() + + then: 'the access token has the expected value' + provider.accessToken == 'abc123' + + cleanup: + SysEnv.pop() + } + + def 'should prefer the access token from the config despite being null'() { + setup: + Global.session = Mock(Session) { + config >> [ + tower: [ + accessToken: null + ] + ] + } + SysEnv.push(['TOWER_ACCESS_TOKEN': 'xyz789']) + + when: 'the provider is created' + def provider = new TowerFusionEnv() + + then: 'the access token has the expected value' + provider.accessToken == null + + cleanup: + SysEnv.pop() + } + + def 'should prefer the access token from the environment if TOWER_WORKFLOW_ID is set'() { + setup: + Global.session = Mock(Session) { + config >> [ + tower: [ + accessToken: 'abc123' + ] + ] + } + SysEnv.push(['TOWER_ACCESS_TOKEN' : 'xyz789', 'TOWER_WORKFLOW_ID': '123']) + + when: 'the provider is created' + def provider = new TowerFusionEnv() + + then: 'the access token has the expected value' + provider.accessToken == 'xyz789' + + cleanup: + SysEnv.pop() + } + + def 'should get a license token'() { + given: 'a TowerFusionEnv provider' + Global.session = Mock(Session) { + config >> [ + tower: [ + endpoint : 'http://localhost:18080', + accessToken: 'abc123' + ] + ] + } + def provider = new TowerFusionEnv() + + and: 'a mock endpoint returning a valid token' + final now = new Date().toInstant() + final expirationDate = JsonOutput.toJson(Date.from(now.plus(1, ChronoUnit.DAYS))) + wireMockServer.stubFor( + WireMock.post(WireMock.urlEqualTo("/license/token/")) + .withHeader('Authorization', WireMock.equalTo('Bearer abc123')) + .willReturn( + WireMock.aResponse() + .withStatus(200) + .withHeader('Content-Type', 'application/json') + .withBody('{"signedToken":"xyz789", "expirationDate":' + expirationDate + '}') + ) + ) + + when: 'a license token is requested' + final token = provider.getLicenseToken(PRODUCT, VERSION) + + then: 'the token has the expected value' + token == 'xyz789' + + and: 'the request is correct' + wireMockServer.verify(1, WireMock.postRequestedFor(WireMock.urlEqualTo("/license/token/")) + .withHeader('Authorization', WireMock.equalTo('Bearer abc123'))) + + where: + PRODUCT | VERSION + 'some-product' | 'some-version' + 'some-product' | null + null | 'some-version' + null | null + } + + def 'should fail getting a token if the Platform configuration is missing'() { + given: 'a TowerFusionEnv provider' + Global.session = Mock(Session) { + config >> [:] + } + def provider = new TowerFusionEnv() + + when: 'a license token is requested' + provider.getLicenseToken('some-product', 'some-version') + + then: 'an exception is thrown' + final ex = thrown(AbortOperationException) + ex.message == 'Missing personal access token -- Make sure there\'s a variable TOWER_ACCESS_TOKEN in your environment' + } + + def 'should fail getting a token if the Platform configuration is empty'() { + given: 'a TowerFusionEnv provider' + Global.session = Mock(Session) { + config >> [ + tower: [:] + ] + } + def provider = new TowerFusionEnv() + + when: 'a license token is requested' + provider.getLicenseToken('some-product', 'some-version') + + then: 'an exception is thrown' + final ex = thrown(AbortOperationException) + ex.message == 'Missing personal access token -- Make sure there\'s a variable TOWER_ACCESS_TOKEN in your environment' + } + + def 'should fail getting a token if the Platform access token is missing'() { + given: 'a TowerFusionEnv provider' + Global.session = Mock(Session) { + config >> [ + tower: [ + endpoint: 'http://localhost:18080' + ] + ] + } + def provider = new TowerFusionEnv() + + when: 'a license token is requested' + provider.getLicenseToken('some-product', 'some-version') + + then: 'an exception is thrown' + final ex = thrown(AbortOperationException) + ex.message == 'Missing personal access token -- Make sure there\'s a variable TOWER_ACCESS_TOKEN in your environment' + } + + def 'should throw UnauthorizedException if getting a token fails with 401'() { + given: 'a TowerFusionEnv provider' + Global.session = Mock(Session) { + config >> [ + tower: [ + endpoint : 'http://localhost:18080', + accessToken: 'abc123' + ] + ] + } + def provider = new TowerFusionEnv() + + and: 'a mock endpoint returning an error' + wireMockServer.stubFor( + WireMock.post(WireMock.urlEqualTo("/license/token/")) + .withHeader('Authorization', WireMock.equalTo('Bearer abc123')) + .willReturn( + WireMock.aResponse() + .withStatus(401) + .withHeader('Content-Type', 'application/json') + .withBody('{"error":"Unauthorized"}') + ) + ) + + when: 'a license token is requested' + provider.getLicenseToken('some-product', 'some-version') + + then: 'an exception is thrown' + thrown(UnauthorizedException) + } + + def 'should return a valid environment' () { + given: 'a TowerFusionEnv provider' + Global.session = Mock(Session) { + config >> [:] + } + def provider = Spy(TowerFusionEnv) + + when: 'the environment is requested' + def env = provider.getEnvironment('s3', Mock(FusionConfig)) + + then: 'the environment has the expected values' + 1 * provider.getLicenseToken(_, _) >> 'xyz789' + env == [FUSION_LICENSE_TOKEN: 'xyz789'] + } + + def 'should return an empty environment if no Platform config is available' () { + given: 'a session with no config for Platform' + Global.session = Mock(Session) { + config >> [:] + } + + when: 'the environment is requested' + def provider = new TowerFusionEnv() + def env = provider.getEnvironment('-', Mock(FusionConfig)) + + then: 'the environment is empty' + env == [:] + } + + def 'should return an empty environment if the license token cannot be obtained' () { + given: 'a TowerFusionEnv provider' + Global.session = Mock(Session) { + config >> [ + tower: [ + endpoint : 'http://localhost:18080', + accessToken: 'abc123' + ] + ] + } + def provider = Spy(TowerFusionEnv) + + when: 'the environment is requested' + def env = provider.getEnvironment('s3', Mock(FusionConfig)) + + then: 'the environment has the expected values' + 1 * provider.getLicenseToken(_, _) >> { + throw new Exception('error') + } + env == [:] + } +} diff --git a/plugins/nf-wave/src/main/io/seqera/wave/plugin/config/TowerConfig.groovy b/plugins/nf-wave/src/main/io/seqera/wave/plugin/config/TowerConfig.groovy index 5a0877039f..ee8a126857 100644 --- a/plugins/nf-wave/src/main/io/seqera/wave/plugin/config/TowerConfig.groovy +++ b/plugins/nf-wave/src/main/io/seqera/wave/plugin/config/TowerConfig.groovy @@ -19,6 +19,7 @@ package io.seqera.wave.plugin.config import groovy.transform.CompileStatic import groovy.transform.ToString +import nextflow.platform.PlatformHelper /** * Model Tower config accessed by Wave @@ -40,51 +41,10 @@ class TowerConfig { final String workflowId TowerConfig(Map opts, Map env) { - this.accessToken = accessToken0(opts, env) - this.refreshToken = refreshToken0(opts, env) - this.workspaceId = workspaceId0(opts, env) as Long - this.endpoint = endpoint0(opts, env) + this.accessToken = PlatformHelper.getAccessToken(opts, env) + this.refreshToken = PlatformHelper.getRefreshToken(opts, env) + this.workspaceId = PlatformHelper.getWorkspaceId(opts, env) as Long + this.endpoint = PlatformHelper.getEndpoint(opts, env) this.workflowId = env.get('TOWER_WORKFLOW_ID') } - - private String endpoint0(Map opts, Map env) { - def result = opts.endpoint as String - if( !result || result=='-' ) - result = env.get('TOWER_API_ENDPOINT') ?: 'https://api.cloud.seqera.io' - return result.stripEnd('/') - } - - private String accessToken0(Map opts, Map env) { - // when 'TOWER_WORKFLOW_ID' is provided in the env, it's a tower made launch - // therefore the access token should only be taken from the env - // otherwise check into the config file and fallback in the env - // see also - // https://github.com/nextflow-io/nextflow/blob/master/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy#L369-L377 - def token = env.get('TOWER_WORKFLOW_ID') - ? env.get('TOWER_ACCESS_TOKEN') - : opts.containsKey('accessToken') ? opts.accessToken as String : env.get('TOWER_ACCESS_TOKEN') - return token - } - - private String refreshToken0(Map opts, Map env) { - // when 'TOWER_WORKFLOW_ID' is provided in the env, it's a tower made launch - // therefore the access token should only be taken from the env - // otherwise check into the config file and fallback in the env - // see also - // https://github.com/nextflow-io/nextflow/blob/master/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy#L369-L377 - def token = env.get('TOWER_WORKFLOW_ID') - ? env.get('TOWER_REFRESH_TOKEN') - : opts.containsKey('refreshToken') ? opts.refreshToken as String : env.get('TOWER_REFRESH_TOKEN') - return token - } - - private String workspaceId0(Map opts, Map env) { - // when 'TOWER_WORKFLOW_ID' is provided in the env, it's a tower made launch - // therefore the workspace should only be taken from the env - // otherwise check into the config file and fallback in the env - def workspaceId = env.get('TOWER_WORKFLOW_ID') - ? env.get('TOWER_WORKSPACE_ID') - : opts.workspaceId as Long ?: env.get('TOWER_WORKSPACE_ID') as Long - return workspaceId - } }