Skip to content

Commit

Permalink
remove huggingface mixing
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Jul 24, 2024
1 parent 56d4762 commit 85e8ee9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
42 changes: 28 additions & 14 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,33 +141,35 @@ def minimize_data_config(input_path, output_path, model):
yaml.dump(config, outfile, default_flow_style=False)


class PVNetModelHubMixin(PyTorchModelHubMixin):
class PVNetModelHubMixin:
"""
Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
"""

@classmethod
@_deprecate_positional_args(version="0.16")
def _from_pretrained(
def from_pretrained(
cls,
*,
model_id: str,
revision: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
cache_dir: Optional[Union[str, Path]] = None,
force_download: bool = False,
proxies: Optional[Dict] = None,
resume_download: Optional[bool] = None,
local_files_only: bool = False,
token: Union[str, bool, None] = None,
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""

if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
config_file = os.path.join(model_id, CONFIG_NAME)
else:

# load model file
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
Expand All @@ -180,11 +182,23 @@ def _from_pretrained(
local_files_only=local_files_only,
)

if "config" in model_kwargs:
logger.debug("Removing config from model_kwargs to avoid conflicts with model init.")
model_kwargs.update(model_kwargs.pop("config"))
# load config file
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)

with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)

model = hydra.utils.instantiate(model_kwargs)
model = hydra.utils.instantiate(config)

state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions tests/models/multimodal/test_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

def test_from_pretrained():
model_name = "openclimatefix/pvnet_uk_region"
model_version = "aa73cdafd1db8df3c8b7f5ecfdb160989e7639ac"
model_version = "92266cd9040c590a9e90ee33eafd0e7b92548be8"

_ = BaseModel.from_pretrained(
model_name,
model_id=model_name,
revision=model_version,
)

0 comments on commit 85e8ee9

Please sign in to comment.