From caa8e12d430cab771e2e0715265cfd9054486462 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 4 Sep 2024 11:34:05 +0100 Subject: [PATCH] upgrade to visualization - add limit examples - add dir - prints straight to report --- ocf_datapipes/batch/visualise.py | 461 ++++++++++++++++--------------- 1 file changed, 244 insertions(+), 217 deletions(-) diff --git a/ocf_datapipes/batch/visualise.py b/ocf_datapipes/batch/visualise.py index 63a8aae4..578f8e7a 100644 --- a/ocf_datapipes/batch/visualise.py +++ b/ocf_datapipes/batch/visualise.py @@ -3,6 +3,7 @@ This is a bit of a working progress, but the idea is to visualize the batch in a markdown file. """ +import os import pandas as pd import plotly.graph_objects as go import torch @@ -10,155 +11,259 @@ from ocf_datapipes.batch import BatchKey, NumpyBatch, NWPBatchKey -def visualize_batch(batch: NumpyBatch): +def visualize_batch(batch: NumpyBatch, folder=".", output_file="report.md", limit_examples=None): """Visualize the batch in a markdown file""" - # Wind - print("# Batch visualization") - - print("## Wind \n") - keys = [ - BatchKey.wind, - BatchKey.wind_t0_idx, - BatchKey.wind_time_utc, - BatchKey.wind_id, - BatchKey.wind_observed_capacity_mwp, - BatchKey.wind_nominal_capacity_mwp, - BatchKey.wind_time_utc, - BatchKey.wind_latitude, - BatchKey.wind_longitude, - BatchKey.wind_solar_azimuth, - BatchKey.wind_solar_elevation, - ] - for key in keys: - if key in batch.keys(): - print("\n") - value = batch[key] - if isinstance(value, torch.Tensor): - print(f"{key} {value.shape=}") - print(f"Max {value.max()}") - print(f"Min {value.min()}") - elif isinstance(value, int): - print(f"{key} {value}") - else: - print(f"{key} {value}") - - print("## GSP \n") - keys = [ - BatchKey.gsp, - BatchKey.gsp_id, - BatchKey.gsp_time_utc, - BatchKey.gsp_time_utc_fourier, - BatchKey.gsp_x_osgb, - BatchKey.gsp_x_osgb_fourier, - BatchKey.gsp_y_osgb, - BatchKey.gsp_y_osgb_fourier, - BatchKey.gsp_t0_idx, - BatchKey.gsp_effective_capacity_mwp, - BatchKey.gsp_nominal_capacity_mwp, - BatchKey.gsp_solar_azimuth, - BatchKey.gsp_solar_elevation, - ] - for key in keys: - if key in batch.keys(): - print("\n") - print(f"### {key.name}") - value = batch[key] - if key.name == "gsp": - # plot gsp data - for b in range(value.shape[0]): - fig = go.Figure() - gsp_data = value[b, :, 0] - time = pd.to_datetime(batch[BatchKey.gsp_time_utc][b], unit="s") - fig.add_trace(go.Scatter(x=time, y=gsp_data, mode="lines", name="GSP")) - fig.update_layout( - title=f"GSP - example {b}", xaxis_title="Time", yaxis_title="Value" + # create dir if it does not exist + for d in [folder, f"{folder}/gsp", f"{folder}/nwp", f"{folder}/satellite"]: + if not os.path.exists(d): + os.makedirs(d) + + with open(f"{folder}/{output_file}", "a") as f: + # Wind + print("# Batch visualization", file=f) + + print("## Wind \n", file=f) + keys = [ + BatchKey.wind, + BatchKey.wind_t0_idx, + BatchKey.wind_time_utc, + BatchKey.wind_id, + BatchKey.wind_observed_capacity_mwp, + BatchKey.wind_nominal_capacity_mwp, + BatchKey.wind_time_utc, + BatchKey.wind_latitude, + BatchKey.wind_longitude, + BatchKey.wind_solar_azimuth, + BatchKey.wind_solar_elevation, + ] + for key in keys: + if key in batch.keys(): + print("\n", file=f) + value = batch[key] + if isinstance(value, torch.Tensor): + print(f"{key} {value.shape=}", file=f) + print(f"Max {value.max()}", file=f) + print(f"Min {value.min()}", file=f) + elif isinstance(value, int): + print(f"{key} {value}", file=f) + else: + print(f"{key} {value}", file=f) + + print("## GSP \n", file=f) + keys = [ + BatchKey.gsp, + BatchKey.gsp_id, + BatchKey.gsp_time_utc, + BatchKey.gsp_time_utc_fourier, + BatchKey.gsp_x_osgb, + BatchKey.gsp_x_osgb_fourier, + BatchKey.gsp_y_osgb, + BatchKey.gsp_y_osgb_fourier, + BatchKey.gsp_t0_idx, + BatchKey.gsp_effective_capacity_mwp, + BatchKey.gsp_nominal_capacity_mwp, + BatchKey.gsp_solar_azimuth, + BatchKey.gsp_solar_elevation, + ] + for key in keys: + if key in batch.keys(): + print("\n", file=f) + print(f"### {key.name}", file=f) + value = batch[key] + if key.name == "gsp": + # plot gsp data + n_examples = value.shape[0] + if limit_examples is not None: + n_examples = min(n_examples, limit_examples) + + for b in range(n_examples): + fig = go.Figure() + gsp_data = value[b, :, 0] + time = pd.to_datetime(batch[BatchKey.gsp_time_utc][b], unit="s") + fig.add_trace(go.Scatter(x=time, y=gsp_data, mode="lines", name="GSP")) + fig.update_layout( + title=f"GSP - example {b}", xaxis_title="Time", yaxis_title="Value" + ) + # fig.show(renderer='browser') + name = f"gsp/gsp_{b}.png" + fig.write_image(f"{folder}/{name}") + print(f"![](./{name})", file=f) + print("\n", file=f) + + elif isinstance(value, torch.Tensor): + print(f"shape {value.shape=}", file=f) + print(f"Max {value.max():.2f}", file=f) + print(f"Min {value.min():.2f}", file=f) + elif isinstance(value, int): + print(f"{value}", file=f) + else: + print(f"{value}", file=f) + + # TODO plot solar azimuth and elevation + + # NWP + print("## NWP \n", file=f) + + keys = [ + NWPBatchKey.nwp, + NWPBatchKey.nwp_target_time_utc, + NWPBatchKey.nwp_channel_names, + NWPBatchKey.nwp_step, + NWPBatchKey.nwp_t0_idx, + NWPBatchKey.nwp_init_time_utc, + ] + + nwp = batch[BatchKey.nwp] + + nwp_providers = nwp.keys() + for provider in nwp_providers: + print("\n", file=f) + print(f"### Provider {provider}", file=f) + nwp_provider = nwp[provider] + + # plot nwp main data + nwp_data = nwp_provider[NWPBatchKey.nwp] + # average of lat and lon + nwp_data = nwp_data.mean(dim=(3, 4)) + + n_examples = nwp_data.shape[0] + if limit_examples is not None: + n_examples = min(n_examples, limit_examples) + + for b in range(n_examples): + + fig = go.Figure() + for i in range(len(nwp_provider[NWPBatchKey.nwp_channel_names])): + channel = nwp_provider[NWPBatchKey.nwp_channel_names][i] + nwp_data_one_channel = nwp_data[b, :, i] + time = nwp_provider[NWPBatchKey.nwp_target_time_utc][b] + time = pd.to_datetime(time, unit="s") + fig.add_trace( + go.Scatter(x=time, y=nwp_data_one_channel, mode="lines", name=channel) ) - # fig.show(renderer='browser') - name = f"gsp_{b}.png" - fig.write_image(name) - print(f"![]({name})") - print("\n") - elif isinstance(value, torch.Tensor): - print(f"shape {value.shape=}") - print(f"Max {value.max():.2f}") - print(f"Min {value.min():.2f}") - elif isinstance(value, int): - print(f"{value}") - else: - print(f"{value}") - - # TODO plot solar azimuth and elevation - - # NWP - print("## NWP \n") - - keys = [ - NWPBatchKey.nwp, - NWPBatchKey.nwp_target_time_utc, - NWPBatchKey.nwp_channel_names, - NWPBatchKey.nwp_step, - NWPBatchKey.nwp_t0_idx, - NWPBatchKey.nwp_init_time_utc, - ] - - nwp = batch[BatchKey.nwp] - - nwp_providers = nwp.keys() - for provider in nwp_providers: - print("\n") - print(f"### Provider {provider}") - nwp_provider = nwp[provider] - - # plot nwp main data - nwp_data = nwp_provider[NWPBatchKey.nwp] - # average of lat and lon - nwp_data = nwp_data.mean(dim=(3, 4)) - - for b in range(nwp_data.shape[0]): - - fig = go.Figure() - for i in range(len(nwp_provider[NWPBatchKey.nwp_channel_names])): - channel = nwp_provider[NWPBatchKey.nwp_channel_names][i] - nwp_data_one_channel = nwp_data[b, :, i] - time = nwp_provider[NWPBatchKey.nwp_target_time_utc][b] - time = pd.to_datetime(time, unit="s") - fig.add_trace( - go.Scatter(x=time, y=nwp_data_one_channel, mode="lines", name=channel) + fig.update_layout( + title=f"{provider} NWP - example {b}", xaxis_title="Time", yaxis_title="Value" ) - - fig.update_layout( - title=f"{provider} NWP - example {b}", xaxis_title="Time", yaxis_title="Value" - ) - # fig.show(renderer='browser') - name = f"{provider}_nwp_{b}.png" - fig.write_image(name) - print(f"![]({name})") - print("\n") + # fig.show(renderer='browser') + name = f"nwp/{provider}_nwp_{b}.png" + fig.write_image(f"{folder}/{name}") + print(f"![](./{name})", file=f) + print("\n", file=f) + + for key in keys: + print("\n", file=f) + print(f"#### {key.name}", file=f) + value = nwp_provider[key] + + if "time" in key.name: + + # make a table with example, shape, max, min + print("| Example | Shape | Max | Min |", file=f) + print("| --- | --- | --- | --- |", file=f) + + for example_id in range(n_examples): + value_ts = pd.to_datetime(value[example_id], unit="s") + print( + f"| {example_id} | {len(value_ts)} | {value_ts.max()} | {value_ts.min()} |", + file=f, + ) + + elif "channel" in key.name: + + # create a table with the channel names with max, min, mean and std + print("| Channel | Max | Min | Mean | Std |", file=f) + print("| --- | --- | --- | --- | --- |", file=f) + for i in range(len(value)): + channel = value[i] + data = nwp_data[:, :, i] + print( + f"| {channel} " + f"| {data.max().item():.2f} " + f"| {data.min().item():.2f} " + f"| {data.mean().item():.2f} " + f"| {data.std().item():.2f} |", + file=f, + ) + + print(f"Shape={value.shape}", file=f) + + elif isinstance(value, torch.Tensor): + print(f"Shape {value.shape=}", file=f) + print(f"Max {value.max():.2f}", file=f) + print(f"Min {value.min():.2f}", file=f) + elif isinstance(value, int): + print(f"{value}", file=f) + else: + print(f"{value}", file=f) + + # Satellite + print("## Satellite \n", file=f) + keys = [ + BatchKey.satellite_actual, + BatchKey.satellite_t0_idx, + BatchKey.satellite_time_utc, + BatchKey.satellite_time_utc, + BatchKey.satellite_x_geostationary, + BatchKey.satellite_y_geostationary, + ] for key in keys: - print("\n") - print(f"#### {key.name}") - value = nwp_provider[key] - if "time" in key.name: + print("\n", file=f) + print(f"#### {key.name}", file=f) + value = batch[key] + + if "satellite_actual" in key.name: + + print(value.shape, file=f) + + # average of lat and lon + value = value.mean(dim=(3, 4)) + + n_examples = value.shape[0] + if limit_examples is not None: + n_examples = min(n_examples, limit_examples) + + for b in range(n_examples): + + fig = go.Figure() + for i in range(value.shape[2]): + satellite_data_one_channel = value[b, :, i] + time = batch[BatchKey.satellite_time_utc][b] + time = pd.to_datetime(time, unit="s") + fig.add_trace( + go.Scatter(x=time, y=satellite_data_one_channel, mode="lines") + ) + + fig.update_layout( + title=f"Satellite - example {b}", xaxis_title="Time", yaxis_title="Value" + ) + # fig.show(renderer='browser') + name = f"satellite/satellite_{b}.png" + fig.write_image(f"{folder}/{name}") + print(f"![](./{name})", file=f) + print("\n", file=f) + + elif "time" in key.name: # make a table with example, shape, max, min - print("| Example | Shape | Max | Min |") - print("| --- | --- | --- | --- |") + print("| Example | Shape | Max | Min |", file=f) + print("| --- | --- | --- | --- |", file=f) - for example_id in range(value.shape[0]): + for example_id in range(n_examples): value_ts = pd.to_datetime(value[example_id], unit="s") print( - f"| {example_id} | {len(value_ts)} | {value_ts.max()} | {value_ts.min()} |" + f"| {example_id} | {len(value_ts)} | {value_ts.max()} | {value_ts.min()} |", + file=f, ) elif "channel" in key.name: # create a table with the channel names with max, min, mean and std - print("| Channel | Max | Min | Mean | Std |") - print("| --- | --- | --- | --- | --- |") + print("| Channel | Max | Min | Mean | Std |", file=f) + print("| --- | --- | --- | --- | --- |", file=f) for i in range(len(value)): channel = value[i] data = nwp_data[:, :, i] @@ -167,98 +272,20 @@ def visualize_batch(batch: NumpyBatch): f"| {data.max().item():.2f} " f"| {data.min().item():.2f} " f"| {data.mean().item():.2f} " - f"| {data.std().item():.2f} |" + f"| {data.std().item():.2f} |", + file=f, ) - print(f"Shape={value.shape}") + print(f"Shape={value.shape}", file=f) elif isinstance(value, torch.Tensor): - print(f"Shape {value.shape=}") - print(f"Max {value.max():.2f}") - print(f"Min {value.min():.2f}") + print(f"Shape {value.shape=}", file=f) + print(f"Max {value.max():.2f}", file=f) + print(f"Min {value.min():.2f}", file=f) elif isinstance(value, int): - print(f"{value}") + print(f"{value}", file=f) else: - print(f"{value}") - - # Satellite - print("## Satellite \n") - keys = [ - BatchKey.satellite_actual, - BatchKey.satellite_t0_idx, - BatchKey.satellite_time_utc, - BatchKey.satellite_time_utc, - BatchKey.satellite_x_geostationary, - BatchKey.satellite_y_geostationary, - ] - - for key in keys: - - print("\n") - print(f"#### {key.name}") - value = batch[key] - - if "satellite_actual" in key.name: - - print(value.shape) - - # average of lat and lon - value = value.mean(dim=(3, 4)) - - for b in range(value.shape[0]): - - fig = go.Figure() - for i in range(value.shape[2]): - satellite_data_one_channel = value[b, :, i] - time = batch[BatchKey.satellite_time_utc][b] - time = pd.to_datetime(time, unit="s") - fig.add_trace(go.Scatter(x=time, y=satellite_data_one_channel, mode="lines")) - - fig.update_layout( - title=f"Satellite - example {b}", xaxis_title="Time", yaxis_title="Value" - ) - # fig.show(renderer='browser') - name = f"satellite_{b}.png" - fig.write_image(name) - print(f"![]({name})") - print("\n") - - elif "time" in key.name: - - # make a table with example, shape, max, min - print("| Example | Shape | Max | Min |") - print("| --- | --- | --- | --- |") - - for example_id in range(value.shape[0]): - value_ts = pd.to_datetime(value[example_id], unit="s") - print(f"| {example_id} | {len(value_ts)} | {value_ts.max()} | {value_ts.min()} |") - - elif "channel" in key.name: - - # create a table with the channel names with max, min, mean and std - print("| Channel | Max | Min | Mean | Std |") - print("| --- | --- | --- | --- | --- |") - for i in range(len(value)): - channel = value[i] - data = nwp_data[:, :, i] - print( - f"| {channel} " - f"| {data.max().item():.2f} " - f"| {data.min().item():.2f} " - f"| {data.mean().item():.2f} " - f"| {data.std().item():.2f} |" - ) - - print(f"Shape={value.shape}") - - elif isinstance(value, torch.Tensor): - print(f"Shape {value.shape=}") - print(f"Max {value.max():.2f}") - print(f"Min {value.min():.2f}") - elif isinstance(value, int): - print(f"{value}") - else: - print(f"{value}") + print(f"{value}", file=f) # For example you can run it like this