From fa4a1e597f61e05883e4e783c141c0d6dc9755be Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 26 Aug 2024 22:28:01 -0700 Subject: [PATCH] Better all res resolution for bulk runner --- bulk_runner.py | 54 +++++++++++++++++++++++++++------------- timm/models/__init__.py | 3 ++- timm/models/_registry.py | 11 +++++++- 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/bulk_runner.py b/bulk_runner.py index e078b8efca..286059c2f3 100755 --- a/bulk_runner.py +++ b/bulk_runner.py @@ -21,7 +21,7 @@ from typing import Callable, List, Tuple, Union -from timm.models import is_model, list_models, get_pretrained_cfg +from timm.models import is_model, list_models, get_pretrained_cfg, get_arch_pretrained_cfgs parser = argparse.ArgumentParser(description='Per-model process launcher') @@ -98,23 +98,44 @@ def _get_model_cfgs( num_classes=None, expand_train_test=False, include_crop=True, + expand_arch=False, ): - model_cfgs = [] - for n in model_names: - pt_cfg = get_pretrained_cfg(n) - if num_classes is not None and getattr(pt_cfg, 'num_classes', 0) != num_classes: - continue - model_cfgs.append((n, pt_cfg.input_size[-1], pt_cfg.crop_pct)) - if expand_train_test and pt_cfg.test_input_size is not None: - if pt_cfg.test_crop_pct is not None: - model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.test_crop_pct)) + model_cfgs = set() + + for name in model_names: + if expand_arch: + pt_cfgs = get_arch_pretrained_cfgs(name).values() + else: + pt_cfg = get_pretrained_cfg(name) + pt_cfgs = [pt_cfg] if pt_cfg is not None else [] + + for cfg in pt_cfgs: + if cfg.input_size is None: + continue + if num_classes is not None and getattr(cfg, 'num_classes', 0) != num_classes: + continue + + # Add main configuration + size = cfg.input_size[-1] + if include_crop: + model_cfgs.add((name, size, cfg.crop_pct)) else: - model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.crop_pct)) + model_cfgs.add((name, size)) + + # Add test configuration if required + if expand_train_test and cfg.test_input_size is not None: + test_size = cfg.test_input_size[-1] + if include_crop: + test_crop = cfg.test_crop_pct or cfg.crop_pct + model_cfgs.add((name, test_size, test_crop)) + else: + model_cfgs.add((name, test_size)) + + # Format the output if include_crop: - model_cfgs = [(n, {'img-size': r, 'crop-pct': cp}) for n, r, cp in sorted(model_cfgs)] + return [(n, {'img-size': r, 'crop-pct': cp}) for n, r, cp in sorted(model_cfgs)] else: - model_cfgs = [(n, {'img-size': r}) for n, r, cp in sorted(model_cfgs)] - return model_cfgs + return [(n, {'img-size': r}) for n, r in sorted(model_cfgs)] def main(): @@ -132,7 +153,7 @@ def main(): model_cfgs = _get_model_cfgs(model_names, num_classes=1000, expand_train_test=True) elif args.model_list == 'all_res': model_names = list_models() - model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False) + model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False, expand_arch=True) elif not is_model(args.model_list): # model name doesn't exist, try as wildcard filter model_names = list_models(args.model_list) @@ -140,9 +161,8 @@ def main(): if not model_cfgs and os.path.exists(args.model_list): with open(args.model_list) as f: - model_cfgs = [] model_names = [line.rstrip() for line in f] - _get_model_cfgs( + model_cfgs = _get_model_cfgs( model_names, #num_classes=1000, expand_train_test=True, diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 5e723724cc..60fd483c44 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -95,4 +95,5 @@ from ._prune import adapt_model_from_string from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \ register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \ - is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value + is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value, \ + get_arch_pretrained_cfgs diff --git a/timm/models/_registry.py b/timm/models/_registry.py index fde8bac787..09d73b551b 100644 --- a/timm/models/_registry.py +++ b/timm/models/_registry.py @@ -16,7 +16,7 @@ __all__ = [ 'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs', 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', - 'get_pretrained_cfg_value', 'is_model_pretrained' + 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_pretrained_cfgs_for_arch' ] _module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module @@ -341,3 +341,12 @@ def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]: """ cfg = get_pretrained_cfg(model_name, allow_unregistered=False) return getattr(cfg, cfg_key, None) + + +def get_arch_pretrained_cfgs(model_name: str) -> Dict[str, PretrainedCfg]: + """ Get all pretrained cfgs for a given architecture. + """ + arch_name, _ = split_model_name_tag(model_name) + model_names = _model_with_tags[arch_name] + cfgs = {m: _model_pretrained_cfgs[m] for m in model_names} + return cfgs