diff --git a/core/src/main/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStream.scala b/core/src/main/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStream.scala index bdeb495..7205ebf 100644 --- a/core/src/main/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStream.scala +++ b/core/src/main/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStream.scala @@ -9,7 +9,9 @@ import java.nio.channels.Channels class ZstdDecompressOutputStream(outputStream: OutputStream) extends OutputStream { private val decompressCtx = new ZstdDecompressCtx() private val outputChannel = Channels.newChannel(outputStream) - private val closed = false + private val outputBuffer = ByteBuffer.allocateDirect(ZstdInputStream.recommendedDOutSize().toInt) + + private var closed = false override def write(chunk: Array[Byte]): Unit = { if (closed) { @@ -17,12 +19,12 @@ class ZstdDecompressOutputStream(outputStream: OutputStream) extends OutputStrea } val inputBuffer = ByteBuffer.allocateDirect(chunk.length) - val outputBuffer = ByteBuffer.allocateDirect(ZstdInputStream.recommendedDOutSize().toInt) - inputBuffer.put(chunk) - inputBuffer.flip() + inputBuffer.rewind() while (inputBuffer.hasRemaining) { + outputBuffer.clear() + decompressCtx.decompressDirectByteBufferStream(outputBuffer, inputBuffer) outputBuffer.flip() @@ -30,8 +32,6 @@ class ZstdDecompressOutputStream(outputStream: OutputStream) extends OutputStrea while (outputBuffer.hasRemaining) { outputChannel.write(outputBuffer) } - - outputBuffer.clear() } } @@ -45,7 +45,9 @@ class ZstdDecompressOutputStream(outputStream: OutputStream) extends OutputStrea override def close(): Unit = { if (!closed) { + closed = true decompressCtx.close() + outputBuffer.clear() outputChannel.close() } } diff --git a/gcs/build.gradle b/gcs/build.gradle index beb00a6..806a73d 100644 --- a/gcs/build.gradle +++ b/gcs/build.gradle @@ -6,4 +6,5 @@ dependencies { implementation "com.google.cloud:google-cloud-storage:$gcsVersion" testImplementation "io.monix:monix_2.13:$monixVersion" + testImplementation "com.github.luben:zstd-jni:$zstdVersion" } \ No newline at end of file diff --git a/gcs/src/main/scala/com/avast/clients/storage/gcs/GcsStorageBackend.scala b/gcs/src/main/scala/com/avast/clients/storage/gcs/GcsStorageBackend.scala index c71529f..03c7e18 100644 --- a/gcs/src/main/scala/com/avast/clients/storage/gcs/GcsStorageBackend.scala +++ b/gcs/src/main/scala/com/avast/clients/storage/gcs/GcsStorageBackend.scala @@ -87,24 +87,11 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket } private def receiveStreamedFile(blob: Blob, destination: File, expectedHash: Sha256): F[Either[StorageException, GetResult]] = { - def getCompressionType: Option[String] = { - Option(blob.getMetadata.get(GcsStorageBackend.CompressionTypeHeader)).map(_.toLowerCase) - } - Sync[F].delay(logger.debug(s"Downloading streamed data to $destination")) >> blocker .delay(destination.newOutputStream(FileStreamOpenOptions)) .bracket { fileStream => - getCompressionType match { - case None => - receiveRawStream(blob, fileStream) - case Some("zstd") => - receiveZstdStream(blob, fileStream) - case Some(unknownCompressionType) => - Sync[F].raiseError[(Long, Sha256)] { - new NotImplementedError(s"Unknown compression type $unknownCompressionType") - } - } + downloadBlobToFile(blob, fileStream) }(fileStream => blocker.delay(fileStream.close())) .map[Either[StorageException, GetResult]] { case (size, hash) => @@ -120,26 +107,53 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket } } - private def receiveRawStream(blob: Blob, targetStream: OutputStream): F[(Long, Sha256)] = { + private def downloadBlobToFile(blob: Blob, fileStream: OutputStream): F[(Long, Sha256)] = { + def getCompressionType: Option[String] = { + Option(blob.getMetadata.get(GcsStorageBackend.CompressionTypeHeader)).map(_.toLowerCase) + } + Sync[F] - .delay(new DigestOutputStream(targetStream, MessageDigest.getInstance("SHA-256"))) - .bracket { hashingStream => - val countingStream = new GcsStorageBackend.CountingOutputStream(hashingStream) - blocker.delay(blob.downloadTo(countingStream)).flatMap { _ => + .delay { + val countingStream = new GcsStorageBackend.CountingOutputStream(fileStream) + val hashingStream = new DigestOutputStream(countingStream, MessageDigest.getInstance("SHA-256")) + (countingStream, hashingStream) + } + .bracket { + case (countingStream, hashingStream) => { + getCompressionType match { + case None => + downloadBlobToStream(blob, hashingStream) + case Some("zstd") => + decodeZstdStream(blob, hashingStream) + case Some(unknown) => + throw new IllegalArgumentException(s"Unknown compression type $unknown") + } + }.flatMap { _ => Sync[F].delay { (countingStream.length, Sha256(hashingStream.getMessageDigest.digest)) } } - }(hashingStream => blocker.delay(hashingStream.close())) + } { + case (hashingStream, countingStream) => + Sync[F].delay { + hashingStream.close() + countingStream.close() + } + } } - private def receiveZstdStream(blob: Blob, targetStream: OutputStream): F[(Long, Sha256)] = { + private def decodeZstdStream(blob: Blob, targetStream: DigestOutputStream): F[Unit] = { Sync[F] .delay(new ZstdDecompressOutputStream(targetStream)) .bracket { decompressionStream => - receiveRawStream(blob, decompressionStream) - }(hashingStream => blocker.delay(hashingStream.close())) + downloadBlobToStream(blob, decompressionStream) + }(decompressionStream => Sync[F].delay(decompressionStream.close())) } + + private def downloadBlobToStream(blob: Blob, targetStream: OutputStream): F[Unit] = { + blocker.delay(blob.downloadTo(targetStream)) + } + override def close(): Unit = { () } @@ -148,8 +162,8 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket object GcsStorageBackend { private val DefaultConfig = ConfigFactory.defaultReference().getConfig("gcsBackendDefaults") - private val CompressionTypeHeader = "comp-type" - private val OriginalSizeHeader = "original-size" + private[gcs] val CompressionTypeHeader = "comp-type" + private[gcs] val OriginalSizeHeader = "original-size" def fromConfig[F[_]: Sync: ContextShift](config: Config, blocker: Blocker): Either[ConfigurationException, Resource[F, GcsStorageBackend[F]]] = { diff --git a/gcs/src/test/scala/com/avast/clients/storage/gcs/GcsStorageBackendTest.scala b/gcs/src/test/scala/com/avast/clients/storage/gcs/GcsStorageBackendTest.scala index 75ff7dc..b93d3ce 100644 --- a/gcs/src/test/scala/com/avast/clients/storage/gcs/GcsStorageBackendTest.scala +++ b/gcs/src/test/scala/com/avast/clients/storage/gcs/GcsStorageBackendTest.scala @@ -2,9 +2,10 @@ package com.avast.clients.storage.gcs import better.files.File import cats.effect.Blocker -import com.avast.clients.storage.gcs.TestImplicits.{StringOps, randomString} +import com.avast.clients.storage.gcs.TestImplicits._ import com.avast.clients.storage.{GetResult, HeadResult} import com.avast.scala.hashes.Sha256 +import com.github.luben.zstd.Zstd import com.google.cloud.storage.{Blob, BlobId, Storage} import monix.eval.Task import monix.execution.Scheduler.Implicits.global @@ -18,6 +19,7 @@ import org.scalatestplus.mockito.MockitoSugar import java.io.OutputStream import scala.concurrent.duration._ +import scala.jdk.CollectionConverters.MapHasAsJava @RunWith(classOf[JUnitRunner]) class GcsStorageBackendTest extends FunSuite with ScalaFutures with MockitoSugar { @@ -52,6 +54,44 @@ class GcsStorageBackendTest extends FunSuite with ScalaFutures with MockitoSugar assertResult(Right(HeadResult.Exists(fileSize)))(result) } + test("head-zstd") { + val fileSize = 1001100 + val originalContent = randomString(fileSize).getBytes() + val compressedContent = Zstd.compress(originalContent, 9) + val sha = originalContent.sha256 + val shaStr = sha.toString() + val bucketName = "bucket-tst" + + val blob = mock[Blob] + when(blob.getSize).thenReturn(compressedContent.length.toLong) + when(blob.getMetadata).thenReturn { + Map( + GcsStorageBackend.CompressionTypeHeader -> "zstd", + GcsStorageBackend.OriginalSizeHeader -> originalContent.length.toString + ).asJava + } + + val storageClient = mock[Storage] + when(storageClient.get(any[BlobId]())).thenAnswer { call => + val blobId = call.getArgument[BlobId](0) + val blobPath = blobId.getName + assertResult(bucketName)(blobId.getBucket) + assertResult { + List( + shaStr.substring(0, 2), + shaStr.substring(2, 4), + shaStr.substring(4, 6), + shaStr, + ) + }(blobPath.split("/").toList) + blob + } + + val result = composeTestBackend(storageClient, bucketName).head(sha).runSyncUnsafe(10.seconds) + + assertResult(Right(HeadResult.Exists(fileSize)))(result) + } + test("get") { val fileSize = 1001200 val content = randomString(fileSize) @@ -60,7 +100,6 @@ class GcsStorageBackendTest extends FunSuite with ScalaFutures with MockitoSugar val bucketName = "bucket-tst" val blob = mock[Blob] - when(blob.getSize).thenReturn(fileSize.toLong) when(blob.downloadTo(any[OutputStream]())).thenAnswer { call => val outputStream = call.getArgument[OutputStream](0) outputStream.write(content.getBytes()) @@ -90,6 +129,47 @@ class GcsStorageBackendTest extends FunSuite with ScalaFutures with MockitoSugar } } + test("get-zstd") { + val fileSize = 1024 * 1024 + val originalContent = randomString(fileSize).getBytes() + val compressedContent = Zstd.compress(originalContent, 9) + val sha = originalContent.sha256 + val shaStr = sha.toString() + val bucketName = "bucket-tst" + + val blob = mock[Blob] + when(blob.getMetadata).thenReturn { + Map(GcsStorageBackend.CompressionTypeHeader -> "zstd").asJava + } + when(blob.downloadTo(any[OutputStream]())).thenAnswer { call => + val outputStream = call.getArgument[OutputStream](0) + outputStream.write(compressedContent) + } + + val storageClient = mock[Storage] + when(storageClient.get(any[BlobId]())).thenAnswer { call => + val blobId = call.getArgument[BlobId](0) + val blobPath = blobId.getName + assertResult(bucketName)(blobId.getBucket) + assertResult { + List( + shaStr.substring(0, 2), + shaStr.substring(2, 4), + shaStr.substring(4, 6), + shaStr, + ) + }(blobPath.split("/").toList) + blob + } + + File.usingTemporaryFile() { file => + val result = composeTestBackend(storageClient, bucketName).get(sha, file).runSyncUnsafe(10.seconds) + assertResult(Right(GetResult.Downloaded(file, fileSize)))(result) + assertResult(sha.toString.toLowerCase)(file.sha256.toLowerCase) + assertResult(fileSize)(file.size) + } + } + test("composeObjectPath") { val sha = Sha256("d05af9a8494696906e8eec79843ca1e4bf408c280616a121ed92f9e92e2de831") assertResult("d0/5a/f9/d05af9a8494696906e8eec79843ca1e4bf408c280616a121ed92f9e92e2de831")(GcsStorageBackend.composeBlobPath(sha)) diff --git a/gcs/src/test/scala/com/avast/clients/storage/gcs/TestImplicits.scala b/gcs/src/test/scala/com/avast/clients/storage/gcs/TestImplicits.scala index 0c1579b..7b9bd2d 100644 --- a/gcs/src/test/scala/com/avast/clients/storage/gcs/TestImplicits.scala +++ b/gcs/src/test/scala/com/avast/clients/storage/gcs/TestImplicits.scala @@ -15,4 +15,11 @@ object TestImplicits { Sha256(digest.digest(s.getBytes)) } } + + implicit class BytesOps(val bytes: Array[Byte]) extends AnyVal { + def sha256: Sha256 = { + val digest = MessageDigest.getInstance("SHA-256") + Sha256(digest.digest(bytes)) + } + } }