Skip to content

Commit

Permalink
feat: GCS backend - Add support for the ZSTD decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
mi-char committed Dec 20, 2023
1 parent 518e33a commit 547ca3e
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 10 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ext {
metricsVersion = "3.0.2"
http4sVersion = "0.22.15"
gcsVersion = "2.30.1"
zstdVersion = "1.5.5-11"
monixVersion = "3.4.1" // Used only in tests.
}

Expand Down
2 changes: 2 additions & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ dependencies {

api "com.avast.metrics:metrics-scala_2.13:$metricsVersion"

implementation "com.github.luben:zstd-jni:$zstdVersion"

testImplementation "io.monix:monix_2.13:$monixVersion"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.avast.clients.storage.compression

import com.github.luben.zstd.{ZstdDecompressCtx, ZstdInputStream}

import java.io.OutputStream
import java.nio.ByteBuffer
import java.nio.channels.Channels

class ZstdDecompressOutputStream(outputStream: OutputStream) extends OutputStream {
private val decompressCtx = new ZstdDecompressCtx()
private val outputChannel = Channels.newChannel(outputStream)
private val closed = false

override def write(chunk: Array[Byte]): Unit = {
if (closed) {
throw new IllegalStateException("Stream is closed")
}

val inputBuffer = ByteBuffer.allocateDirect(chunk.length)
val outputBuffer = ByteBuffer.allocateDirect(ZstdInputStream.recommendedDOutSize().toInt)

inputBuffer.put(chunk)
inputBuffer.flip()

while (inputBuffer.hasRemaining) {
decompressCtx.decompressDirectByteBufferStream(outputBuffer, inputBuffer)

outputBuffer.flip()

while (outputBuffer.hasRemaining) {
outputChannel.write(outputBuffer)
}

outputBuffer.clear()
}
}

override def write(chunk: Array[Byte], offset: Int, length: Int): Unit = {
write(chunk.slice(offset, offset + length))
}

override def write(b: Int): Unit = {
write(Array(b.toByte))
}

override def close(): Unit = {
if (!closed) {
decompressCtx.close()
outputChannel.close()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.avast.clients.storage.compression

import com.avast.scala.hashes.Sha256
import com.github.luben.zstd.Zstd
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatestplus.junit.JUnitRunner

import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer
import java.security.MessageDigest

@RunWith(classOf[JUnitRunner])
class ZstdDecompressOutputStreamTest extends FunSuite {
private def computeSha256(data: Array[Byte]): Sha256 = {
Sha256(MessageDigest.getInstance("SHA-256").digest(data))
}

private def generateData(size: Int): Array[Byte] = {
val builder = Array.newBuilder[Byte]
var i = 0
while (i < size) {
builder += (i % 256).toByte
i += 1
}
builder.result()
}

test("decompress zstd stream") {
val chunkSize = 4 * 1024
val testCases = Seq(0, 1, chunkSize, 10 * 1024 * 1024)

for (testCase <- testCases) {
println(s"Test case: $testCase")

val original_data = generateData(testCase)
val original_sha256 = computeSha256(original_data)

val compressed_data = Zstd.compress(original_data, 9)

val sourceStream = ByteBuffer.wrap(compressed_data)
val targetStream = new ByteArrayOutputStream()

val decompressStream = new ZstdDecompressOutputStream(targetStream)

while (sourceStream.hasRemaining) {
val chunkSize = math.min(sourceStream.remaining(), 4 * 1024)
val chunk = new Array[Byte](chunkSize)
sourceStream.get(chunk)
decompressStream.write(chunk)
}

decompressStream.close()

val result = targetStream.toByteArray

assert(original_sha256 == computeSha256(result))
}
}
}
6 changes: 6 additions & 0 deletions gcs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ GCS backends supports multiple ways of authentication:
* Reading credential file from default paths (see https://cloud.google.com/docs/authentication/application-default-credentials#personal)
* For all options see https://cloud.google.com/docs/authentication/provide-credentials-adc#how-to

### Object decompression

GCS backend supports decompression of objects. The decision whether to decompress object or not is based on the `comp-type` header in the object's metadata.
If the header is present and contains `zstd` value, the object is decompressed on the fly. Otherwise the object is downloaded as is.

The only supported compression algorithm is currently `zstd`.

### Client initialization

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import better.files.File
import cats.effect.implicits.catsEffectSyntaxBracket
import cats.effect.{Blocker, ContextShift, Resource, Sync}
import cats.syntax.all._
import com.avast.clients.storage.compression.ZstdDecompressOutputStream
import com.avast.clients.storage.gcs.GcsStorageBackend.composeBlobPath
import com.avast.clients.storage.{ConfigurationException, GetResult, HeadResult, StorageBackend, StorageException}
import com.avast.scala.hashes.Sha256
Expand All @@ -17,7 +18,7 @@ import pureconfig.generic.ProductHint
import pureconfig.generic.auto._
import pureconfig.{CamelCase, ConfigFieldMapping}

import java.io.{ByteArrayInputStream, FileInputStream}
import java.io.{ByteArrayInputStream, FileInputStream, OutputStream}
import java.nio.charset.StandardCharsets
import java.nio.file.StandardOpenOption
import java.security.{DigestOutputStream, MessageDigest}
Expand All @@ -34,7 +35,12 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket
blob <- getBlob(sha256)
result = blob match {
case Some(blob) =>
HeadResult.Exists(blob.getSize)
blob.getMetadata.get(GcsStorageBackend.OriginalSizeHeader) match {
case null =>
HeadResult.Exists(blob.getSize)
case originalSize =>
HeadResult.Exists(originalSize.toLong)
}
case None =>
HeadResult.NotFound
}
Expand Down Expand Up @@ -81,19 +87,24 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket
}

private def receiveStreamedFile(blob: Blob, destination: File, expectedHash: Sha256): F[Either[StorageException, GetResult]] = {
def getCompressionType: Option[String] = {
Option(blob.getMetadata.get(GcsStorageBackend.CompressionTypeHeader)).map(_.toLowerCase)
}

Sync[F].delay(logger.debug(s"Downloading streamed data to $destination")) >>
blocker
.delay(destination.newOutputStream(FileStreamOpenOptions))
.bracket { fileStream =>
Sync[F]
.delay(new DigestOutputStream(fileStream, MessageDigest.getInstance("SHA-256")))
.bracket { stream =>
blocker.delay(blob.downloadTo(stream)).flatMap { _ =>
Sync[F].delay {
(blob.getSize, Sha256(stream.getMessageDigest.digest))
}
getCompressionType match {
case None =>
receiveRawStream(blob, fileStream)
case Some("zstd") =>
receiveZstdStream(blob, fileStream)
case Some(unknownCompressionType) =>
Sync[F].raiseError[(Long, Sha256)] {
new NotImplementedError(s"Unknown compression type $unknownCompressionType")
}
}(stream => blocker.delay(stream.close()))
}
}(fileStream => blocker.delay(fileStream.close()))
.map[Either[StorageException, GetResult]] {
case (size, hash) =>
Expand All @@ -109,6 +120,26 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket
}
}

private def receiveRawStream(blob: Blob, targetStream: OutputStream): F[(Long, Sha256)] = {
Sync[F]
.delay(new DigestOutputStream(targetStream, MessageDigest.getInstance("SHA-256")))
.bracket { hashingStream =>
val countingStream = new GcsStorageBackend.CountingOutputStream(hashingStream)
blocker.delay(blob.downloadTo(countingStream)).flatMap { _ =>
Sync[F].delay {
(countingStream.length, Sha256(hashingStream.getMessageDigest.digest))
}
}
}(hashingStream => blocker.delay(hashingStream.close()))
}

private def receiveZstdStream(blob: Blob, targetStream: OutputStream): F[(Long, Sha256)] = {
Sync[F]
.delay(new ZstdDecompressOutputStream(targetStream))
.bracket { decompressionStream =>
receiveRawStream(blob, decompressionStream)
}(hashingStream => blocker.delay(hashingStream.close()))
}
override def close(): Unit = {
()
}
Expand All @@ -117,6 +148,9 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket
object GcsStorageBackend {
private val DefaultConfig = ConfigFactory.defaultReference().getConfig("gcsBackendDefaults")

private val CompressionTypeHeader = "comp-type"
private val OriginalSizeHeader = "original-size"

def fromConfig[F[_]: Sync: ContextShift](config: Config,
blocker: Blocker): Either[ConfigurationException, Resource[F, GcsStorageBackend[F]]] = {

Expand Down Expand Up @@ -148,6 +182,35 @@ object GcsStorageBackend {
String.join("/", sha256Hex.substring(0, 2), sha256Hex.substring(2, 4), sha256Hex.substring(4, 6), sha256Hex)
}

private[gcs] class CountingOutputStream(target: OutputStream) extends OutputStream {
private var count: Long = 0

def length: Long = count

override def write(b: Int): Unit = {
target.write(b)
count += 1
}

override def write(b: Array[Byte]): Unit = {
target.write(b)
count += b.length
}

override def write(b: Array[Byte], off: Int, len: Int): Unit = {
target.write(b, off, len)
count += len
}

override def flush(): Unit = {
target.flush()
}

override def close(): Unit = {
target.close()
}
}

def prepareStorageClient[F[_]: Sync: ContextShift](conf: GcsBackendConfiguration,
blocker: Blocker): Either[ConfigurationException, Storage] = {
Either
Expand Down

0 comments on commit 547ca3e

Please sign in to comment.