Skip to content

Commit

Permalink
feat: support mp3 and ogg in mime types
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv committed Dec 22, 2024
1 parent a11af34 commit add73de
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion disnake/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def to_file(
# if the filename doesn't have an extension (e.g. widget member avatars),
# try to infer it from the data
if not os.path.splitext(filename)[1]:
ext = utils._get_extension_for_image(data)
ext = utils._get_extension_for_data(data)
if ext:
filename += ext

Expand Down
5 changes: 1 addition & 4 deletions disnake/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -5100,10 +5100,7 @@ async def create_soundboard_sound(
:class:`GuildSoundboardSound`
The newly created soundboard sound.
"""
# TODO: consider trying to determine correct mime type, or leave it at images for now and keep using octet-stream here?
sound_data = await utils._assetbytes_to_base64_data(
sound, mime_type="application/octet-stream"
)
sound_data = await utils._assetbytes_to_base64_data(sound)
emoji_name, emoji_id = PartialEmoji._emoji_to_name_id(emoji)

data = await self._state.http.create_guild_soundboard_sound(
Expand Down
2 changes: 1 addition & 1 deletion disnake/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,7 +1655,7 @@ def create_guild_sticker(
initial_bytes = file.fp.read(16)

try:
mime_type = utils._get_mime_type_for_image(initial_bytes)
mime_type = utils._get_mime_type_for_data(initial_bytes)
except ValueError:
if initial_bytes.startswith(b"{"):
mime_type = "application/json"
Expand Down
30 changes: 18 additions & 12 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,12 @@ def _maybe_cast(value: V, converter: Callable[[V], T], default: T = None) -> Opt
"image/jpeg": ".jpg",
"image/gif": ".gif",
"image/webp": ".webp",
"audio/mpeg": ".mp3",
"audio/ogg": ".ogg",
}


def _get_mime_type_for_image(data: _BytesLike) -> str:
def _get_mime_type_for_data(data: _BytesLike) -> str:
if data[0:8] == b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A":
return "image/png"
elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"):
Expand All @@ -520,43 +522,47 @@ def _get_mime_type_for_image(data: _BytesLike) -> str:
return "image/gif"
elif data[0:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
elif data[0:3] == b"ID3" or data[0:2] in (b"\xff\xfb", b"\xff\xf3", b"\xff\xf2"):
# n.b. this doesn't support the unofficial MPEG-2.5 frame header (which starts with 0xFFEx).
# Discord also doesn't accept it.
return "audio/mpeg"
elif data[0:4] == b"OggS":
return "audio/ogg"
else:
raise ValueError("Unsupported image type given")
raise ValueError("Unsupported file type provided")


def _bytes_to_base64_data(data: _BytesLike, *, mime_type: Optional[str] = None) -> str:
def _bytes_to_base64_data(data: _BytesLike) -> str:
fmt = "data:{mime};base64,{data}"
mime = mime_type or _get_mime_type_for_image(data)
mime = _get_mime_type_for_data(data)
b64 = b64encode(data).decode("ascii")
return fmt.format(mime=mime, data=b64)


def _get_extension_for_image(data: _BytesLike) -> Optional[str]:
def _get_extension_for_data(data: _BytesLike) -> Optional[str]:
try:
mime_type = _get_mime_type_for_image(data)
mime_type = _get_mime_type_for_data(data)
except ValueError:
return None
return _mime_type_extensions.get(mime_type)


@overload
async def _assetbytes_to_base64_data(data: None, *, mime_type: Optional[str] = None) -> None:
async def _assetbytes_to_base64_data(data: None) -> None:
...


@overload
async def _assetbytes_to_base64_data(data: AssetBytes, *, mime_type: Optional[str] = None) -> str:
async def _assetbytes_to_base64_data(data: AssetBytes) -> str:
...


async def _assetbytes_to_base64_data(
data: Optional[AssetBytes], *, mime_type: Optional[str] = None
) -> Optional[str]:
async def _assetbytes_to_base64_data(data: Optional[AssetBytes]) -> Optional[str]:
if data is None:
return None
if not isinstance(data, (bytes, bytearray, memoryview)):
data = await data.read()
return _bytes_to_base64_data(data, mime_type=mime_type)
return _bytes_to_base64_data(data)


if HAS_ORJSON:
Expand Down
25 changes: 11 additions & 14 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,20 @@ def test_maybe_cast() -> None:
(b"\x47\x49\x46\x38\x37\x61", "image/gif", ".gif"),
(b"\x47\x49\x46\x38\x39\x61", "image/gif", ".gif"),
(b"RIFFxxxxWEBP", "image/webp", ".webp"),
(b"ID3", "audio/mpeg", ".mp3"),
(b"\xFF\xF3", "audio/mpeg", ".mp3"),
(b"OggS", "audio/ogg", ".ogg"),
],
)
def test_mime_type_valid(data, expected_mime, expected_ext) -> None:
for d in (data, data + b"\xFF"):
assert utils._get_mime_type_for_image(d) == expected_mime
assert utils._get_extension_for_image(d) == expected_ext
assert utils._get_mime_type_for_data(d) == expected_mime
assert utils._get_extension_for_data(d) == expected_ext

prefixed = b"\xFF" + data
with pytest.raises(ValueError, match=r"Unsupported image type given"):
utils._get_mime_type_for_image(prefixed)
assert utils._get_extension_for_image(prefixed) is None
with pytest.raises(ValueError, match=r"Unsupported file type provided"):
utils._get_mime_type_for_data(prefixed)
assert utils._get_extension_for_data(prefixed) is None


@pytest.mark.parametrize(
Expand All @@ -291,9 +294,9 @@ def test_mime_type_valid(data, expected_mime, expected_ext) -> None:
],
)
def test_mime_type_invalid(data) -> None:
with pytest.raises(ValueError, match=r"Unsupported image type given"):
utils._get_mime_type_for_image(data)
assert utils._get_extension_for_image(data) is None
with pytest.raises(ValueError, match=r"Unsupported file type provided"):
utils._get_mime_type_for_data(data)
assert utils._get_extension_for_data(data) is None


@pytest.mark.asyncio
Expand All @@ -312,12 +315,6 @@ async def test_assetbytes_base64() -> None:

assert await utils._assetbytes_to_base64_data(mock_asset) == expected

# test mime override
assert (
await utils._assetbytes_to_base64_data(b"\x01\x02\x03", mime_type="application/json")
== "data:application/json;base64,AQID"
)


@pytest.mark.parametrize(
("after", "use_clock", "expected"),
Expand Down

0 comments on commit add73de

Please sign in to comment.