From ff8862561b82c5ed07675deb03d02db13f6ca2cc Mon Sep 17 00:00:00 2001 From: Ben Zhang Date: Sat, 28 Dec 2024 19:06:39 +0000 Subject: [PATCH] Support zstd compression --- src/docker_unpack/cli.py | 10 ++- src/docker_unpack/utils.py | 90 ++++++++++++++++++++++++++- tests/integration/compression/test.sh | 4 +- 3 files changed, 98 insertions(+), 6 deletions(-) diff --git a/src/docker_unpack/cli.py b/src/docker_unpack/cli.py index b0e8759..7f37bf2 100644 --- a/src/docker_unpack/cli.py +++ b/src/docker_unpack/cli.py @@ -11,7 +11,7 @@ from watcloud_utils.typer import app, typer from ._version import __version__ -from .utils import generate_env, generate_runscript +from .utils import generate_env, generate_runscript, MyTarFile, StreamProxy @app.command() @@ -30,7 +30,8 @@ def unpack(input_file: typer.FileBinaryRead, output_dir: Path): with tempfile.TemporaryDirectory() as temp_dir: logger.info(f"Extracting tar file to {temp_dir=}") - with tarfile.open(fileobj=input_file, mode="r|*") as tar: + input_file_proxy = StreamProxy(input_file) + with MyTarFile.open(fileobj=input_file_proxy, mode=f"r{'|' if input_file_proxy.supports_streaming() else ':'}{input_file_proxy.getcomptype()}") as tar: tar.extractall(temp_dir) manifest_path = Path(temp_dir) / "manifest.json" @@ -55,7 +56,10 @@ def unpack(input_file: typer.FileBinaryRead, output_dir: Path): layer_path = Path(temp_dir) / layer logger.info(f"Extracting {layer_path=}") - with tarfile.open(layer_path) as tar: + with open(layer_path, "rb") as f: + comptype = StreamProxy(f).getcomptype() + + with MyTarFile.open(layer_path, f"r:{comptype}") as tar: for member in tar: basename = os.path.basename(member.name) diff --git a/src/docker_unpack/utils.py b/src/docker_unpack/utils.py index e868526..7afaf21 100644 --- a/src/docker_unpack/utils.py +++ b/src/docker_unpack/utils.py @@ -1,7 +1,11 @@ import os +import tarfile +import typing from pathlib import Path + from watcloud_utils.logging import logger + def escape(value): """Escapes special characters in a string for use in a shell script.""" return value.replace('"', r"\"").replace("'", r"\'") @@ -120,4 +124,88 @@ def generate_env(root_path: Path, img_config: dict): os.fsync(f.fileno()) # Set executable permissions - os.chmod(env_path, 0o755) \ No newline at end of file + os.chmod(env_path, 0o755) + + +class StreamProxy: + """ + A stream wrapper to detect compression type. + + Derived from https://github.com/python/cpython/blob/2cf396c368a188e9142843e566ce6d8e6eb08999/Lib/tarfile.py#L574-L598 + """ + + def __init__(self, fileobj): + self.fileobj = fileobj + self.buf = self.fileobj.read(tarfile.BLOCKSIZE) + + def read(self, size): + self.read = self.fileobj.read + return self.buf + + def getcomptype(self): + if self.buf.startswith(b"\x1f\x8b\x08"): + return "gz" + elif self.buf[0:3] == b"BZh" and self.buf[4:10] == b"1AY&SY": + return "bz2" + elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")): + return "xz" + elif self.buf.startswith(b"\x28\xb5\x2f\xfd"): + return "zst" + else: + return "tar" + + def supports_streaming(self): + comptype = self.getcomptype() + return comptype not in ("zst",) + + def close(self): + self.fileobj.close() + + +class MyTarFile(tarfile.TarFile): + """ + A custom TarFile class that supports more compression types. + + Derived from: + - https://github.com/python/cpython/issues/81276#issuecomment-1966037544 + """ + + OPEN_METH = {"zst": "zstopen"} | tarfile.TarFile.OPEN_METH + + @classmethod + def zstopen( + cls, + name: str , + mode: typing.Literal["r", "w", "x"] = "r", + fileobj: typing.Optional[typing.BinaryIO] = None, + ) -> tarfile.TarFile: + if mode not in ("r", "w", "x"): + raise NotImplementedError(f"mode `{mode}' not implemented for zst") + try: + import zstandard + except ImportError: + raise tarfile.CompressionError("zstandard module not available") + if mode == "r": + zfobj = zstandard.open(fileobj or name, "rb") + else: + zfobj = zstandard.open( + fileobj or name, + mode + "b", + cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1), + ) + try: + print(f"calling taropen with {name=}, {mode=}, {zfobj=}") + tarobj = cls.taropen(name, mode, zfobj) + except (OSError, EOFError, zstandard.ZstdError) as exc: + zfobj.close() + if mode == "r": + raise tarfile.ReadError("not a zst file") from exc + raise + except: + zfobj.close() + raise + # Setting the _extfileobj attribute is important to signal a need to + # close this object and thus flush the compressed stream. + # Unfortunately, tarfile.pyi doesn't know about it. + tarobj._extfileobj = False # type: ignore + return tarobj diff --git a/tests/integration/compression/test.sh b/tests/integration/compression/test.sh index 987d1cb..7dcd3ac 100755 --- a/tests/integration/compression/test.sh +++ b/tests/integration/compression/test.sh @@ -10,7 +10,7 @@ trap 'echo "Error on line $LINENO: $BASH_COMMAND"; exit 1' ERR docker pull alpine # MARK: Compressing the whole package -for compression in gzip bzip2 xz; do +for compression in zstd gzip bzip2 xz; do echo "Testing package compression with $compression" __tmpdir=$(mktemp -d) docker save alpine | $compression > "$__tmpdir/image.tar" @@ -22,7 +22,7 @@ done # MARK: Compressing the layers docker buildx create --name compression-test --driver docker-container --use -for compression in gzip estargz; do +for compression in zstd gzip estargz; do echo "Testing layer compression with $compression" __tmpdir=$(mktemp -d)