diff --git a/build.gradle b/build.gradle index bb433c4..f9becf8 100644 --- a/build.gradle +++ b/build.gradle @@ -14,6 +14,7 @@ ext { metricsVersion = "3.0.2" http4sVersion = "0.22.15" gcsVersion = "2.30.1" + zstdVersion = "1.5.5-11" monixVersion = "3.4.1" // Used only in tests. } diff --git a/core/build.gradle b/core/build.gradle index 6cffc5c..2e6ecf1 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -13,5 +13,7 @@ dependencies { api "com.avast.metrics:metrics-scala_2.13:$metricsVersion" + implementation "com.github.luben:zstd-jni:$zstdVersion" + testImplementation "io.monix:monix_2.13:$monixVersion" } diff --git a/core/src/main/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStream.scala b/core/src/main/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStream.scala new file mode 100644 index 0000000..bdeb495 --- /dev/null +++ b/core/src/main/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStream.scala @@ -0,0 +1,52 @@ +package com.avast.clients.storage.compression + +import com.github.luben.zstd.{ZstdDecompressCtx, ZstdInputStream} + +import java.io.OutputStream +import java.nio.ByteBuffer +import java.nio.channels.Channels + +class ZstdDecompressOutputStream(outputStream: OutputStream) extends OutputStream { + private val decompressCtx = new ZstdDecompressCtx() + private val outputChannel = Channels.newChannel(outputStream) + private val closed = false + + override def write(chunk: Array[Byte]): Unit = { + if (closed) { + throw new IllegalStateException("Stream is closed") + } + + val inputBuffer = ByteBuffer.allocateDirect(chunk.length) + val outputBuffer = ByteBuffer.allocateDirect(ZstdInputStream.recommendedDOutSize().toInt) + + inputBuffer.put(chunk) + inputBuffer.flip() + + while (inputBuffer.hasRemaining) { + decompressCtx.decompressDirectByteBufferStream(outputBuffer, inputBuffer) + + outputBuffer.flip() + + while (outputBuffer.hasRemaining) { + outputChannel.write(outputBuffer) + } + + outputBuffer.clear() + } + } + + override def write(chunk: Array[Byte], offset: Int, length: Int): Unit = { + write(chunk.slice(offset, offset + length)) + } + + override def write(b: Int): Unit = { + write(Array(b.toByte)) + } + + override def close(): Unit = { + if (!closed) { + decompressCtx.close() + outputChannel.close() + } + } +} diff --git a/core/src/test/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStreamTest.scala b/core/src/test/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStreamTest.scala new file mode 100644 index 0000000..e49a1e0 --- /dev/null +++ b/core/src/test/scala/com/avast/clients/storage/compression/ZstdDecompressOutputStreamTest.scala @@ -0,0 +1,60 @@ +package com.avast.clients.storage.compression + +import com.avast.scala.hashes.Sha256 +import com.github.luben.zstd.Zstd +import org.junit.runner.RunWith +import org.scalatest.FunSuite +import org.scalatestplus.junit.JUnitRunner + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer +import java.security.MessageDigest + +@RunWith(classOf[JUnitRunner]) +class ZstdDecompressOutputStreamTest extends FunSuite { + private def computeSha256(data: Array[Byte]): Sha256 = { + Sha256(MessageDigest.getInstance("SHA-256").digest(data)) + } + + private def generateData(size: Int): Array[Byte] = { + val builder = Array.newBuilder[Byte] + var i = 0 + while (i < size) { + builder += (i % 256).toByte + i += 1 + } + builder.result() + } + + test("decompress zstd stream") { + val chunkSize = 4 * 1024 + val testCases = Seq(0, 1, chunkSize, 10 * 1024 * 1024) + + for (testCase <- testCases) { + println(s"Test case: $testCase") + + val original_data = generateData(testCase) + val original_sha256 = computeSha256(original_data) + + val compressed_data = Zstd.compress(original_data, 9) + + val sourceStream = ByteBuffer.wrap(compressed_data) + val targetStream = new ByteArrayOutputStream() + + val decompressStream = new ZstdDecompressOutputStream(targetStream) + + while (sourceStream.hasRemaining) { + val chunkSize = math.min(sourceStream.remaining(), 4 * 1024) + val chunk = new Array[Byte](chunkSize) + sourceStream.get(chunk) + decompressStream.write(chunk) + } + + decompressStream.close() + + val result = targetStream.toByteArray + + assert(original_sha256 == computeSha256(result)) + } + } +} diff --git a/gcs/README.md b/gcs/README.md index 586ebac..e36f410 100644 --- a/gcs/README.md +++ b/gcs/README.md @@ -25,6 +25,12 @@ GCS backends supports multiple ways of authentication: * Reading credential file from default paths (see https://cloud.google.com/docs/authentication/application-default-credentials#personal) * For all options see https://cloud.google.com/docs/authentication/provide-credentials-adc#how-to +### Object decompression + +GCS backend supports decompression of objects. The decision whether to decompress object or not is based on the `comp-type` header in the object's metadata. +If the header is present and contains `zstd` value, the object is decompressed on the fly. Otherwise the object is downloaded as is. + +The only supported compression algorithm is currently `zstd`. ### Client initialization diff --git a/gcs/src/main/scala/com/avast/clients/storage/gcs/GcsStorageBackend.scala b/gcs/src/main/scala/com/avast/clients/storage/gcs/GcsStorageBackend.scala index 6c5db95..c71529f 100644 --- a/gcs/src/main/scala/com/avast/clients/storage/gcs/GcsStorageBackend.scala +++ b/gcs/src/main/scala/com/avast/clients/storage/gcs/GcsStorageBackend.scala @@ -4,6 +4,7 @@ import better.files.File import cats.effect.implicits.catsEffectSyntaxBracket import cats.effect.{Blocker, ContextShift, Resource, Sync} import cats.syntax.all._ +import com.avast.clients.storage.compression.ZstdDecompressOutputStream import com.avast.clients.storage.gcs.GcsStorageBackend.composeBlobPath import com.avast.clients.storage.{ConfigurationException, GetResult, HeadResult, StorageBackend, StorageException} import com.avast.scala.hashes.Sha256 @@ -17,7 +18,7 @@ import pureconfig.generic.ProductHint import pureconfig.generic.auto._ import pureconfig.{CamelCase, ConfigFieldMapping} -import java.io.{ByteArrayInputStream, FileInputStream} +import java.io.{ByteArrayInputStream, FileInputStream, OutputStream} import java.nio.charset.StandardCharsets import java.nio.file.StandardOpenOption import java.security.{DigestOutputStream, MessageDigest} @@ -34,7 +35,12 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket blob <- getBlob(sha256) result = blob match { case Some(blob) => - HeadResult.Exists(blob.getSize) + blob.getMetadata.get(GcsStorageBackend.OriginalSizeHeader) match { + case null => + HeadResult.Exists(blob.getSize) + case originalSize => + HeadResult.Exists(originalSize.toLong) + } case None => HeadResult.NotFound } @@ -81,19 +87,24 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket } private def receiveStreamedFile(blob: Blob, destination: File, expectedHash: Sha256): F[Either[StorageException, GetResult]] = { + def getCompressionType: Option[String] = { + Option(blob.getMetadata.get(GcsStorageBackend.CompressionTypeHeader)).map(_.toLowerCase) + } + Sync[F].delay(logger.debug(s"Downloading streamed data to $destination")) >> blocker .delay(destination.newOutputStream(FileStreamOpenOptions)) .bracket { fileStream => - Sync[F] - .delay(new DigestOutputStream(fileStream, MessageDigest.getInstance("SHA-256"))) - .bracket { stream => - blocker.delay(blob.downloadTo(stream)).flatMap { _ => - Sync[F].delay { - (blob.getSize, Sha256(stream.getMessageDigest.digest)) - } + getCompressionType match { + case None => + receiveRawStream(blob, fileStream) + case Some("zstd") => + receiveZstdStream(blob, fileStream) + case Some(unknownCompressionType) => + Sync[F].raiseError[(Long, Sha256)] { + new NotImplementedError(s"Unknown compression type $unknownCompressionType") } - }(stream => blocker.delay(stream.close())) + } }(fileStream => blocker.delay(fileStream.close())) .map[Either[StorageException, GetResult]] { case (size, hash) => @@ -109,6 +120,26 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket } } + private def receiveRawStream(blob: Blob, targetStream: OutputStream): F[(Long, Sha256)] = { + Sync[F] + .delay(new DigestOutputStream(targetStream, MessageDigest.getInstance("SHA-256"))) + .bracket { hashingStream => + val countingStream = new GcsStorageBackend.CountingOutputStream(hashingStream) + blocker.delay(blob.downloadTo(countingStream)).flatMap { _ => + Sync[F].delay { + (countingStream.length, Sha256(hashingStream.getMessageDigest.digest)) + } + } + }(hashingStream => blocker.delay(hashingStream.close())) + } + + private def receiveZstdStream(blob: Blob, targetStream: OutputStream): F[(Long, Sha256)] = { + Sync[F] + .delay(new ZstdDecompressOutputStream(targetStream)) + .bracket { decompressionStream => + receiveRawStream(blob, decompressionStream) + }(hashingStream => blocker.delay(hashingStream.close())) + } override def close(): Unit = { () } @@ -117,6 +148,9 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket object GcsStorageBackend { private val DefaultConfig = ConfigFactory.defaultReference().getConfig("gcsBackendDefaults") + private val CompressionTypeHeader = "comp-type" + private val OriginalSizeHeader = "original-size" + def fromConfig[F[_]: Sync: ContextShift](config: Config, blocker: Blocker): Either[ConfigurationException, Resource[F, GcsStorageBackend[F]]] = { @@ -148,6 +182,35 @@ object GcsStorageBackend { String.join("/", sha256Hex.substring(0, 2), sha256Hex.substring(2, 4), sha256Hex.substring(4, 6), sha256Hex) } + private[gcs] class CountingOutputStream(target: OutputStream) extends OutputStream { + private var count: Long = 0 + + def length: Long = count + + override def write(b: Int): Unit = { + target.write(b) + count += 1 + } + + override def write(b: Array[Byte]): Unit = { + target.write(b) + count += b.length + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + target.write(b, off, len) + count += len + } + + override def flush(): Unit = { + target.flush() + } + + override def close(): Unit = { + target.close() + } + } + def prepareStorageClient[F[_]: Sync: ContextShift](conf: GcsBackendConfiguration, blocker: Blocker): Either[ConfigurationException, Storage] = { Either