diff --git a/docs/source/preparing_datasets/basic_dataset_conversion.md b/docs/source/preparing_datasets/basic_dataset_conversion.md index 58b214078..f2146adac 100644 --- a/docs/source/preparing_datasets/basic_dataset_conversion.md +++ b/docs/source/preparing_datasets/basic_dataset_conversion.md @@ -43,7 +43,7 @@ out = ('/local/data', 'oci://bucket/data') | Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 | | Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 | | Image | 'pil' | `PIL` | raw PIL image class ([link]((https://pillow.readthedocs.io/en/stable/reference/Image.html))) | -| Image | 'jpeg' | `JPEG` | PIL image as JPEG | +| Image | 'jpeg:quality' | `JPEG` | PIL image as JPEG, quality between 0 and 100 | | Image | 'png' | `PNG` | PIL image as PNG | | Pickle | 'pkl' | `Pickle` | arbitrary Python objects | | JSON | 'json' | `JSON` | arbitrary data as JSON | @@ -52,7 +52,7 @@ Here's an example where the field `x` is an image, and `y` is a class label, as ```python column = { - 'x': 'jpeg', + 'x': 'jpeg:90', 'y': 'int8', } ``` diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index c28d0058d..ed2bc6ca9 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -464,7 +464,29 @@ def decode(self, data: bytes) -> Image.Image: class JPEG(Encoding): - """Store PIL image as JPEG.""" + """Store PIL image as JPEG. Optionally specify quality.""" + + def __init__(self, quality: int = 75): + if not isinstance(quality, int): + raise ValueError('JPEG quality must be an integer') + if not (0 <= quality <= 100): + raise ValueError('JPEG quality must be between 0 and 100') + self.quality = quality + + @classmethod + def from_str(cls, config: str) -> Self: + """Parse this encoding from string. + + Args: + text (str): The string to parse. + + Returns: + Self: The initialized Encoding. + """ + if config == '': + return cls() + else: + return cls(int(config)) def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) @@ -474,7 +496,7 @@ def encode(self, obj: Image.Image) -> bytes: return f.read() else: out = BytesIO() - obj.save(out, format='JPEG') + obj.save(out, format='JPEG', quality=self.quality) return out.getvalue() def decode(self, data: bytes) -> Image.Image: diff --git a/tests/test_encodings.py b/tests/test_encodings.py index 47fe2a6b2..376a3d2be 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -194,6 +194,30 @@ def test_jpeg_encode_decode(self, mode: str): dec_data = dec_data.convert('I') assert isinstance(dec_data, Image.Image) + @pytest.mark.parametrize('mode', ['L', 'RGB']) + def test_jpeg_encode_decode_with_quality(self, mode: str): + jpeg_enc = mdsEnc.JPEG(quality=50) + assert jpeg_enc.size is None + + # Creating the (32 x 32) NumPy Array with random values + np_data = np.random.randint(255, size=(32, 32), dtype=np.uint32) + # Default image mode of PIL Image is 'I' + img = Image.fromarray(np_data).convert(mode) + + # Test encode + enc_data = jpeg_enc.encode(img) + assert isinstance(enc_data, bytes) + + # Test decode + dec_data = jpeg_enc.decode(enc_data) + dec_data = dec_data.convert('I') + assert isinstance(dec_data, Image.Image) + + @pytest.mark.parametrize('quality', [-1, 101, 'foo']) + def test_jpeg_encode_decode_with_quality_invalid(self, quality: Any): + with pytest.raises(ValueError): + mdsEnc.JPEG(quality=quality) + @pytest.mark.parametrize('mode', ['L', 'RGB']) def test_jpegfile_encode_decode(self, mode: str): jpeg_enc = mdsEnc.JPEG() @@ -224,6 +248,7 @@ def test_jpeg_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): jpeg_enc = mdsEnc.JPEG() _ = jpeg_enc.encode(data) + @pytest.mark.parametrize('mode', ['I', 'L', 'RGB']) def test_png_encode_decode(self, mode: str):