Skip to content

Commit

Permalink
Refactor model path search to just handle folders. Add special load_ …
Browse files Browse the repository at this point in the history
…function for the detector models
  • Loading branch information
austinschneider committed Aug 30, 2024
1 parent f9abd93 commit 43e5520
Showing 1 changed file with 129 additions and 1 deletion.
130 changes: 129 additions & 1 deletion python/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ def _find_model_folder_and_file(base_dir, model_name, must_exist, specific_file=
specific_file_path = os.path.join(base_dir, name, specific_file)
if os.path.isfile(specific_file_path):
return name, True, specific_file_path
else:
return name, True, None
else:
return name, True, None
else:
Expand Down Expand Up @@ -557,6 +559,108 @@ def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exi
return os.path.join(base_dir, model_name, model_file_name)


def _get_model_folder(base_dir, model_name, must_exist):
model_names = [
f for f in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, f))
]

exact_model_names = [f for f in model_names if f.lower() == model_name.lower()]

if len(exact_model_names) == 0:
model_names = [f for f in model_names if f.lower().startswith(model_name.lower())]
else:
model_names = exact_model_names

if len(model_names) == 0 and must_exist:
raise ValueError(f"No model folders found for {model_name}\nSearched in {base_dir}")
elif len(model_names) == 0 and not must_exist:
return model_name, False
elif len(model_names) == 1:
return model_names[0], True
else:
raise ValueError(f"Multiple directories found for {model_name}\nSearched in {base_dir}")

def _get_model_subfolders(base_dir, model_regex):
model_subfolders = [
f for f in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, f))
]
model_subfolders = [
f for f in model_subfolders if model_regex.match(f) is not None
]
return model_subfolders


def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exist=True, specific_file=None):
_model_regex = re.compile(
r"^\s*" + _MODEL_PATTERN + ("" if suffix is None else r"(?:" + suffix + r")?") + r"\s*$",
re.VERBOSE | re.IGNORECASE,
)
suffix = "" if suffix is None else suffix

resources_dir = resource_package_dir()
base_dir = _get_base_directory(resources_dir, prefix)

d = _model_regex.match(model_name)
if d is None:
raise ValueError(f"Invalid model name: {model_name}")
d = d.groupdict()
model_search_name, version = d["model_name"], d["version"]

if version is not None:
version = normalize_version(version)

found_model_name, folder_exists = _get_model_folder(base_dir, model_search_name, must_exist)

model_dir = os.path.join(base_dir, found_model_name)

if not must_exist and not folder_exists:
if version is None:
version = "v1"

model_dir = os.path.join(model_dir, f"{found_model_name}-v{version}")
return model_dir


model_subfolders = _get_model_subfolders(model_dir, _model_regex)

if len(model_subfolders) == 0:
if must_exist:
raise ValueError(f"No model folders found for {model_search_name}\nSearched in {model_dir}")
else:
if version is None:
version = "v1"

model_dir = os.path.join(model_dir, f"{found_model_name}-v{version}")
return model_dir

models_and_versions = []
for f in model_subfolders:
d = _model_regex.match(f).groupdict()
if d["version"] is not None:
models_and_versions.append((f, normalize_version(d["version"])))

matching_models = [(m, v) for m, v in models_and_versions if v == version]

if len(matching_models) == 1:
model_dir = os.path.join(model_dir, matching_models[0][0])
return model_dir
elif len(matching_models) > 1:
raise ValueError(f"Multiple directories found for {model_search_name} with version {version}\nSearched in {model_dir}")

top_level_has_specific_file = specific_file is not None and os.path.isfile(os.path.join(model_dir, specific_file))

if top_level_has_specific_file:
return model_dir

if len(matching_models) == 0:
if must_exist and version is not None:
raise ValueError(f"No model folders found for {model_search_name} with version {version}\nSearched in {model_dir}")

found_model_subfolder, subfolder_version = max(models_and_versions, key=lambda x: tokenize_version(x[1]))

return os.path.join(model_dir, found_model_subfolder)


def get_detector_model_file_path(model_name, must_exist=True):
return _get_model_path(model_name, prefix="Detectors/densities", suffix=".dat", is_file=True, must_exist=must_exist)

Expand Down Expand Up @@ -604,7 +708,31 @@ def load_flux(model_name, *args, **kwargs):


def load_detector(model_name, *args, **kwargs):
return load_resource("detector", model_name, *args, **kwargs)
resource_type = "detector"
resource_name = model_name
folder = _resource_folder_by_name[resource_type]
specific_file = f"{resource_type}.py"

abs_dir = _get_model_path(resource_name, prefix=folder, is_file=False, must_exist=True, specific_file=specific_file)

script_fname = os.path.join(abs_dir, f"{resource_type}.py")
if os.path.isfile(script_fname):
resource_module = load_module(f"siren-{resource_type}-{resource_name}", script_fname, persist=False)
loader = getattr(resource_module, f"load_{resource_type}")
resource = loader(*args, **kwargs)
return resource

densities_fname = os.path.join(abs_dir, "densities.dat")
materials_fname = os.path.join(abs_dir, "materials.dat")

if os.path.isfile(densities_fname) and os.path.isfile(materials_fname):
from . import detector as _detector
detector_model = _detector.DetectorModel()
detector_model.LoadMaterialModel(materials_fname)
detector_model.LoadDetectorModel(densities_fname)
return detector_model

raise ValueError("Could not find detector loading script \"{script_fname}\" or densities and materials files \"{densities_fname}\", \"materials_fname\"")


def load_processes(model_name, *args, **kwargs):
Expand Down

0 comments on commit 43e5520

Please sign in to comment.