Skip to content

Commit

Permalink
Add support for passing model args via hf hub config
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Nov 19, 2023
1 parent 23e7f17 commit a604011
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
5 changes: 4 additions & 1 deletion timm/models/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def create_model(
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
if model_args:
for k, v in model_args.items():
kwargs.setdefault(k, v)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
if pretrained_tag and not pretrained_cfg:
Expand Down
37 changes: 29 additions & 8 deletions timm/models/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ def load_model_config_from_hf(model_id: str):
if 'label_descriptions' in hf_config:
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')

model_args = hf_config.get('model_args', {})
model_name = hf_config['architecture']
return pretrained_cfg, model_name
return pretrained_cfg, model_name, model_args


def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
Expand Down Expand Up @@ -193,19 +194,23 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict] = None
model_config: Optional[dict] = None,
model_args: Optional[dict] = None
):
model_config = model_config or {}
hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
# set some values at root config level
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
hf_config['num_features'] = model_config.get('num_features', model.num_features)
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)

# NOTE these attr saved for informational purposes, do not impact model build
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
if isinstance(global_pool_type, str) and global_pool_type:
hf_config['global_pool'] = global_pool_type

# Save class label info
if 'labels' in model_config:
_logger.warning(
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
Expand All @@ -225,6 +230,9 @@ def save_config_for_hf(
# maps label names -> descriptions
hf_config['label_descriptions'] = label_descriptions

if model_args:
hf_config['model_args'] = model_args

hf_config['pretrained_cfg'] = pretrained_cfg
hf_config.update(model_config)

Expand All @@ -236,6 +244,7 @@ def save_for_hf(
model,
save_directory: str,
model_config: Optional[dict] = None,
model_args: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
):
assert has_hf_hub(True)
Expand All @@ -251,11 +260,16 @@ def save_for_hf(
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)

config_path = save_directory / 'config.json'
save_config_for_hf(model, config_path, model_config=model_config)
save_config_for_hf(
model,
config_path,
model_config=model_config,
model_args=model_args,
)


def push_to_hf_hub(
model,
model: torch.nn.Module,
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
Expand All @@ -264,6 +278,7 @@ def push_to_hf_hub(
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
model_args: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
):
"""
Expand Down Expand Up @@ -291,7 +306,13 @@ def push_to_hf_hub(
# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
# Save model weights and config.
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
save_for_hf(
model,
tmpdir,
model_config=model_config,
model_args=model_args,
safe_serialization=safe_serialization,
)

# Add readme if it does not exist
if not has_readme:
Expand Down

0 comments on commit a604011

Please sign in to comment.