diff --git a/src/main/java/io/kestra/plugin/gcp/runner/Batch.java b/src/main/java/io/kestra/plugin/gcp/runner/Batch.java index 49b7a6fc..fb9d977a 100644 --- a/src/main/java/io/kestra/plugin/gcp/runner/Batch.java +++ b/src/main/java/io/kestra/plugin/gcp/runner/Batch.java @@ -22,10 +22,7 @@ import lombok.*; import lombok.experimental.SuperBuilder; -import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.IOException; -import java.nio.ByteBuffer; import java.nio.file.Path; import java.time.*; import java.util.*; @@ -242,35 +239,16 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List filesToUploadWithOutputDir = new ArrayList<>(filesToUpload); - if (outputDirectoryEnabled) { - String outputDirName = (batchWorkingDirectory.relativize((Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR)) + "/").substring(1); - filesToUploadWithOutputDir.add(outputDirName); - } - try (Storage storage = storage(runContext, credentials)) { - for (String relativePath : filesToUploadWithOutputDir) { - BlobInfo destination = BlobInfo.newBuilder(BlobId.of( - renderedBucket, - workingDirectoryToBlobPath + Path.of("/" + relativePath) - )).build(); - Path filePath = runContext.resolve(Path.of(relativePath)); - if (relativePath.endsWith("/")) { - storage.create(destination); - continue; - } - - try (var fileInputStream = new FileInputStream(filePath.toFile()); - var writer = storage.writer(destination)) { - byte[] buffer = new byte[BUFFER_SIZE]; - int limit; - while ((limit = fileInputStream.read(buffer)) >= 0) { - writer.write(ByteBuffer.wrap(buffer, 0, limit)); - } - } - } - } + GcsUtils.of(projectId, credentials).uploadFiles(runContext, + filesToUpload, + renderedBucket, + batchWorkingDirectory, + outputDirectory, + outputDirectoryEnabled + ); } var taskBuilder = TaskSpec.newBuilder(); @@ -287,7 +265,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List= 0) { - fileOutputStream.write(buffer, 0, limit); - } - } - } - - if (outputDirectoryEnabled) { - Path batchOutputDirectory = (Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR); - Page outputDirEntries = storage.list(renderedBucket, Storage.BlobListOption.prefix(batchOutputDirectory.toString().substring(1))); - outputDirEntries.iterateAll().forEach(blob -> { - Path relativeBlobPathFromOutputDir = Path.of(batchOutputDirectory.toString().substring(1)).relativize(Path.of(blob.getBlobId().getName())); - storage.downloadTo( - blob.getBlobId(), - taskCommands.getOutputDirectory().resolve(relativeBlobPathFromOutputDir) - ); - }); - } - } + GcsUtils.of(projectId, credentials).downloadFile( + runContext, + taskCommands, + filesToDownload, + renderedBucket, + batchWorkingDirectory, + outputDirectory, + outputDirectoryEnabled + ); } return new RunnerResult(0, taskCommands.getLogConsumer()); } } finally { if (hasBucket && delete) { - try (Storage storage = storage(runContext, credentials)) { + try (Storage storage = GcsUtils.of(projectId, credentials).storage(runContext)) { Page list = storage.list(renderedBucket, Storage.BlobListOption.prefix(workingDirectoryToBlobPath)); list.iterateAll().forEach(blob -> storage.delete(blob.getBlobId())); storage.delete(BlobInfo.newBuilder(BlobId.of(renderedBucket, workingDirectoryToBlobPath)).build().getBlobId()); @@ -464,18 +423,6 @@ private boolean isTerminated(JobStatus.State state) { return state == JobStatus.State.SUCCEEDED || state == JobStatus.State.DELETION_IN_PROGRESS || isFailed(state); } - private Storage storage(RunContext runContext, GoogleCredentials credentials) throws IllegalVariableEvaluationException { - VersionProvider versionProvider = runContext.getApplicationContext().getBean(VersionProvider.class); - - return StorageOptions - .newBuilder() - .setCredentials(credentials) - .setProjectId(runContext.render(projectId)) - .setHeaderProvider(() -> Map.of("user-agent", "Kestra/" + versionProvider.getVersion())) - .build() - .getService(); - } - @Override protected Map runnerAdditionalVars(RunContext runContext, TaskCommands taskCommands) throws IllegalVariableEvaluationException { Map additionalVars = new HashMap<>(); diff --git a/src/main/java/io/kestra/plugin/gcp/runner/GcsUtils.java b/src/main/java/io/kestra/plugin/gcp/runner/GcsUtils.java index 7f257fcb..afc3d5d5 100644 --- a/src/main/java/io/kestra/plugin/gcp/runner/GcsUtils.java +++ b/src/main/java/io/kestra/plugin/gcp/runner/GcsUtils.java @@ -47,17 +47,19 @@ public void downloadFile(RunContext runContext, Path outputDirectory, boolean outputDirectoryEnabled) throws Exception { try (Storage storage = storage(runContext)) { - for (String relativePath : filesToDownload) { - BlobInfo source = BlobInfo.newBuilder(BlobId.of( - bucket, - removeLeadingSlash(workingDirectory.toString()) + Path.of("/" + relativePath) - )).build(); - try (var fileOutputStream = new FileOutputStream(runContext.resolve(Path.of(relativePath)).toFile()); - var reader = storage.reader(source.getBlobId())) { - byte[] buffer = new byte[BUFFER_SIZE]; - int limit; - while ((limit = reader.read(ByteBuffer.wrap(buffer))) >= 0) { - fileOutputStream.write(buffer, 0, limit); + if (filesToDownload != null) { + for (String relativePath : filesToDownload) { + BlobInfo source = BlobInfo.newBuilder(BlobId.of( + bucket, + removeLeadingSlash(workingDirectory.toString()) + Path.of("/" + relativePath) + )).build(); + try (var fileOutputStream = new FileOutputStream(runContext.resolve(Path.of(relativePath)).toFile()); + var reader = storage.reader(source.getBlobId())) { + byte[] buffer = new byte[BUFFER_SIZE]; + int limit; + while ((limit = reader.read(ByteBuffer.wrap(buffer))) >= 0) { + fileOutputStream.write(buffer, 0, limit); + } } } } @@ -70,7 +72,9 @@ public void downloadFile(RunContext runContext, BlobId blobId = blob.getBlobId(); if (!blobId.getName().endsWith("/")) { Path relativeBlobPathFromOutputDir = outputDirPath.relativize(Path.of(blobId.getName())); - storage.downloadTo(blobId, taskCommands.getOutputDirectory().resolve(relativeBlobPathFromOutputDir)); + Path outputFile = taskCommands.getOutputDirectory().resolve(relativeBlobPathFromOutputDir); + outputFile.getParent().toFile().mkdirs(); + storage.downloadTo(blobId, outputFile); } }); } diff --git a/src/test/java/io/kestra/plugin/gcp/runner/BatchTest.java b/src/test/java/io/kestra/plugin/gcp/runner/BatchTest.java index fb483723..da4e941e 100644 --- a/src/test/java/io/kestra/plugin/gcp/runner/BatchTest.java +++ b/src/test/java/io/kestra/plugin/gcp/runner/BatchTest.java @@ -34,4 +34,9 @@ protected TaskRunner taskRunner() { .completionCheckInterval(Duration.ofMillis(100)) .build(); } + + @Override + protected boolean needsToSpecifyWorkingDirectory() { + return true; + } } \ No newline at end of file