diff --git a/src/python/roms/__init__.py b/src/python/roms/__init__.py index b435517b3..38c694752 100644 --- a/src/python/roms/__init__.py +++ b/src/python/roms/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import functools import hashlib import json import tarfile @@ -10,10 +11,16 @@ import requests -def _download_roms(): - """Downloads all roms in the form of a base64 file, then decodes it into a tar.gz, then extracts all the roms by matching it to the expected md5.""" +@functools.cache +def _get_all_rom_hashes() -> dict[str, str]: # this is a map of {rom.bin : md5 checksum} - all_roms = json.load(open(Path(__file__).parent / "md5.json")) + with open(Path(__file__).parent / "md5.json") as f: + return json.load(f) + + +def _download_roms() -> None: + """Downloads all roms in the form of a base64 file, then decodes it into a tar.gz, then extracts all the roms by matching it to the expected md5.""" + all_roms = _get_all_rom_hashes() # use requests to download the base64 file url = "https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64" @@ -51,9 +58,10 @@ def _download_roms(): md5_hash = md5.hexdigest() # assert that the hash matches - assert ( - md5_hash == all_roms[rom_name] - ), f"Rom {rom_name}'s hash does not match what was expected. Please report this to a dev." + assert md5_hash == all_roms[rom_name], ( + f"Rom {rom_name}'s hash ({md5_hash}) does not match what was expected. " + "Please report this to a dev." + ) # save this rom rom_path = Path(__file__).parent / rom_name @@ -68,8 +76,8 @@ def get_rom_path(name: str) -> Path | None: bin_file = f"{name}.bin" bin_path = Path(__file__).parent / bin_file - # check if it exists within the md5.json - all_roms = json.load(open(Path(__file__).parent / "md5.json")) + # check if it exists within the the hash dictionary + all_roms = _get_all_rom_hashes() if bin_file not in all_roms: warnings.warn(f"Rom {name} not supported.") return None @@ -89,5 +97,5 @@ def get_rom_path(name: str) -> Path | None: def get_all_rom_ids() -> list[str]: """Returns a list of all available rom_ids, ie: ['tetris', 'pong', 'zaxxon', ...].""" - all_roms = json.load(open(Path(__file__).parent / "md5.json")) + all_roms = _get_all_rom_hashes() return [key.split(".")[0] for key in all_roms.keys()]