From 002e7dffa2f0c688a1b1f54464edbe8507294255 Mon Sep 17 00:00:00 2001 From: Alex Cabrera Date: Mon, 28 Oct 2024 16:42:01 -0700 Subject: [PATCH 1/4] add jpeg quality option --- .../basic_dataset_conversion.md | 92 ++++++--- streaming/base/format/mds/encodings.py | 195 ++++++++++-------- 2 files changed, 176 insertions(+), 111 deletions(-) diff --git a/docs/source/preparing_datasets/basic_dataset_conversion.md b/docs/source/preparing_datasets/basic_dataset_conversion.md index 58b214078..d23237bb5 100644 --- a/docs/source/preparing_datasets/basic_dataset_conversion.md +++ b/docs/source/preparing_datasets/basic_dataset_conversion.md @@ -7,11 +7,13 @@ This guide covers how to convert your raw data to MDS format using {class}`strea Use {class}`streaming.MDSWriter` to convert raw data to MDS format. MDSWriter is like a native file writer; instead of writing the content line by line, MDSWriter writes the data sample by sample. It writes the data into shard files in a sequential manner (for example, `shard.00000.mds`, then `shard.00001.mds`, and so on). Configure {class}`streaming.MDSWriter` according to your requirements with the parameters below: 1. An `out` parameter is an output directory to save shard files. The `out` directory can be specified in three ways: - * **Local path**: Shard files are stored locally. - * **Remote path**: A local temporary directory is created to cache the shard files, and when shard creation is complete, they are uploaded to the remote location. - * **`(local_dir, remote_dir)` tuple**: Shard files are saved in the specified `local_dir` and uploaded to `remote_dir`. + +- **Local path**: Shard files are stored locally. +- **Remote path**: A local temporary directory is created to cache the shard files, and when shard creation is complete, they are uploaded to the remote location. +- **`(local_dir, remote_dir)` tuple**: Shard files are saved in the specified `local_dir` and uploaded to `remote_dir`. + ```python out = '/local/data' out = 's3://bucket/data' # Will create a temporary local dir @@ -22,37 +24,39 @@ out = ('/local/data', 'oci://bucket/data') 3. A `column` parameter is a `dict` mapping a feature name or label name with a streaming supported encoding type. `MDSWriter` encodes your data to bytes, and at training time, data gets decoded back automatically to its original form. The `index.json` file stores `column` metadata for decoding. Supported encoding formats are: -| Category | Name | Class | Notes | -|--------------------|---------------|--------------|--------------------------| -| Encoding | 'bytes' | `Bytes` | no-op encoding | -| Encoding | 'str' | `Str` | stores in UTF-8 | -| Encoding | 'int' | `Int` | Python `int`, uses `numpy.int64` for encoding | -| Numpy Array | 'ndarray:dtype:shape' | `NDArray(dtype: Optional[str] = None, shape: Optional[Tuple[int]] = None)` | uses `numpy.ndarray` | -| Numpy Unsigned Int | 'uint8' | `UInt8` | uses `numpy.uint8` | -| Numpy Unsigned Int | 'uint16' | `UInt16` | uses `numpy.uint16` | -| Numpy Unsigned Int | 'uint32' | `Uint32` | uses `numpy.uint32` | -| Numpy Unsigned Int | 'uint64' | `Uint64` | uses `numpy.uint64` | -| Numpy Signed Int | 'int8' | `Int8` | uses `numpy.int8` | -| Numpy Signed Int | 'int16' | `Int16` | uses `numpy.int16` | -| Numpy Signed Int | 'int32' | `Int32` | uses `numpy.int32` | -| Numpy Signed Int | 'int64' | `Int64` | uses `numpy.int64` | -| Numpy Float | 'float16' | `Float16` | uses `numpy.float16` | -| Numpy Float | 'float32' | `Float32` | uses `numpy.float32` | -| Numpy Float | 'float64' | `Float64` | uses `numpy.float64` | -| Numerical String | 'str_int' | `StrInt` | stores in UTF-8 | -| Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 | -| Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 | -| Image | 'pil' | `PIL` | raw PIL image class ([link]((https://pillow.readthedocs.io/en/stable/reference/Image.html))) | -| Image | 'jpeg' | `JPEG` | PIL image as JPEG | -| Image | 'png' | `PNG` | PIL image as PNG | -| Pickle | 'pkl' | `Pickle` | arbitrary Python objects | -| JSON | 'json' | `JSON` | arbitrary data as JSON | +| Category | Name | Class | Notes | +| ------------------ | --------------------- | -------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- | +| Encoding | 'bytes' | `Bytes` | no-op encoding | +| Encoding | 'str' | `Str` | stores in UTF-8 | +| Encoding | 'int' | `Int` | Python `int`, uses `numpy.int64` for encoding | +| Numpy Array | 'ndarray:dtype:shape' | `NDArray(dtype: Optional[str] = None, shape: Optional[Tuple[int]] = None)` | uses `numpy.ndarray` | +| Numpy Unsigned Int | 'uint8' | `UInt8` | uses `numpy.uint8` | +| Numpy Unsigned Int | 'uint16' | `UInt16` | uses `numpy.uint16` | +| Numpy Unsigned Int | 'uint32' | `Uint32` | uses `numpy.uint32` | +| Numpy Unsigned Int | 'uint64' | `Uint64` | uses `numpy.uint64` | +| Numpy Signed Int | 'int8' | `Int8` | uses `numpy.int8` | +| Numpy Signed Int | 'int16' | `Int16` | uses `numpy.int16` | +| Numpy Signed Int | 'int32' | `Int32` | uses `numpy.int32` | +| Numpy Signed Int | 'int64' | `Int64` | uses `numpy.int64` | +| Numpy Float | 'float16' | `Float16` | uses `numpy.float16` | +| Numpy Float | 'float32' | `Float32` | uses `numpy.float32` | +| Numpy Float | 'float64' | `Float64` | uses `numpy.float64` | +| Numerical String | 'str_int' | `StrInt` | stores in UTF-8 | +| Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 | +| Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 | +| Image | 'pil' | `PIL` | raw PIL image class ([link](<(https://pillow.readthedocs.io/en/stable/reference/Image.html)>)) | +| Image | 'jpeg:quality' | `JPEG` | PIL image as JPEG | +| Image | 'png' | `PNG` | PIL image as PNG | +| Pickle | 'pkl' | `Pickle` | arbitrary Python objects | +| JSON | 'json' | `JSON` | arbitrary data as JSON | Here's an example where the field `x` is an image, and `y` is a class label, as an integer. + + ```python column = { - 'x': 'jpeg', + 'x': 'jpeg:95', 'y': 'int8', } ``` @@ -60,6 +64,7 @@ column = { If the data type you need is not listed in the above table, then you can write your own data type class with `encode` and `decode` methods in it and patch it inside streaming. For example, let's say, you wanted to add a `complex128` data type (64 bits each for real and imaginary parts): + ```python import numpy as np from typing import Any @@ -77,13 +82,15 @@ class Complex128(Encoding): _encodings['complex128'] = Complex128 ``` -4. An optional shard `size_limit`, in bytes, for each *uncompressed* shard file. This defaults to 67 MB. Specify this as a number of bytes, either directly as an `int`, or a human-readable suffix: +4. An optional shard `size_limit`, in bytes, for each _uncompressed_ shard file. This defaults to 67 MB. Specify this as a number of bytes, either directly as an `int`, or a human-readable suffix: + ```python size_limit = 1024 # 1kB limit, as an int size_limit = '1kb' # 1kB limit, as a human-readable string ``` + Shard file size depends on the dataset size, but generally, too small of a shard size creates a ton of shard files and heavy network overheads, and too large of a shard size creates fewer shard files, but downloads are less balanced. A shard size of between 50-100MB works well in practice. 5. An optional `compression` algorithm name (and level) if you would like to compress the shard files. This can reduce egress costs during training. StreamingDataset will uncompress shard files upon download during training. You can control whether to keep compressed shard files locally during training with the `keep_zip` flag -- more information [here](../dataset_configuration/shard_retrieval.md#Keeping-compressed-shards). @@ -101,10 +108,12 @@ Supported compression algorithms: The compression algorithm to use, if any, is specified by passing `code` or `code:level` as a string. For example: + ```python compression = 'zstd' # zstd, defaults to level 3. compression = 'zstd:9' # zstd, specifying level 9. ``` + The higher the level, the higher the compression ratio. However, using higher compression levels will impact the compression speed. In our experience, `zstd` is optimal over the time-size Pareto frontier. Compression is most beneficial for text, whereas it is less helpful for other modalities like images. 6. An optional `hashes` list of algorithm names, used to verify data integrity. Hashes are saved in the `index.json` file. Hash verification during training is controlled with the `validate_hash` argument more information [here](../dataset_configuration/shard_retrieval.md#Hash-validation). @@ -139,6 +148,7 @@ Available non-cryptographic hash functions: As an example: + ```python hashes = ['sha256', 'xxh64'] ``` @@ -172,31 +182,41 @@ class RandomClassificationDataset: ``` Here, we write shards to a local directory. You can specify a remote path as well. + + ```python output_dir = 'test_output_dir' ``` Specify the column encoding types for each sample and label: + + ```python columns = {'x': 'pkl', 'y': 'int64'} ``` Optionally, specify a compression algorithm and level: + + ```python compression = 'zstd:7' # compress shards with ZStandard, level 7 ``` Optionally, specify a list of hash algorithms for verification: + + ```python hashes = ['sha1'] # Use only SHA1 hashing on each shard ``` Optionally, provide a shard size limit, after which a new shard starts. In this small example, we use 10kb, but for production datasets 50-100MB is more appropriate. + + ```python # Here we use a human-readable string, but we could also # pass in an int specifying the number of bytes. @@ -204,7 +224,9 @@ limit = '10kb' ``` It's time to call the {class}`streaming.MDSWriter` with the above initialized parameters and write the samples by iterating over a dataset. + + ```python from streaming.base import MDSWriter @@ -215,7 +237,9 @@ with MDSWriter(out=output_dir, columns=columns, compression=compression, hashes= ``` Clean up after ourselves. + + ``` from shutil import rmtree @@ -223,7 +247,9 @@ rmtree(output_dir) ``` Once the dataset has been written, the output directory contains an index.json file that contains shard metadata, the shard files themselves. For example, + + ```bash dirname ├── index.json @@ -234,6 +260,7 @@ dirname ## Example: Writing `ndarray`s to MDS format Here, we show how to write `ndarray`s to MDS format in three ways: + 1. dynamic shape and dtype 2. dynamic shape but fixed dtype 3. fixed shape and dtype @@ -269,6 +296,7 @@ for i in range(dataset.num_samples): The streaming encoding type, as the value in the `columns` dict, should be `ndarray:dtype`. So in this example, it is `ndarray:int16`. + ```python # Write to MDS with MDSWriter(out='my_dataset2/', @@ -293,6 +321,7 @@ for i in range(dataset.num_samples): The streaming encoding type, as the value in the `columns` dict, should be `ndarray:dtype:shape`. So in this example, it is `ndarray:int16:3,3,3`. + ```python # Write to MDS with MDSWriter(out='my_dataset3/', @@ -313,6 +342,7 @@ for i in range(dataset.num_samples): We can see that the dataset is more efficiently serialized when we are more specific about array shape and datatype: + ```python import subprocess @@ -327,7 +357,9 @@ subprocess.run(['du', '-sh', 'my_dataset3']) ``` Clean up after ourselves. + + ```python from shutil import rmtree diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index c28d0058d..b2fda25ca 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -17,15 +17,21 @@ from typing_extensions import Self __all__ = [ - 'get_mds_encoded_size', 'get_mds_encodings', 'is_mds_encoding', 'mds_decode', 'mds_encode', - 'is_mds_encoding_safe' + "get_mds_encoded_size", + "get_mds_encodings", + "is_mds_encoding", + "mds_decode", + "mds_encode", + "is_mds_encoding_safe", ] class Encoding(ABC): """Encodes and decodes between objects of a certain type and raw bytes.""" - size: Optional[int] = None # Fixed size in bytes of encoded data (None if variable size). + size: Optional[int] = ( + None # Fixed size in bytes of encoded data (None if variable size). + ) @abstractmethod def encode(self, obj: Any) -> bytes: @@ -55,7 +61,8 @@ def decode(self, data: bytes) -> Any: def _validate(data: Any, expected_type: Any) -> None: if not isinstance(data, expected_type): raise AttributeError( - f'data should be of type {expected_type}, but instead, found as {type(data)}') + f"data should be of type {expected_type}, but instead, found as {type(data)}" + ) class Bytes(Encoding): @@ -74,10 +81,10 @@ class Str(Encoding): def encode(self, obj: str) -> bytes: self._validate(obj, str) - return obj.encode('utf-8') + return obj.encode("utf-8") def decode(self, data: bytes) -> str: - return data.decode('utf-8') + return data.decode("utf-8") class Int(Encoding): @@ -117,10 +124,10 @@ class NDArray(Encoding): # Integer <4 -> shape dtype. _int2shape_dtype = { - 0: 'uint8', - 1: 'uint16', - 2: 'uint32', - 3: 'uint64', + 0: "uint8", + 1: "uint16", + 2: "uint32", + 3: "uint64", } # Shape dtype -> integer <4. @@ -128,24 +135,26 @@ class NDArray(Encoding): # Integer <256 -> value dtype. _int2value_dtype = { - 8: 'uint8', - 9: 'int8', - 16: 'uint16', - 17: 'int16', - 18: 'float16', - 32: 'uint32', - 33: 'int32', - 34: 'float32', - 64: 'uint64', - 65: 'int64', - 66: 'float64', + 8: "uint8", + 9: "int8", + 16: "uint16", + 17: "int16", + 18: "float16", + 32: "uint32", + 33: "int32", + 34: "float32", + 64: "uint64", + 65: "int64", + 66: "float64", } # Value dtype -> integer <256. _value_dtype2int = {v: k for k, v in _int2value_dtype.items()} @classmethod - def _get_static_size(cls, dtype: Optional[str], shape: Optional[tuple[int]]) -> Optional[int]: + def _get_static_size( + cls, dtype: Optional[str], shape: Optional[tuple[int]] + ) -> Optional[int]: """Get the fixed size of the column in bytes, if applicable. Args: @@ -179,14 +188,14 @@ def from_str(cls, text: str) -> Self: Returns: Self: The initialized Encoding. """ - args = text.split(':') if text else [] + args = text.split(":") if text else [] assert len(args) in {0, 1, 2} if 1 <= len(args): dtype = args[0] else: dtype = None if 2 <= len(args): - shape = tuple(map(int, args[1].split(','))) + shape = tuple(map(int, args[1].split(","))) else: shape = None return cls(dtype, shape) @@ -203,20 +212,20 @@ def _rightsize_shape_dtype(cls, shape: npt.NDArray[np.int64]) -> str: """ if len(shape) == 0: raise ValueError( - 'Attempting to encode a scalar with NDArray encoding. Please use a scalar encoding.' + "Attempting to encode a scalar with NDArray encoding. Please use a scalar encoding." ) if shape.min() <= 0: - raise ValueError('All dimensions must be greater than zero.') + raise ValueError("All dimensions must be greater than zero.") x = shape.max() if x < (1 << 8): - return 'uint8' + return "uint8" elif x < (1 << 16): - return 'uint16' + return "uint16" elif x < (1 << 32): - return 'uint32' + return "uint32" else: - return 'uint64' + return "uint64" def encode(self, obj: npt.NDArray) -> bytes: """Encode the given data from the original object to bytes. @@ -232,22 +241,24 @@ def encode(self, obj: npt.NDArray) -> bytes: # Encode dtype, if not given in header. dtype_int = self._value_dtype2int.get(obj.dtype.name) if dtype_int is None: - raise ValueError(f'Unsupported dtype: {obj.dtype.name}.') + raise ValueError(f"Unsupported dtype: {obj.dtype.name}.") if self.dtype is None: part = bytes([dtype_int]) parts.append(part) else: if obj.dtype != self.dtype: - raise ValueError(f'Wrong dtype: expected {self.dtype}, got {obj.dtype.name}.') + raise ValueError( + f"Wrong dtype: expected {self.dtype}, got {obj.dtype.name}." + ) if obj.size == 0: - raise ValueError('Attempting to encode a numpy array with 0 elements.') + raise ValueError("Attempting to encode a numpy array with 0 elements.") # Encode shape, if not given in header. if self.shape is None: ndim = len(obj.shape) if 64 <= ndim: - raise ValueError('Array has too many axes: maximum 63, got {ndim}.') + raise ValueError("Array has too many axes: maximum 63, got {ndim}.") shape_arr = np.array(obj.shape, np.int64) shape_dtype = self._rightsize_shape_dtype(shape_arr) shape_dtype_int = self._shape_dtype2int[shape_dtype] @@ -258,13 +269,13 @@ def encode(self, obj: npt.NDArray) -> bytes: parts.append(part) else: if obj.shape != self.shape: - raise ValueError('Wrong shape: expected {self.shape}, got {obj.shape}.') + raise ValueError("Wrong shape: expected {self.shape}, got {obj.shape}.") # Encode the array values. part = obj.tobytes() parts.append(part) - return b''.join(parts) + return b"".join(parts) def decode(self, data: bytes) -> npt.NDArray: """Decode the given data from bytes to the original object. @@ -296,7 +307,7 @@ def decode(self, data: bytes) -> npt.NDArray: shape_dtype = self._int2shape_dtype[shape_dtype_int] shape_dtype_nbytes = 2**shape_dtype_int size = ndim * shape_dtype_nbytes - shape = np.frombuffer(data[index:index + size], shape_dtype) + shape = np.frombuffer(data[index : index + size], shape_dtype) index += size # Decode the array values. @@ -411,10 +422,10 @@ class StrInt(StrEncoding): def encode(self, obj: int) -> bytes: self._validate(obj, int) - return str(obj).encode('utf-8') + return str(obj).encode("utf-8") def decode(self, data: bytes) -> int: - return int(data.decode('utf-8')) + return int(data.decode("utf-8")) class StrFloat(Encoding): @@ -422,10 +433,10 @@ class StrFloat(Encoding): def encode(self, obj: float) -> bytes: self._validate(obj, float) - return str(obj).encode('utf-8') + return str(obj).encode("utf-8") def decode(self, data: bytes) -> float: - return float(data.decode('utf-8')) + return float(data.decode("utf-8")) class StrDecimal(Encoding): @@ -433,10 +444,10 @@ class StrDecimal(Encoding): def encode(self, obj: Decimal) -> bytes: self._validate(obj, Decimal) - return str(obj).encode('utf-8') + return str(obj).encode("utf-8") def decode(self, data: bytes) -> Decimal: - return Decimal(data.decode('utf-8')) + return Decimal(data.decode("utf-8")) class PIL(Encoding): @@ -447,7 +458,7 @@ class PIL(Encoding): def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) - mode = obj.mode.encode('utf-8') + mode = obj.mode.encode("utf-8") width, height = obj.size raw = obj.tobytes() ints = np.array([width, height, len(mode)], np.uint32) @@ -457,7 +468,7 @@ def decode(self, data: bytes) -> Image.Image: idx = 3 * 4 width, height, mode_size = np.frombuffer(data[:idx], np.uint32) idx2 = idx + mode_size - mode = data[idx:idx2].decode('utf-8') + mode = data[idx:idx2].decode("utf-8") size = width, height raw = data[idx2:] return Image.frombytes(mode, size, raw) # pyright: ignore @@ -466,21 +477,43 @@ def decode(self, data: bytes) -> Image.Image: class JPEG(Encoding): """Store PIL image as JPEG.""" + def __init__(self, quality: int = 75): + assert 0 <= quality <= 100 + self.quality = quality + def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) - if isinstance(obj, JpegImageFile) and hasattr(obj, 'filename'): + if isinstance(obj, JpegImageFile) and hasattr(obj, "filename"): # read the source file to prevent lossy re-encoding - with open(obj.filename, 'rb') as f: + with open(obj.filename, "rb") as f: return f.read() else: out = BytesIO() - obj.save(out, format='JPEG') + obj.save(out, format="JPEG", quality=self.quality) return out.getvalue() def decode(self, data: bytes) -> Image.Image: inp = BytesIO(data) return Image.open(inp) + @classmethod + def from_str(cls, text: str) -> Self: + """Parse this encoding from string. + + Args: + text (str): The string to parse. + + Returns: + Self: The initialized Encoding. + """ + args = text.split(":") if text else [] + assert len(args) in {0, 1} + if len(args) == 1: + quality = int(args[0]) + else: + quality = 75 + return cls(quality) + class PNG(Encoding): """Store PIL image as PNG.""" @@ -488,7 +521,7 @@ class PNG(Encoding): def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) out = BytesIO() - obj.save(out, format='PNG') + obj.save(out, format="PNG") return out.getvalue() def decode(self, data: bytes) -> Image.Image: @@ -514,47 +547,47 @@ def encode(self, obj: Any) -> bytes: obj = obj.tolist() data = json.dumps(obj) self._is_valid(obj, data) - return data.encode('utf-8') + return data.encode("utf-8") def decode(self, data: bytes) -> Any: - return json.loads(data.decode('utf-8')) + return json.loads(data.decode("utf-8")) def _is_valid(self, original: Any, converted: Any) -> None: try: json.loads(converted) except json.decoder.JSONDecodeError as e: - e.msg = f'Invalid JSON data: {original}' + e.msg = f"Invalid JSON data: {original}" raise # Encodings (name -> class). _encodings = { - 'bytes': Bytes, - 'str': Str, - 'int': Int, - 'ndarray': NDArray, - 'uint8': UInt8, - 'uint16': UInt16, - 'uint32': UInt32, - 'uint64': UInt64, - 'int8': Int8, - 'int16': Int16, - 'int32': Int32, - 'int64': Int64, - 'float16': Float16, - 'float32': Float32, - 'float64': Float64, - 'str_int': StrInt, - 'str_float': StrFloat, - 'str_decimal': StrDecimal, - 'pil': PIL, - 'jpeg': JPEG, - 'png': PNG, - 'pkl': Pickle, - 'json': JSON, + "bytes": Bytes, + "str": Str, + "int": Int, + "ndarray": NDArray, + "uint8": UInt8, + "uint16": UInt16, + "uint32": UInt32, + "uint64": UInt64, + "int8": Int8, + "int16": Int16, + "int32": Int32, + "int64": Int64, + "float16": Float16, + "float32": Float32, + "float64": Float64, + "str_int": StrInt, + "str_float": StrFloat, + "str_decimal": StrDecimal, + "pil": PIL, + "jpeg": JPEG, + "png": PNG, + "pkl": Pickle, + "json": JSON, } -_unsafe_encodings = {'pkl'} +_unsafe_encodings = {"pkl"} def get_mds_encodings() -> set[str]: @@ -575,14 +608,14 @@ def _get_coder(encoding: str) -> Optional[Encoding]: Returns: Encoding: The coder. """ - index = encoding.find(':') + index = encoding.find(":") if index == -1: cls = _encodings.get(encoding) if cls is None: return None return cls() name = encoding[:index] - config = encoding[index + 1:] + config = encoding[index + 1 :] return _encodings[name].from_str(config) @@ -625,7 +658,7 @@ def mds_encode(encoding: str, obj: Any) -> bytes: return obj coder = _get_coder(encoding) if coder is None: - raise ValueError(f'Unsupported encoding: {encoding}.') + raise ValueError(f"Unsupported encoding: {encoding}.") return coder.encode(obj) @@ -641,7 +674,7 @@ def mds_decode(encoding: str, data: bytes) -> Any: """ coder = _get_coder(encoding) if coder is None: - raise ValueError(f'Unsupported encoding: {encoding}.') + raise ValueError(f"Unsupported encoding: {encoding}.") return coder.decode(data) @@ -656,5 +689,5 @@ def get_mds_encoded_size(encoding: str) -> Optional[int]: """ coder = _get_coder(encoding) if coder is None: - raise ValueError(f'Unsupported encoding: {encoding}.') + raise ValueError(f"Unsupported encoding: {encoding}.") return coder.size From 6a8dc146845c255083e89d9099f270a113c1a203 Mon Sep 17 00:00:00 2001 From: Alex Cabrera Date: Tue, 29 Oct 2024 10:16:13 -0700 Subject: [PATCH 2/4] formatting --- .../basic_dataset_conversion.md | 92 +++----- streaming/base/format/mds/encodings.py | 199 +++++++++--------- 2 files changed, 124 insertions(+), 167 deletions(-) diff --git a/docs/source/preparing_datasets/basic_dataset_conversion.md b/docs/source/preparing_datasets/basic_dataset_conversion.md index d23237bb5..f2146adac 100644 --- a/docs/source/preparing_datasets/basic_dataset_conversion.md +++ b/docs/source/preparing_datasets/basic_dataset_conversion.md @@ -7,13 +7,11 @@ This guide covers how to convert your raw data to MDS format using {class}`strea Use {class}`streaming.MDSWriter` to convert raw data to MDS format. MDSWriter is like a native file writer; instead of writing the content line by line, MDSWriter writes the data sample by sample. It writes the data into shard files in a sequential manner (for example, `shard.00000.mds`, then `shard.00001.mds`, and so on). Configure {class}`streaming.MDSWriter` according to your requirements with the parameters below: 1. An `out` parameter is an output directory to save shard files. The `out` directory can be specified in three ways: - -- **Local path**: Shard files are stored locally. -- **Remote path**: A local temporary directory is created to cache the shard files, and when shard creation is complete, they are uploaded to the remote location. -- **`(local_dir, remote_dir)` tuple**: Shard files are saved in the specified `local_dir` and uploaded to `remote_dir`. + * **Local path**: Shard files are stored locally. + * **Remote path**: A local temporary directory is created to cache the shard files, and when shard creation is complete, they are uploaded to the remote location. + * **`(local_dir, remote_dir)` tuple**: Shard files are saved in the specified `local_dir` and uploaded to `remote_dir`. - ```python out = '/local/data' out = 's3://bucket/data' # Will create a temporary local dir @@ -24,39 +22,37 @@ out = ('/local/data', 'oci://bucket/data') 3. A `column` parameter is a `dict` mapping a feature name or label name with a streaming supported encoding type. `MDSWriter` encodes your data to bytes, and at training time, data gets decoded back automatically to its original form. The `index.json` file stores `column` metadata for decoding. Supported encoding formats are: -| Category | Name | Class | Notes | -| ------------------ | --------------------- | -------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- | -| Encoding | 'bytes' | `Bytes` | no-op encoding | -| Encoding | 'str' | `Str` | stores in UTF-8 | -| Encoding | 'int' | `Int` | Python `int`, uses `numpy.int64` for encoding | -| Numpy Array | 'ndarray:dtype:shape' | `NDArray(dtype: Optional[str] = None, shape: Optional[Tuple[int]] = None)` | uses `numpy.ndarray` | -| Numpy Unsigned Int | 'uint8' | `UInt8` | uses `numpy.uint8` | -| Numpy Unsigned Int | 'uint16' | `UInt16` | uses `numpy.uint16` | -| Numpy Unsigned Int | 'uint32' | `Uint32` | uses `numpy.uint32` | -| Numpy Unsigned Int | 'uint64' | `Uint64` | uses `numpy.uint64` | -| Numpy Signed Int | 'int8' | `Int8` | uses `numpy.int8` | -| Numpy Signed Int | 'int16' | `Int16` | uses `numpy.int16` | -| Numpy Signed Int | 'int32' | `Int32` | uses `numpy.int32` | -| Numpy Signed Int | 'int64' | `Int64` | uses `numpy.int64` | -| Numpy Float | 'float16' | `Float16` | uses `numpy.float16` | -| Numpy Float | 'float32' | `Float32` | uses `numpy.float32` | -| Numpy Float | 'float64' | `Float64` | uses `numpy.float64` | -| Numerical String | 'str_int' | `StrInt` | stores in UTF-8 | -| Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 | -| Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 | -| Image | 'pil' | `PIL` | raw PIL image class ([link](<(https://pillow.readthedocs.io/en/stable/reference/Image.html)>)) | -| Image | 'jpeg:quality' | `JPEG` | PIL image as JPEG | -| Image | 'png' | `PNG` | PIL image as PNG | -| Pickle | 'pkl' | `Pickle` | arbitrary Python objects | -| JSON | 'json' | `JSON` | arbitrary data as JSON | +| Category | Name | Class | Notes | +|--------------------|---------------|--------------|--------------------------| +| Encoding | 'bytes' | `Bytes` | no-op encoding | +| Encoding | 'str' | `Str` | stores in UTF-8 | +| Encoding | 'int' | `Int` | Python `int`, uses `numpy.int64` for encoding | +| Numpy Array | 'ndarray:dtype:shape' | `NDArray(dtype: Optional[str] = None, shape: Optional[Tuple[int]] = None)` | uses `numpy.ndarray` | +| Numpy Unsigned Int | 'uint8' | `UInt8` | uses `numpy.uint8` | +| Numpy Unsigned Int | 'uint16' | `UInt16` | uses `numpy.uint16` | +| Numpy Unsigned Int | 'uint32' | `Uint32` | uses `numpy.uint32` | +| Numpy Unsigned Int | 'uint64' | `Uint64` | uses `numpy.uint64` | +| Numpy Signed Int | 'int8' | `Int8` | uses `numpy.int8` | +| Numpy Signed Int | 'int16' | `Int16` | uses `numpy.int16` | +| Numpy Signed Int | 'int32' | `Int32` | uses `numpy.int32` | +| Numpy Signed Int | 'int64' | `Int64` | uses `numpy.int64` | +| Numpy Float | 'float16' | `Float16` | uses `numpy.float16` | +| Numpy Float | 'float32' | `Float32` | uses `numpy.float32` | +| Numpy Float | 'float64' | `Float64` | uses `numpy.float64` | +| Numerical String | 'str_int' | `StrInt` | stores in UTF-8 | +| Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 | +| Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 | +| Image | 'pil' | `PIL` | raw PIL image class ([link]((https://pillow.readthedocs.io/en/stable/reference/Image.html))) | +| Image | 'jpeg:quality' | `JPEG` | PIL image as JPEG, quality between 0 and 100 | +| Image | 'png' | `PNG` | PIL image as PNG | +| Pickle | 'pkl' | `Pickle` | arbitrary Python objects | +| JSON | 'json' | `JSON` | arbitrary data as JSON | Here's an example where the field `x` is an image, and `y` is a class label, as an integer. - - ```python column = { - 'x': 'jpeg:95', + 'x': 'jpeg:90', 'y': 'int8', } ``` @@ -64,7 +60,6 @@ column = { If the data type you need is not listed in the above table, then you can write your own data type class with `encode` and `decode` methods in it and patch it inside streaming. For example, let's say, you wanted to add a `complex128` data type (64 bits each for real and imaginary parts): - ```python import numpy as np from typing import Any @@ -82,15 +77,13 @@ class Complex128(Encoding): _encodings['complex128'] = Complex128 ``` -4. An optional shard `size_limit`, in bytes, for each _uncompressed_ shard file. This defaults to 67 MB. Specify this as a number of bytes, either directly as an `int`, or a human-readable suffix: +4. An optional shard `size_limit`, in bytes, for each *uncompressed* shard file. This defaults to 67 MB. Specify this as a number of bytes, either directly as an `int`, or a human-readable suffix: - ```python size_limit = 1024 # 1kB limit, as an int size_limit = '1kb' # 1kB limit, as a human-readable string ``` - Shard file size depends on the dataset size, but generally, too small of a shard size creates a ton of shard files and heavy network overheads, and too large of a shard size creates fewer shard files, but downloads are less balanced. A shard size of between 50-100MB works well in practice. 5. An optional `compression` algorithm name (and level) if you would like to compress the shard files. This can reduce egress costs during training. StreamingDataset will uncompress shard files upon download during training. You can control whether to keep compressed shard files locally during training with the `keep_zip` flag -- more information [here](../dataset_configuration/shard_retrieval.md#Keeping-compressed-shards). @@ -108,12 +101,10 @@ Supported compression algorithms: The compression algorithm to use, if any, is specified by passing `code` or `code:level` as a string. For example: - ```python compression = 'zstd' # zstd, defaults to level 3. compression = 'zstd:9' # zstd, specifying level 9. ``` - The higher the level, the higher the compression ratio. However, using higher compression levels will impact the compression speed. In our experience, `zstd` is optimal over the time-size Pareto frontier. Compression is most beneficial for text, whereas it is less helpful for other modalities like images. 6. An optional `hashes` list of algorithm names, used to verify data integrity. Hashes are saved in the `index.json` file. Hash verification during training is controlled with the `validate_hash` argument more information [here](../dataset_configuration/shard_retrieval.md#Hash-validation). @@ -148,7 +139,6 @@ Available non-cryptographic hash functions: As an example: - ```python hashes = ['sha256', 'xxh64'] ``` @@ -182,41 +172,31 @@ class RandomClassificationDataset: ``` Here, we write shards to a local directory. You can specify a remote path as well. - - ```python output_dir = 'test_output_dir' ``` Specify the column encoding types for each sample and label: - - ```python columns = {'x': 'pkl', 'y': 'int64'} ``` Optionally, specify a compression algorithm and level: - - ```python compression = 'zstd:7' # compress shards with ZStandard, level 7 ``` Optionally, specify a list of hash algorithms for verification: - - ```python hashes = ['sha1'] # Use only SHA1 hashing on each shard ``` Optionally, provide a shard size limit, after which a new shard starts. In this small example, we use 10kb, but for production datasets 50-100MB is more appropriate. - - ```python # Here we use a human-readable string, but we could also # pass in an int specifying the number of bytes. @@ -224,9 +204,7 @@ limit = '10kb' ``` It's time to call the {class}`streaming.MDSWriter` with the above initialized parameters and write the samples by iterating over a dataset. - - ```python from streaming.base import MDSWriter @@ -237,9 +215,7 @@ with MDSWriter(out=output_dir, columns=columns, compression=compression, hashes= ``` Clean up after ourselves. - - ``` from shutil import rmtree @@ -247,9 +223,7 @@ rmtree(output_dir) ``` Once the dataset has been written, the output directory contains an index.json file that contains shard metadata, the shard files themselves. For example, - - ```bash dirname ├── index.json @@ -260,7 +234,6 @@ dirname ## Example: Writing `ndarray`s to MDS format Here, we show how to write `ndarray`s to MDS format in three ways: - 1. dynamic shape and dtype 2. dynamic shape but fixed dtype 3. fixed shape and dtype @@ -296,7 +269,6 @@ for i in range(dataset.num_samples): The streaming encoding type, as the value in the `columns` dict, should be `ndarray:dtype`. So in this example, it is `ndarray:int16`. - ```python # Write to MDS with MDSWriter(out='my_dataset2/', @@ -321,7 +293,6 @@ for i in range(dataset.num_samples): The streaming encoding type, as the value in the `columns` dict, should be `ndarray:dtype:shape`. So in this example, it is `ndarray:int16:3,3,3`. - ```python # Write to MDS with MDSWriter(out='my_dataset3/', @@ -342,7 +313,6 @@ for i in range(dataset.num_samples): We can see that the dataset is more efficiently serialized when we are more specific about array shape and datatype: - ```python import subprocess @@ -357,9 +327,7 @@ subprocess.run(['du', '-sh', 'my_dataset3']) ``` Clean up after ourselves. - - ```python from shutil import rmtree diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index b2fda25ca..723525260 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -17,21 +17,15 @@ from typing_extensions import Self __all__ = [ - "get_mds_encoded_size", - "get_mds_encodings", - "is_mds_encoding", - "mds_decode", - "mds_encode", - "is_mds_encoding_safe", + 'get_mds_encoded_size', 'get_mds_encodings', 'is_mds_encoding', 'mds_decode', 'mds_encode', + 'is_mds_encoding_safe' ] class Encoding(ABC): """Encodes and decodes between objects of a certain type and raw bytes.""" - size: Optional[int] = ( - None # Fixed size in bytes of encoded data (None if variable size). - ) + size: Optional[int] = None # Fixed size in bytes of encoded data (None if variable size). @abstractmethod def encode(self, obj: Any) -> bytes: @@ -61,8 +55,7 @@ def decode(self, data: bytes) -> Any: def _validate(data: Any, expected_type: Any) -> None: if not isinstance(data, expected_type): raise AttributeError( - f"data should be of type {expected_type}, but instead, found as {type(data)}" - ) + f'data should be of type {expected_type}, but instead, found as {type(data)}') class Bytes(Encoding): @@ -81,10 +74,10 @@ class Str(Encoding): def encode(self, obj: str) -> bytes: self._validate(obj, str) - return obj.encode("utf-8") + return obj.encode('utf-8') def decode(self, data: bytes) -> str: - return data.decode("utf-8") + return data.decode('utf-8') class Int(Encoding): @@ -124,10 +117,10 @@ class NDArray(Encoding): # Integer <4 -> shape dtype. _int2shape_dtype = { - 0: "uint8", - 1: "uint16", - 2: "uint32", - 3: "uint64", + 0: 'uint8', + 1: 'uint16', + 2: 'uint32', + 3: 'uint64', } # Shape dtype -> integer <4. @@ -135,26 +128,24 @@ class NDArray(Encoding): # Integer <256 -> value dtype. _int2value_dtype = { - 8: "uint8", - 9: "int8", - 16: "uint16", - 17: "int16", - 18: "float16", - 32: "uint32", - 33: "int32", - 34: "float32", - 64: "uint64", - 65: "int64", - 66: "float64", + 8: 'uint8', + 9: 'int8', + 16: 'uint16', + 17: 'int16', + 18: 'float16', + 32: 'uint32', + 33: 'int32', + 34: 'float32', + 64: 'uint64', + 65: 'int64', + 66: 'float64', } # Value dtype -> integer <256. _value_dtype2int = {v: k for k, v in _int2value_dtype.items()} @classmethod - def _get_static_size( - cls, dtype: Optional[str], shape: Optional[tuple[int]] - ) -> Optional[int]: + def _get_static_size(cls, dtype: Optional[str], shape: Optional[tuple[int]]) -> Optional[int]: """Get the fixed size of the column in bytes, if applicable. Args: @@ -188,14 +179,14 @@ def from_str(cls, text: str) -> Self: Returns: Self: The initialized Encoding. """ - args = text.split(":") if text else [] + args = text.split(':') if text else [] assert len(args) in {0, 1, 2} if 1 <= len(args): dtype = args[0] else: dtype = None if 2 <= len(args): - shape = tuple(map(int, args[1].split(","))) + shape = tuple(map(int, args[1].split(','))) else: shape = None return cls(dtype, shape) @@ -212,20 +203,20 @@ def _rightsize_shape_dtype(cls, shape: npt.NDArray[np.int64]) -> str: """ if len(shape) == 0: raise ValueError( - "Attempting to encode a scalar with NDArray encoding. Please use a scalar encoding." + 'Attempting to encode a scalar with NDArray encoding. Please use a scalar encoding.' ) if shape.min() <= 0: - raise ValueError("All dimensions must be greater than zero.") + raise ValueError('All dimensions must be greater than zero.') x = shape.max() if x < (1 << 8): - return "uint8" + return 'uint8' elif x < (1 << 16): - return "uint16" + return 'uint16' elif x < (1 << 32): - return "uint32" + return 'uint32' else: - return "uint64" + return 'uint64' def encode(self, obj: npt.NDArray) -> bytes: """Encode the given data from the original object to bytes. @@ -241,24 +232,22 @@ def encode(self, obj: npt.NDArray) -> bytes: # Encode dtype, if not given in header. dtype_int = self._value_dtype2int.get(obj.dtype.name) if dtype_int is None: - raise ValueError(f"Unsupported dtype: {obj.dtype.name}.") + raise ValueError(f'Unsupported dtype: {obj.dtype.name}.') if self.dtype is None: part = bytes([dtype_int]) parts.append(part) else: if obj.dtype != self.dtype: - raise ValueError( - f"Wrong dtype: expected {self.dtype}, got {obj.dtype.name}." - ) + raise ValueError(f'Wrong dtype: expected {self.dtype}, got {obj.dtype.name}.') if obj.size == 0: - raise ValueError("Attempting to encode a numpy array with 0 elements.") + raise ValueError('Attempting to encode a numpy array with 0 elements.') # Encode shape, if not given in header. if self.shape is None: ndim = len(obj.shape) if 64 <= ndim: - raise ValueError("Array has too many axes: maximum 63, got {ndim}.") + raise ValueError('Array has too many axes: maximum 63, got {ndim}.') shape_arr = np.array(obj.shape, np.int64) shape_dtype = self._rightsize_shape_dtype(shape_arr) shape_dtype_int = self._shape_dtype2int[shape_dtype] @@ -269,13 +258,13 @@ def encode(self, obj: npt.NDArray) -> bytes: parts.append(part) else: if obj.shape != self.shape: - raise ValueError("Wrong shape: expected {self.shape}, got {obj.shape}.") + raise ValueError('Wrong shape: expected {self.shape}, got {obj.shape}.') # Encode the array values. part = obj.tobytes() parts.append(part) - return b"".join(parts) + return b''.join(parts) def decode(self, data: bytes) -> npt.NDArray: """Decode the given data from bytes to the original object. @@ -307,7 +296,7 @@ def decode(self, data: bytes) -> npt.NDArray: shape_dtype = self._int2shape_dtype[shape_dtype_int] shape_dtype_nbytes = 2**shape_dtype_int size = ndim * shape_dtype_nbytes - shape = np.frombuffer(data[index : index + size], shape_dtype) + shape = np.frombuffer(data[index:index + size], shape_dtype) index += size # Decode the array values. @@ -422,10 +411,10 @@ class StrInt(StrEncoding): def encode(self, obj: int) -> bytes: self._validate(obj, int) - return str(obj).encode("utf-8") + return str(obj).encode('utf-8') def decode(self, data: bytes) -> int: - return int(data.decode("utf-8")) + return int(data.decode('utf-8')) class StrFloat(Encoding): @@ -433,10 +422,10 @@ class StrFloat(Encoding): def encode(self, obj: float) -> bytes: self._validate(obj, float) - return str(obj).encode("utf-8") + return str(obj).encode('utf-8') def decode(self, data: bytes) -> float: - return float(data.decode("utf-8")) + return float(data.decode('utf-8')) class StrDecimal(Encoding): @@ -444,10 +433,10 @@ class StrDecimal(Encoding): def encode(self, obj: Decimal) -> bytes: self._validate(obj, Decimal) - return str(obj).encode("utf-8") + return str(obj).encode('utf-8') def decode(self, data: bytes) -> Decimal: - return Decimal(data.decode("utf-8")) + return Decimal(data.decode('utf-8')) class PIL(Encoding): @@ -458,7 +447,7 @@ class PIL(Encoding): def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) - mode = obj.mode.encode("utf-8") + mode = obj.mode.encode('utf-8') width, height = obj.size raw = obj.tobytes() ints = np.array([width, height, len(mode)], np.uint32) @@ -468,7 +457,7 @@ def decode(self, data: bytes) -> Image.Image: idx = 3 * 4 width, height, mode_size = np.frombuffer(data[:idx], np.uint32) idx2 = idx + mode_size - mode = data[idx:idx2].decode("utf-8") + mode = data[idx:idx2].decode('utf-8') size = width, height raw = data[idx2:] return Image.frombytes(mode, size, raw) # pyright: ignore @@ -481,21 +470,6 @@ def __init__(self, quality: int = 75): assert 0 <= quality <= 100 self.quality = quality - def encode(self, obj: Image.Image) -> bytes: - self._validate(obj, Image.Image) - if isinstance(obj, JpegImageFile) and hasattr(obj, "filename"): - # read the source file to prevent lossy re-encoding - with open(obj.filename, "rb") as f: - return f.read() - else: - out = BytesIO() - obj.save(out, format="JPEG", quality=self.quality) - return out.getvalue() - - def decode(self, data: bytes) -> Image.Image: - inp = BytesIO(data) - return Image.open(inp) - @classmethod def from_str(cls, text: str) -> Self: """Parse this encoding from string. @@ -506,7 +480,7 @@ def from_str(cls, text: str) -> Self: Returns: Self: The initialized Encoding. """ - args = text.split(":") if text else [] + args = text.split(':') if text else [] assert len(args) in {0, 1} if len(args) == 1: quality = int(args[0]) @@ -514,6 +488,21 @@ def from_str(cls, text: str) -> Self: quality = 75 return cls(quality) + def encode(self, obj: Image.Image) -> bytes: + self._validate(obj, Image.Image) + if isinstance(obj, JpegImageFile) and hasattr(obj, 'filename'): + # read the source file to prevent lossy re-encoding + with open(obj.filename, 'rb') as f: + return f.read() + else: + out = BytesIO() + obj.save(out, format='JPEG', quality=self.quality) + return out.getvalue() + + def decode(self, data: bytes) -> Image.Image: + inp = BytesIO(data) + return Image.open(inp) + class PNG(Encoding): """Store PIL image as PNG.""" @@ -521,7 +510,7 @@ class PNG(Encoding): def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) out = BytesIO() - obj.save(out, format="PNG") + obj.save(out, format='PNG') return out.getvalue() def decode(self, data: bytes) -> Image.Image: @@ -547,47 +536,47 @@ def encode(self, obj: Any) -> bytes: obj = obj.tolist() data = json.dumps(obj) self._is_valid(obj, data) - return data.encode("utf-8") + return data.encode('utf-8') def decode(self, data: bytes) -> Any: - return json.loads(data.decode("utf-8")) + return json.loads(data.decode('utf-8')) def _is_valid(self, original: Any, converted: Any) -> None: try: json.loads(converted) except json.decoder.JSONDecodeError as e: - e.msg = f"Invalid JSON data: {original}" + e.msg = f'Invalid JSON data: {original}' raise # Encodings (name -> class). _encodings = { - "bytes": Bytes, - "str": Str, - "int": Int, - "ndarray": NDArray, - "uint8": UInt8, - "uint16": UInt16, - "uint32": UInt32, - "uint64": UInt64, - "int8": Int8, - "int16": Int16, - "int32": Int32, - "int64": Int64, - "float16": Float16, - "float32": Float32, - "float64": Float64, - "str_int": StrInt, - "str_float": StrFloat, - "str_decimal": StrDecimal, - "pil": PIL, - "jpeg": JPEG, - "png": PNG, - "pkl": Pickle, - "json": JSON, + 'bytes': Bytes, + 'str': Str, + 'int': Int, + 'ndarray': NDArray, + 'uint8': UInt8, + 'uint16': UInt16, + 'uint32': UInt32, + 'uint64': UInt64, + 'int8': Int8, + 'int16': Int16, + 'int32': Int32, + 'int64': Int64, + 'float16': Float16, + 'float32': Float32, + 'float64': Float64, + 'str_int': StrInt, + 'str_float': StrFloat, + 'str_decimal': StrDecimal, + 'pil': PIL, + 'jpeg': JPEG, + 'png': PNG, + 'pkl': Pickle, + 'json': JSON, } -_unsafe_encodings = {"pkl"} +_unsafe_encodings = {'pkl'} def get_mds_encodings() -> set[str]: @@ -608,14 +597,14 @@ def _get_coder(encoding: str) -> Optional[Encoding]: Returns: Encoding: The coder. """ - index = encoding.find(":") + index = encoding.find(':') if index == -1: cls = _encodings.get(encoding) if cls is None: return None return cls() name = encoding[:index] - config = encoding[index + 1 :] + config = encoding[index + 1:] return _encodings[name].from_str(config) @@ -658,7 +647,7 @@ def mds_encode(encoding: str, obj: Any) -> bytes: return obj coder = _get_coder(encoding) if coder is None: - raise ValueError(f"Unsupported encoding: {encoding}.") + raise ValueError(f'Unsupported encoding: {encoding}.') return coder.encode(obj) @@ -674,7 +663,7 @@ def mds_decode(encoding: str, data: bytes) -> Any: """ coder = _get_coder(encoding) if coder is None: - raise ValueError(f"Unsupported encoding: {encoding}.") + raise ValueError(f'Unsupported encoding: {encoding}.') return coder.decode(data) @@ -689,5 +678,5 @@ def get_mds_encoded_size(encoding: str) -> Optional[int]: """ coder = _get_coder(encoding) if coder is None: - raise ValueError(f"Unsupported encoding: {encoding}.") + raise ValueError(f'Unsupported encoding: {encoding}.') return coder.size From 5b8186b1fe7d32744bdf44a8e5ce06a4b6f0c870 Mon Sep 17 00:00:00 2001 From: Alex Cabrera Date: Fri, 1 Nov 2024 11:25:51 -0700 Subject: [PATCH 3/4] updates --- streaming/base/format/mds/encodings.py | 13 +++++++++---- tests/test_encodings.py | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 723525260..70466d170 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -464,10 +464,11 @@ def decode(self, data: bytes) -> Image.Image: class JPEG(Encoding): - """Store PIL image as JPEG.""" + """Store PIL image as JPEG. Optionally specify quality.""" def __init__(self, quality: int = 75): - assert 0 <= quality <= 100 + if not (0 <= quality <= 100): + raise ValueError('JPEG quality must be between 0 and 100') self.quality = quality @classmethod @@ -481,9 +482,13 @@ def from_str(cls, text: str) -> Self: Self: The initialized Encoding. """ args = text.split(':') if text else [] - assert len(args) in {0, 1} + if len(args) not in {0, 1}: + raise ValueError('JPEG encoding string must have 0 or 1 arguments') if len(args) == 1: - quality = int(args[0]) + try: + quality = int(args[0]) + except ValueError: + raise ValueError('JPEG quality must be an integer between 0 and 100') else: quality = 75 return cls(quality) diff --git a/tests/test_encodings.py b/tests/test_encodings.py index 47fe2a6b2..376a3d2be 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -194,6 +194,30 @@ def test_jpeg_encode_decode(self, mode: str): dec_data = dec_data.convert('I') assert isinstance(dec_data, Image.Image) + @pytest.mark.parametrize('mode', ['L', 'RGB']) + def test_jpeg_encode_decode_with_quality(self, mode: str): + jpeg_enc = mdsEnc.JPEG(quality=50) + assert jpeg_enc.size is None + + # Creating the (32 x 32) NumPy Array with random values + np_data = np.random.randint(255, size=(32, 32), dtype=np.uint32) + # Default image mode of PIL Image is 'I' + img = Image.fromarray(np_data).convert(mode) + + # Test encode + enc_data = jpeg_enc.encode(img) + assert isinstance(enc_data, bytes) + + # Test decode + dec_data = jpeg_enc.decode(enc_data) + dec_data = dec_data.convert('I') + assert isinstance(dec_data, Image.Image) + + @pytest.mark.parametrize('quality', [-1, 101, 'foo']) + def test_jpeg_encode_decode_with_quality_invalid(self, quality: Any): + with pytest.raises(ValueError): + mdsEnc.JPEG(quality=quality) + @pytest.mark.parametrize('mode', ['L', 'RGB']) def test_jpegfile_encode_decode(self, mode: str): jpeg_enc = mdsEnc.JPEG() @@ -224,6 +248,7 @@ def test_jpeg_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): jpeg_enc = mdsEnc.JPEG() _ = jpeg_enc.encode(data) + @pytest.mark.parametrize('mode', ['I', 'L', 'RGB']) def test_png_encode_decode(self, mode: str): From 45c2132fbd3e0621c403496b51b5b56e1c174a22 Mon Sep 17 00:00:00 2001 From: Alex Cabrera Date: Fri, 1 Nov 2024 11:25:55 -0700 Subject: [PATCH 4/4] udpates --- streaming/base/format/mds/encodings.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 70466d170..ed2bc6ca9 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -467,12 +467,14 @@ class JPEG(Encoding): """Store PIL image as JPEG. Optionally specify quality.""" def __init__(self, quality: int = 75): + if not isinstance(quality, int): + raise ValueError('JPEG quality must be an integer') if not (0 <= quality <= 100): raise ValueError('JPEG quality must be between 0 and 100') self.quality = quality @classmethod - def from_str(cls, text: str) -> Self: + def from_str(cls, config: str) -> Self: """Parse this encoding from string. Args: @@ -481,17 +483,10 @@ def from_str(cls, text: str) -> Self: Returns: Self: The initialized Encoding. """ - args = text.split(':') if text else [] - if len(args) not in {0, 1}: - raise ValueError('JPEG encoding string must have 0 or 1 arguments') - if len(args) == 1: - try: - quality = int(args[0]) - except ValueError: - raise ValueError('JPEG quality must be an integer between 0 and 100') + if config == '': + return cls() else: - quality = 75 - return cls(quality) + return cls(int(config)) def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image)