Skip to content

Commit

Permalink
feat: GCS backend - Add missing test, fix stream composition + change…
Browse files Browse the repository at this point in the history
…s from the review
  • Loading branch information
mi-char committed Dec 20, 2023
1 parent 547ca3e commit c0b9d76
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@ import java.nio.channels.Channels
class ZstdDecompressOutputStream(outputStream: OutputStream) extends OutputStream {
private val decompressCtx = new ZstdDecompressCtx()
private val outputChannel = Channels.newChannel(outputStream)
private val closed = false
private val outputBuffer = ByteBuffer.allocateDirect(ZstdInputStream.recommendedDOutSize().toInt)

private var closed = false

override def write(chunk: Array[Byte]): Unit = {
if (closed) {
throw new IllegalStateException("Stream is closed")
}

val inputBuffer = ByteBuffer.allocateDirect(chunk.length)
val outputBuffer = ByteBuffer.allocateDirect(ZstdInputStream.recommendedDOutSize().toInt)

inputBuffer.put(chunk)
inputBuffer.flip()
inputBuffer.rewind()

while (inputBuffer.hasRemaining) {
outputBuffer.clear()

decompressCtx.decompressDirectByteBufferStream(outputBuffer, inputBuffer)

outputBuffer.flip()

while (outputBuffer.hasRemaining) {
outputChannel.write(outputBuffer)
}

outputBuffer.clear()
}
}

Expand All @@ -45,7 +45,9 @@ class ZstdDecompressOutputStream(outputStream: OutputStream) extends OutputStrea

override def close(): Unit = {
if (!closed) {
closed = true
decompressCtx.close()
outputBuffer.clear()
outputChannel.close()
}
}
Expand Down
1 change: 1 addition & 0 deletions gcs/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ dependencies {
implementation "com.google.cloud:google-cloud-storage:$gcsVersion"

testImplementation "io.monix:monix_2.13:$monixVersion"
testImplementation "com.github.luben:zstd-jni:$zstdVersion"
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,11 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket
}

private def receiveStreamedFile(blob: Blob, destination: File, expectedHash: Sha256): F[Either[StorageException, GetResult]] = {
def getCompressionType: Option[String] = {
Option(blob.getMetadata.get(GcsStorageBackend.CompressionTypeHeader)).map(_.toLowerCase)
}

Sync[F].delay(logger.debug(s"Downloading streamed data to $destination")) >>
blocker
.delay(destination.newOutputStream(FileStreamOpenOptions))
.bracket { fileStream =>
getCompressionType match {
case None =>
receiveRawStream(blob, fileStream)
case Some("zstd") =>
receiveZstdStream(blob, fileStream)
case Some(unknownCompressionType) =>
Sync[F].raiseError[(Long, Sha256)] {
new NotImplementedError(s"Unknown compression type $unknownCompressionType")
}
}
downloadBlobToFile(blob, fileStream)
}(fileStream => blocker.delay(fileStream.close()))
.map[Either[StorageException, GetResult]] {
case (size, hash) =>
Expand All @@ -120,26 +107,53 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket
}
}

private def receiveRawStream(blob: Blob, targetStream: OutputStream): F[(Long, Sha256)] = {
private def downloadBlobToFile(blob: Blob, fileStream: OutputStream): F[(Long, Sha256)] = {
def getCompressionType: Option[String] = {
Option(blob.getMetadata.get(GcsStorageBackend.CompressionTypeHeader)).map(_.toLowerCase)
}

Sync[F]
.delay(new DigestOutputStream(targetStream, MessageDigest.getInstance("SHA-256")))
.bracket { hashingStream =>
val countingStream = new GcsStorageBackend.CountingOutputStream(hashingStream)
blocker.delay(blob.downloadTo(countingStream)).flatMap { _ =>
.delay {
val countingStream = new GcsStorageBackend.CountingOutputStream(fileStream)
val hashingStream = new DigestOutputStream(countingStream, MessageDigest.getInstance("SHA-256"))
(countingStream, hashingStream)
}
.bracket {
case (countingStream, hashingStream) => {
getCompressionType match {
case None =>
downloadBlobToStream(blob, hashingStream)
case Some("zstd") =>
decodeZstdStream(blob, hashingStream)
case Some(unknown) =>
throw new IllegalArgumentException(s"Unknown compression type $unknown")
}
}.flatMap { _ =>
Sync[F].delay {
(countingStream.length, Sha256(hashingStream.getMessageDigest.digest))
}
}
}(hashingStream => blocker.delay(hashingStream.close()))
} {
case (hashingStream, countingStream) =>
Sync[F].delay {
hashingStream.close()
countingStream.close()
}
}
}

private def receiveZstdStream(blob: Blob, targetStream: OutputStream): F[(Long, Sha256)] = {
private def decodeZstdStream(blob: Blob, targetStream: DigestOutputStream): F[Unit] = {
Sync[F]
.delay(new ZstdDecompressOutputStream(targetStream))
.bracket { decompressionStream =>
receiveRawStream(blob, decompressionStream)
}(hashingStream => blocker.delay(hashingStream.close()))
downloadBlobToStream(blob, decompressionStream)
}(decompressionStream => Sync[F].delay(decompressionStream.close()))
}

private def downloadBlobToStream(blob: Blob, targetStream: OutputStream): F[Unit] = {
blocker.delay(blob.downloadTo(targetStream))
}

override def close(): Unit = {
()
}
Expand All @@ -148,8 +162,8 @@ class GcsStorageBackend[F[_]: Sync: ContextShift](storageClient: Storage, bucket
object GcsStorageBackend {
private val DefaultConfig = ConfigFactory.defaultReference().getConfig("gcsBackendDefaults")

private val CompressionTypeHeader = "comp-type"
private val OriginalSizeHeader = "original-size"
private[gcs] val CompressionTypeHeader = "comp-type"
private[gcs] val OriginalSizeHeader = "original-size"

def fromConfig[F[_]: Sync: ContextShift](config: Config,
blocker: Blocker): Either[ConfigurationException, Resource[F, GcsStorageBackend[F]]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package com.avast.clients.storage.gcs

import better.files.File
import cats.effect.Blocker
import com.avast.clients.storage.gcs.TestImplicits.{StringOps, randomString}
import com.avast.clients.storage.gcs.TestImplicits._
import com.avast.clients.storage.{GetResult, HeadResult}
import com.avast.scala.hashes.Sha256
import com.github.luben.zstd.Zstd
import com.google.cloud.storage.{Blob, BlobId, Storage}
import monix.eval.Task
import monix.execution.Scheduler.Implicits.global
Expand All @@ -18,6 +19,7 @@ import org.scalatestplus.mockito.MockitoSugar

import java.io.OutputStream
import scala.concurrent.duration._
import scala.jdk.CollectionConverters.MapHasAsJava

@RunWith(classOf[JUnitRunner])
class GcsStorageBackendTest extends FunSuite with ScalaFutures with MockitoSugar {
Expand Down Expand Up @@ -52,6 +54,44 @@ class GcsStorageBackendTest extends FunSuite with ScalaFutures with MockitoSugar
assertResult(Right(HeadResult.Exists(fileSize)))(result)
}

test("head-zstd") {
val fileSize = 1001100
val originalContent = randomString(fileSize).getBytes()
val compressedContent = Zstd.compress(originalContent, 9)
val sha = originalContent.sha256
val shaStr = sha.toString()
val bucketName = "bucket-tst"

val blob = mock[Blob]
when(blob.getSize).thenReturn(compressedContent.length.toLong)
when(blob.getMetadata).thenReturn {
Map(
GcsStorageBackend.CompressionTypeHeader -> "zstd",
GcsStorageBackend.OriginalSizeHeader -> originalContent.length.toString
).asJava
}

val storageClient = mock[Storage]
when(storageClient.get(any[BlobId]())).thenAnswer { call =>
val blobId = call.getArgument[BlobId](0)
val blobPath = blobId.getName
assertResult(bucketName)(blobId.getBucket)
assertResult {
List(
shaStr.substring(0, 2),
shaStr.substring(2, 4),
shaStr.substring(4, 6),
shaStr,
)
}(blobPath.split("/").toList)
blob
}

val result = composeTestBackend(storageClient, bucketName).head(sha).runSyncUnsafe(10.seconds)

assertResult(Right(HeadResult.Exists(fileSize)))(result)
}

test("get") {
val fileSize = 1001200
val content = randomString(fileSize)
Expand All @@ -60,7 +100,6 @@ class GcsStorageBackendTest extends FunSuite with ScalaFutures with MockitoSugar
val bucketName = "bucket-tst"

val blob = mock[Blob]
when(blob.getSize).thenReturn(fileSize.toLong)
when(blob.downloadTo(any[OutputStream]())).thenAnswer { call =>
val outputStream = call.getArgument[OutputStream](0)
outputStream.write(content.getBytes())
Expand Down Expand Up @@ -90,6 +129,47 @@ class GcsStorageBackendTest extends FunSuite with ScalaFutures with MockitoSugar
}
}

test("get-zstd") {
val fileSize = 1024 * 1024
val originalContent = randomString(fileSize).getBytes()
val compressedContent = Zstd.compress(originalContent, 9)
val sha = originalContent.sha256
val shaStr = sha.toString()
val bucketName = "bucket-tst"

val blob = mock[Blob]
when(blob.getMetadata).thenReturn {
Map(GcsStorageBackend.CompressionTypeHeader -> "zstd").asJava
}
when(blob.downloadTo(any[OutputStream]())).thenAnswer { call =>
val outputStream = call.getArgument[OutputStream](0)
outputStream.write(compressedContent)
}

val storageClient = mock[Storage]
when(storageClient.get(any[BlobId]())).thenAnswer { call =>
val blobId = call.getArgument[BlobId](0)
val blobPath = blobId.getName
assertResult(bucketName)(blobId.getBucket)
assertResult {
List(
shaStr.substring(0, 2),
shaStr.substring(2, 4),
shaStr.substring(4, 6),
shaStr,
)
}(blobPath.split("/").toList)
blob
}

File.usingTemporaryFile() { file =>
val result = composeTestBackend(storageClient, bucketName).get(sha, file).runSyncUnsafe(10.seconds)
assertResult(Right(GetResult.Downloaded(file, fileSize)))(result)
assertResult(sha.toString.toLowerCase)(file.sha256.toLowerCase)
assertResult(fileSize)(file.size)
}
}

test("composeObjectPath") {
val sha = Sha256("d05af9a8494696906e8eec79843ca1e4bf408c280616a121ed92f9e92e2de831")
assertResult("d0/5a/f9/d05af9a8494696906e8eec79843ca1e4bf408c280616a121ed92f9e92e2de831")(GcsStorageBackend.composeBlobPath(sha))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@ object TestImplicits {
Sha256(digest.digest(s.getBytes))
}
}

implicit class BytesOps(val bytes: Array[Byte]) extends AnyVal {
def sha256: Sha256 = {
val digest = MessageDigest.getInstance("SHA-256")
Sha256(digest.digest(bytes))
}
}
}

0 comments on commit c0b9d76

Please sign in to comment.