diff --git a/timm/models/_factory.py b/timm/models/_factory.py index 72b4b16cc7..bff15b9a64 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -99,7 +99,10 @@ def create_model( assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' # For model names specified in the form `hf-hub:path/architecture_name@revision`, # load model weights + pretrained_cfg from Hugging Face hub. - pretrained_cfg, model_name = load_model_config_from_hf(model_name) + pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name) + if model_args: + for k, v in model_args.items(): + kwargs.setdefault(k, v) else: model_name, pretrained_tag = split_model_name_tag(model_name) if pretrained_tag and not pretrained_cfg: diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 720a50914d..53f8e85521 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -164,8 +164,9 @@ def load_model_config_from_hf(model_id: str): if 'label_descriptions' in hf_config: pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions') + model_args = hf_config.get('model_args', {}) model_name = hf_config['architecture'] - return pretrained_cfg, model_name + return pretrained_cfg, model_name, model_args def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): @@ -193,19 +194,23 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): def save_config_for_hf( model, config_path: str, - model_config: Optional[dict] = None + model_config: Optional[dict] = None, + model_args: Optional[dict] = None ): model_config = model_config or {} hf_config = {} pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) # set some values at root config level hf_config['architecture'] = pretrained_cfg.pop('architecture') - hf_config['num_classes'] = model_config.get('num_classes', model.num_classes) - hf_config['num_features'] = model_config.get('num_features', model.num_features) - global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None)) + hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) + + # NOTE these attr saved for informational purposes, do not impact model build + hf_config['num_features'] = model_config.pop('num_features', model.num_features) + global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None)) if isinstance(global_pool_type, str) and global_pool_type: hf_config['global_pool'] = global_pool_type + # Save class label info if 'labels' in model_config: _logger.warning( "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'." @@ -225,6 +230,9 @@ def save_config_for_hf( # maps label names -> descriptions hf_config['label_descriptions'] = label_descriptions + if model_args: + hf_config['model_args'] = model_args + hf_config['pretrained_cfg'] = pretrained_cfg hf_config.update(model_config) @@ -236,6 +244,7 @@ def save_for_hf( model, save_directory: str, model_config: Optional[dict] = None, + model_args: Optional[dict] = None, safe_serialization: Union[bool, Literal["both"]] = False, ): assert has_hf_hub(True) @@ -251,11 +260,16 @@ def save_for_hf( torch.save(tensors, save_directory / HF_WEIGHTS_NAME) config_path = save_directory / 'config.json' - save_config_for_hf(model, config_path, model_config=model_config) + save_config_for_hf( + model, + config_path, + model_config=model_config, + model_args=model_args, + ) def push_to_hf_hub( - model, + model: torch.nn.Module, repo_id: str, commit_message: str = 'Add model', token: Optional[str] = None, @@ -264,6 +278,7 @@ def push_to_hf_hub( create_pr: bool = False, model_config: Optional[dict] = None, model_card: Optional[dict] = None, + model_args: Optional[dict] = None, safe_serialization: Union[bool, Literal["both"]] = False, ): """ @@ -291,7 +306,13 @@ def push_to_hf_hub( # Dump model and push to Hub with TemporaryDirectory() as tmpdir: # Save model weights and config. - save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization) + save_for_hf( + model, + tmpdir, + model_config=model_config, + model_args=model_args, + safe_serialization=safe_serialization, + ) # Add readme if it does not exist if not has_readme: