Skip to content

Commit

Permalink
Huggingface dataconfig (#89)
Browse files Browse the repository at this point in the history
* Upload data config alongside model to huggingface

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add function to download and return config path

* checkpoint clean data config to huggingface

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge #minor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfulu and pre-commit-ci[bot] authored Oct 31, 2023
1 parent e398906 commit b922c9d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 2 deletions.
73 changes: 71 additions & 2 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn.functional as F
import wandb
import yaml
from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from huggingface_hub.file_download import hf_hub_download
Expand All @@ -32,13 +33,45 @@
from pvnet.optimizers import AbstractOptimizer
from pvnet.utils import construct_ocf_ml_metrics_batch_df, plot_batch_forecasts

DATA_CONFIG_NAME = "data_config.yaml"


logger = logging.getLogger(__name__)

activities = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)


def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"):
"""Resave the data config and replace the filepaths with a placeholder.
Args:
input_path: Path to input datapipes configuration file
output_path: Location to save the output configuration file
placeholder: String placeholder for data sources
"""
with open(input_path) as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)

config["general"]["description"] = "Config for training the saved PVNet model"
config["general"]["name"] = "PVNet current"

for source in ["gsp", "nwp", "satellite", "hrvsatellite"]:
if source in config["input_data"]:
# If not empty - i.e. if used
if config["input_data"][source][f"{source}_zarr_path"] != "":
config["input_data"][source][f"{source}_zarr_path"] = f"{placeholder}.zarr"

if "pv" in config["input_data"]:
for d in config["input_data"]["pv"]["pv_files_groups"]:
d["pv_filename"] = f"{placeholder}.netcdf"
d["pv_metadata_filename"] = f"{placeholder}.csv"

with open(output_path, "w") as outfile:
yaml.dump(config, outfile, default_flow_style=False)


class PVNetModelHubMixin(PyTorchModelHubMixin):
"""
Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
Expand Down Expand Up @@ -90,11 +123,42 @@ def _from_pretrained(

return model

@_deprecate_positional_args(version="0.16")
@classmethod
def get_data_config(
cls,
model_id: str,
revision: str,
cache_dir: Optional[Union[str, Path]] = None,
force_download: bool = False,
proxies: Optional[Dict] = None,
resume_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
):
"""Load data config file."""
if os.path.isdir(model_id):
print("Loading data config from local directory")
data_config_file = os.path.join(model_id, DATA_CONFIG_NAME)
else:
data_config_file = hf_hub_download(
repo_id=model_id,
filename=DATA_CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)

return data_config_file

def save_pretrained(
self,
save_directory: Union[str, Path],
config: dict,
data_config: Union[str, Path],
repo_id: Optional[str] = None,
push_to_hub: bool = False,
wandb_model_code: Optional[str] = None,
Expand All @@ -109,6 +173,8 @@ def save_pretrained(
Path to directory in which the model weights and configuration will be saved.
config (`dict`):
Model configuration specified as a key/value dictionary.
data_config (`str` or `Path`):
The path to the data config.
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
the folder name if not provided.
Expand All @@ -128,10 +194,13 @@ def save_pretrained(
# saving model weights/files
self._save_pretrained(save_directory)

# saving config
# saving model and data config
if isinstance(config, dict):
(save_directory / CONFIG_NAME).write_text(json.dumps(config, indent=4))

# Save cleaned datapipes configuration file
make_clean_data_config(data_config, save_directory / DATA_CONFIG_NAME)

# Creating and saving model card.
card_data = ModelCardData(language="en", license="mit", library_name="pytorch")
if card_template_path is None:
Expand Down
11 changes: 11 additions & 0 deletions pvnet/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Training"""
import os
import shutil
from typing import Optional

import hydra
Expand Down Expand Up @@ -108,6 +109,16 @@ def train(config: DictConfig) -> Optional[float]:
# Also save model config here - this makes for easy model push to huggingface
os.makedirs(callback.dirpath, exist_ok=True)
OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml")

# Similarly save the data config
data_config = config.datamodule.configuration
if data_config is None:
# Data config can be none if using presaved batches. We go to the presaved
# batches to get the data config
data_config = f"{config.datamodule.batch_dir}/data_configuration.yaml"

assert os.path.isfile(data_config), f"Data config file not found: {data_config}"
shutil.copyfile(data_config, f"{callback.dirpath}/data_config.yaml")
break

should_pretrain = False
Expand Down
5 changes: 5 additions & 0 deletions scripts/checkpoint_to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def push_to_huggingface(

model.load_state_dict(state_dict=checkpoint["state_dict"])

# Check for data config
data_config = f"{checkpoint_dir_path}/data_config.yaml"
assert os.path.isfile(data_config)

# Push to hub
if local_path is None:
temp_dir = tempfile.TemporaryDirectory()
Expand All @@ -69,6 +73,7 @@ def push_to_huggingface(
model.save_pretrained(
model_output_dir,
config=model_config,
data_config=data_config,
wandb_model_code=wandb_id,
push_to_hub=push_to_hub,
repo_id="openclimatefix/pvnet_v2" if push_to_hub else None,
Expand Down

0 comments on commit b922c9d

Please sign in to comment.