Skip to content

Commit

Permalink
run blacks
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Oct 16, 2024
1 parent ab4e84d commit 54616d5
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 114 deletions.
34 changes: 17 additions & 17 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@

# sentry
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
environment=os.getenv("ENVIRONMENT", "local"),
traces_sample_rate=1
dsn=os.getenv("SENTRY_DSN"), environment=os.getenv("ENVIRONMENT", "local"), traces_sample_rate=1
)

sentry_sdk.set_tag("app_name", "pvnet_app")
Expand Down Expand Up @@ -129,7 +127,7 @@ def app(
use_ecmwf_only = os.getenv("USE_ECMWF_ONLY", "false").lower() == "true"
run_extra_models = os.getenv("RUN_EXTRA_MODELS", "false").lower() == "true"
use_ocf_data_sampler = os.getenv("USE_OCF_DATA_SAMPLER", "true").lower() == "true"

logger.info(f"Using `pvnet` library version: {pvnet.__version__}")
logger.info(f"Using `pvnet_app` library version: {pvnet_app.__version__}")
logger.info(f"Using {num_workers} workers")
Expand All @@ -138,10 +136,12 @@ def app(
logger.info(f"Running extra models: {run_extra_models}")

# load models
model_configs = get_all_models(get_ecmwf_only=use_ecmwf_only,
get_day_ahead_only=use_day_ahead_model,
run_extra_models=run_extra_models,
use_ocf_data_sampler=use_ocf_data_sampler)
model_configs = get_all_models(
get_ecmwf_only=use_ecmwf_only,
get_day_ahead_only=use_day_ahead_model,
run_extra_models=run_extra_models,
use_ocf_data_sampler=use_ocf_data_sampler,
)

logger.info(f"Using adjuster: {model_configs[0].use_adjuster}")
logger.info(f"Saving GSP sum: {model_configs[0].save_gsp_sum}")
Expand All @@ -166,12 +166,12 @@ def app(

# Get capacities from the database
logger.info("Loading capacities from the database")

db_connection = DatabaseConnection(url=os.getenv("DB_URL"), base=Base_Forecast, echo=False)
with db_connection.get_session() as session:
#  Pandas series of most recent GSP capacities
gsp_capacities = get_latest_gsp_capacities(
session=session, gsp_ids=gsp_ids, datetime_utc=t0-timedelta(days=2)
session=session, gsp_ids=gsp_ids, datetime_utc=t0 - timedelta(days=2)
)

# National capacity is needed if using summation model
Expand Down Expand Up @@ -241,21 +241,21 @@ def app(
logger.info("Creating DataLoader")

if not use_ocf_data_sampler:
logger.info('Making OCF datapipes dataloader')
logger.info("Making OCF datapipes dataloader")
# The current day ahead model uses the legacy dataloader
dataloader = get_legacy_dataloader(
config_filename=common_config_path,
t0=t0,
config_filename=common_config_path,
t0=t0,
gsp_ids=gsp_ids,
batch_size=batch_size,
num_workers=num_workers,
)

else:
logger.info('Making OCF Data Sampler dataloader')
logger.info("Making OCF Data Sampler dataloader")
dataloader = get_dataloader(
config_filename=common_config_path,
t0=t0,
config_filename=common_config_path,
t0=t0,
gsp_ids=gsp_ids,
batch_size=batch_size,
num_workers=num_workers,
Expand Down
81 changes: 39 additions & 42 deletions pvnet_app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def load_yaml_config(path: str) -> dict:

def save_yaml_config(config: dict, path: str) -> None:
"""Save config file to path"""
with open(path, 'w') as file:
with open(path, "w") as file:
yaml.dump(config, file, default_flow_style=False)


Expand All @@ -23,29 +23,29 @@ def populate_config_with_data_data_filepaths(config: dict, gsp_path: str = "") -
config: The data config
gsp_path: For lagacy usage only
"""

production_paths = {
"gsp": gsp_path,
"nwp": {"ukv": nwp_ukv_path, "ecmwf": nwp_ecmwf_path},
"satellite": sat_path,
}
}

# Replace data sources
for source in ["gsp", "satellite"]:
if source in config["input_data"] :
if config["input_data"][source][f"{source}_zarr_path"]!="":
if source in config["input_data"]:
if config["input_data"][source][f"{source}_zarr_path"] != "":
assert source in production_paths, f"Missing production path: {source}"
config["input_data"][source][f"{source}_zarr_path"] = production_paths[source]

# NWP is nested so much be treated separately
if "nwp" in config["input_data"]:
nwp_config = config["input_data"]["nwp"]
for nwp_source in nwp_config.keys():
if nwp_config[nwp_source]["nwp_zarr_path"]!="":
if nwp_config[nwp_source]["nwp_zarr_path"] != "":
assert "nwp" in production_paths, "Missing production path: nwp"
assert nwp_source in production_paths["nwp"], f"Missing NWP path: {nwp_source}"
nwp_config[nwp_source]["nwp_zarr_path"] = production_paths["nwp"][nwp_source]

return config


Expand All @@ -55,27 +55,25 @@ def overwrite_config_dropouts(config: dict) -> dict:
Args:
config: The data config
"""

# Replace data sources
for source in ["satellite"]:
if source in config["input_data"] :
if config["input_data"][source][f"{source}_zarr_path"]!="":
if source in config["input_data"]:
if config["input_data"][source][f"{source}_zarr_path"] != "":
config["input_data"][source][f"dropout_timedeltas_minutes"] = None

# NWP is nested so much be treated separately
if "nwp" in config["input_data"]:
nwp_config = config["input_data"]["nwp"]
for nwp_source in nwp_config.keys():
if nwp_config[nwp_source]["nwp_zarr_path"]!="":
if nwp_config[nwp_source]["nwp_zarr_path"] != "":
nwp_config[nwp_source]["dropout_timedeltas_minutes"] = None

return config


def modify_data_config_for_production(
input_path: str,
output_path: str,
gsp_path: str = ""
input_path: str, output_path: str, gsp_path: str = ""
) -> None:
"""Resave the data config with the data source filepaths and dropouts overwritten
Expand All @@ -85,16 +83,16 @@ def modify_data_config_for_production(
gsp_path: For lagacy usage only
"""
config = load_yaml_config(input_path)

config = populate_config_with_data_data_filepaths(config, gsp_path=gsp_path)
config = overwrite_config_dropouts(config)

save_yaml_config(config, output_path)


def get_union_of_configs(config_paths: list[str]) -> dict:
"""Find the config which is able to run all models from a list of config paths
Note that this implementation is very limited and will not work in general unless all models
have been trained on the same batches. We do not chck example if the satellite and NWP channels
are the same in the different configs, or whether the NWP time slices are the same. Many more
Expand All @@ -103,37 +101,36 @@ def get_union_of_configs(config_paths: list[str]) -> dict:

# Load all the configs
configs = [load_yaml_config(config_path) for config_path in config_paths]

# We will ammend this config according to the entries in the other configs
common_config = configs[0]
common_config = configs[0]

for config in configs[1:]:

if "satellite" in config["input_data"]:

if "satellite" in common_config["input_data"]:
# Find the minimum satellite delay across configs

# Find the minimum satellite delay across configs
common_config["input_data"]["satellite"]["live_delay_minutes"] = min(
common_config["input_data"]["satellite"]["live_delay_minutes"],
config["input_data"]["satellite"]["live_delay_minutes"]
config["input_data"]["satellite"]["live_delay_minutes"],
)


else:
# Add satellite to common config if not there already
# Add satellite to common config if not there already
common_config["input_data"]["satellite"] = config["input_data"]["satellite"]

if "nwp" in config["input_data"]:
# Add NWP to common config if not there already

# Add NWP to common config if not there already
if "nwp" not in common_config["input_data"]:
common_config["input_data"]["nwp"] = config["input_data"]["nwp"]

else:
for nwp_key, nwp_conf in config["input_data"]["nwp"].items():
# Add different NWP sources to common config if not there already
# Add different NWP sources to common config if not there already
if nwp_key not in common_config["input_data"]["nwp"]:
common_config["input_data"]["nwp"][nwp_key] = nwp_conf
return common_config

return common_config
2 changes: 1 addition & 1 deletion pvnet_app/consts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
sat_path = "sat.zarr"
nwp_ukv_path = "nwp_ukv.zarr"
nwp_ecmwf_path = "nwp_ecmwf.zarr"
nwp_ecmwf_path = "nwp_ecmwf.zarr"
5 changes: 3 additions & 2 deletions pvnet_app/data/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def _download_nwp_data(source, destination):
fs.get(source, destination, recursive=True)


def download_all_nwp_data(download_ukv: Optional[bool] = True, download_ecmwf: Optional[bool] = True):
def download_all_nwp_data(
download_ukv: Optional[bool] = True, download_ecmwf: Optional[bool] = True
):
"""Download the NWP data"""
if download_ukv:
_download_nwp_data(os.environ["NWP_UKV_ZARR_PATH"], nwp_ukv_path)
Expand Down Expand Up @@ -163,4 +165,3 @@ def preprocess_nwp_data(use_ukv: Optional[bool] = True, use_ecmwf: Optional[bool
fix_ecmwf_data()
else:
logger.info(f"Skipping ECMWF data preprocessing")

Loading

0 comments on commit 54616d5

Please sign in to comment.