Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add jpeg quality option #818

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/preparing_datasets/basic_dataset_conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ out = ('/local/data', 'oci://bucket/data')
| Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 |
| Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 |
| Image | 'pil' | `PIL` | raw PIL image class ([link]((https://pillow.readthedocs.io/en/stable/reference/Image.html))) |
| Image | 'jpeg' | `JPEG` | PIL image as JPEG |
| Image | 'jpeg:quality' | `JPEG` | PIL image as JPEG, quality between 0 and 100 |
| Image | 'png' | `PNG` | PIL image as PNG |
| Pickle | 'pkl' | `Pickle` | arbitrary Python objects |
| JSON | 'json' | `JSON` | arbitrary data as JSON |
Expand All @@ -52,7 +52,7 @@ Here's an example where the field `x` is an image, and `y` is a class label, as
<!--pytest.mark.skip-->
```python
column = {
'x': 'jpeg',
'x': 'jpeg:90',
'y': 'int8',
}
```
Expand Down
26 changes: 24 additions & 2 deletions streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,29 @@ def decode(self, data: bytes) -> Image.Image:


class JPEG(Encoding):
"""Store PIL image as JPEG."""
"""Store PIL image as JPEG. Optionally specify quality."""

def __init__(self, quality: int = 75):
if not isinstance(quality, int):
raise ValueError('JPEG quality must be an integer')
if not (0 <= quality <= 100):
raise ValueError('JPEG quality must be between 0 and 100')
self.quality = quality

@classmethod
def from_str(cls, config: str) -> Self:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably also add a test for this too then, to also confirm what it should look like when used. Thanks!

"""Parse this encoding from string.

Args:
text (str): The string to parse.

Returns:
Self: The initialized Encoding.
"""
if config == '':
return cls()
else:
return cls(int(config))

def encode(self, obj: Image.Image) -> bytes:
self._validate(obj, Image.Image)
Expand All @@ -474,7 +496,7 @@ def encode(self, obj: Image.Image) -> bytes:
return f.read()
else:
out = BytesIO()
obj.save(out, format='JPEG')
obj.save(out, format='JPEG', quality=self.quality)
return out.getvalue()

def decode(self, data: bytes) -> Image.Image:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,30 @@ def test_jpeg_encode_decode(self, mode: str):
dec_data = dec_data.convert('I')
assert isinstance(dec_data, Image.Image)

@pytest.mark.parametrize('mode', ['L', 'RGB'])
def test_jpeg_encode_decode_with_quality(self, mode: str):
jpeg_enc = mdsEnc.JPEG(quality=50)
assert jpeg_enc.size is None

# Creating the (32 x 32) NumPy Array with random values
np_data = np.random.randint(255, size=(32, 32), dtype=np.uint32)
# Default image mode of PIL Image is 'I'
img = Image.fromarray(np_data).convert(mode)

# Test encode
enc_data = jpeg_enc.encode(img)
assert isinstance(enc_data, bytes)

# Test decode
dec_data = jpeg_enc.decode(enc_data)
dec_data = dec_data.convert('I')
assert isinstance(dec_data, Image.Image)

@pytest.mark.parametrize('quality', [-1, 101, 'foo'])
def test_jpeg_encode_decode_with_quality_invalid(self, quality: Any):
with pytest.raises(ValueError):
mdsEnc.JPEG(quality=quality)

@pytest.mark.parametrize('mode', ['L', 'RGB'])
def test_jpegfile_encode_decode(self, mode: str):
jpeg_enc = mdsEnc.JPEG()
Expand Down Expand Up @@ -224,6 +248,7 @@ def test_jpeg_encode_invalid_data(self, data: Any):
with pytest.raises(AttributeError):
jpeg_enc = mdsEnc.JPEG()
_ = jpeg_enc.encode(data)


@pytest.mark.parametrize('mode', ['I', 'L', 'RGB'])
def test_png_encode_decode(self, mode: str):
Expand Down