diff --git a/train/run.py b/train/run.py index ac4e92c..657db5b 100644 --- a/train/run.py +++ b/train/run.py @@ -1,3 +1,4 @@ +import torch.utils.data.dataset from dgmr import DGMR from datasets import load_dataset from torch.utils.data import DataLoader @@ -7,8 +8,9 @@ from pytorch_lightning.callbacks import ModelCheckpoint #import wandb #wandb.init(project="dgmr") +import os from pathlib import Path - +import tensorflow as tf from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only @@ -104,6 +106,136 @@ def on_train_epoch_end(self, trainer, pl_module): experiment.log_artifact(ckpts) +_FEATURES = {name: tf.io.FixedLenFeature([], dtype) + for name, dtype in [ + ("radar", tf.string), ("sample_prob", tf.float32), + ("osgb_extent_top", tf.int64), ("osgb_extent_left", tf.int64), + ("osgb_extent_bottom", tf.int64), ("osgb_extent_right", tf.int64), + ("end_time_timestamp", tf.int64), + ]} + +_SHAPE_BY_SPLIT_VARIANT = { + ("train", "random_crops_256"): (24, 256, 256, 1), + ("valid", "subsampled_tiles_256_20min_stride"): (24, 256, 256, 1), + ("test", "full_frame_20min_stride"): (24, 1536, 1280, 1), + ("test", "subsampled_overlapping_padded_tiles_512_20min_stride"): (24, 512, 512, 1), +} + +_MM_PER_HOUR_INCREMENT = 1/32. +_MAX_MM_PER_HOUR = 128. +_INT16_MASK_VALUE = -1 + + +def parse_and_preprocess_row(row, split, variant): + result = tf.io.parse_example(row, _FEATURES) + shape = _SHAPE_BY_SPLIT_VARIANT[(split, variant)] + radar_bytes = result.pop("radar") + radar_int16 = tf.reshape(tf.io.decode_raw(radar_bytes, tf.int16), shape) + mask = tf.not_equal(radar_int16, _INT16_MASK_VALUE) + radar = tf.cast(radar_int16, tf.float32) * _MM_PER_HOUR_INCREMENT + radar = tf.clip_by_value( + radar, _INT16_MASK_VALUE * _MM_PER_HOUR_INCREMENT, _MAX_MM_PER_HOUR) + result["radar_frames"] = radar + result["radar_mask"] = mask + return result + + +DATASET_ROOT_DIR = "gs://dm-nowcasting-example-data/datasets/nowcasting_open_source_osgb/nimrod_osgb_1000m_yearly_splits/radar/20200718" + + +def reader(split="train", variant="random_crops_256", shuffle_files=False): + """Reader for open-source nowcasting datasets. + + Args: + split: Which yearly split of the dataset to use: + "train": Data from 2016 - 2018, excluding the first day of each month. + "valid": Data from 2016 - 2018, only the first day of the month. + "test": Data from 2019. + variant: Which variant to use. The available variants depend on the split: + "random_crops_256": Available for the training split. 24x256x256 pixel + crops, sampled with a bias towards crops containing rainfall. Crops at + all spatial and temporal offsets were able to be sampled, some crops may + overlap. + "subsampled_tiles_256_20min_stride": Available for the validation set. + Non-spatially-overlapping 24x256x256 pixel crops, subsampled from a + regular spatial grid with stride 256x256 pixels, and a temporal stride + of 20mins (4 timesteps at 5 minute resolution). Sampling favours crops + containing rainfall. + "subsampled_overlapping_padded_tiles_512_20min_stride": Available for the + test set. Overlapping 24x512x512 pixel crops, subsampled from a + regular spatial grid with stride 64x64 pixels, and a temporal stride + of 20mins (4 timesteps at 5 minute resolution). Subsampling favours + crops containing rainfall. + These crops include extra spatial context for a fairer evaluation of + the PySTEPS baseline, which benefits from this extra context. Our other + models only use the central 256x256 pixels of these crops. + "full_frame_20min_stride": Available for the test set. Includes full + frames at 24x1536x1280 pixels, every 20 minutes with no additional + subsampling. + shuffle_files: Whether to shuffle the shard files of the dataset + non-deterministically before interleaving them. Recommended for the + training set to improve mixing and read performance (since + non-deterministic parallel interleave is then enabled). + + Returns: + A tf.data.Dataset whose rows are dicts with the following keys: + + "radar_frames": Shape TxHxWx1, float32. Radar-based estimates of + ground-level precipitation, in units of mm/hr. Pixels which are masked + will take on a value of -1/32 and should be excluded from use as + evaluation targets. The coordinate reference system used is OSGB36, with + a spatial resolution of 1000 OSGB36 coordinate units (approximately equal + to 1km). The temporal resolution is 5 minutes. + "radar_mask": Shape TxHxWx1, bool. A binary mask which is False + for pixels that are unobserved / unable to be inferred from radar + measurements (e.g. due to being too far from a radar site). This mask + is usually static over time, but occasionally a whole radar site will + drop in or out resulting in large changes to the mask, and more localised + changes can happen too. + "sample_prob": Scalar float. The probability with which the row was + sampled from the overall pool available for sampling, as described above + under 'variants'. We use importance weights proportional to 1/sample_prob + when computing metrics on the validation and test set, to reduce bias due + to the subsampling. + "end_time_timestamp": Scalar int64. A timestamp for the final frame in + the example, in seconds since the UNIX epoch (1970-01-01 00:00:00 UTC). + "osgb_extent_left", "osgb_extent_right", "osgb_extent_top", + "osgb_extent_bottom": + Scalar int64s. Spatial extent for the crop in the OSGB36 coordinate + reference system. + """ + shards_glob = os.path.join(DATASET_ROOT_DIR, split, variant, "*.tfrecord.gz") + shard_paths = tf.io.gfile.glob(shards_glob) + shards_dataset = tf.data.Dataset.from_tensor_slices(shard_paths) + if shuffle_files: + shards_dataset = shards_dataset.shuffle(buffer_size=len(shard_paths)) + return ( + shards_dataset + .interleave(lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP"), + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=not shuffle_files) + .map(lambda row: parse_and_preprocess_row(row, split, variant), + num_parallel_calls=tf.data.AUTOTUNE) + # Do your own subsequent repeat, shuffle, batch, prefetch etc as required. + ) +NUM_INPUT_FRAMES = 4 +NUM_TARGET_FRAMES = 18 +def extract_input_and_target_frames(radar_frames): + """Extract input and target frames from a dataset row's radar_frames.""" + # We align our targets to the end of the window, and inputs precede targets. + input_frames = radar_frames[-NUM_TARGET_FRAMES-NUM_INPUT_FRAMES : -NUM_TARGET_FRAMES] + target_frames = radar_frames[-NUM_TARGET_FRAMES : ] + return input_frames, target_frames + +class TFDataset(torch.utils.data.dataset.IterableDataset): + def __init__(self, split, variant): + super().__init__() + self.reader = iter(reader(split,variant)) + + def __iter__(self): + row = next(self.reader) + input_frames, target_frames = extract_input_and_target_frames(row["radar_frames"]) + yield input_frames, target_frames class DGMRDataModule(LightningDataModule): """ @@ -144,18 +276,20 @@ def __init__( ) def train_dataloader(self): - train_dataset = load_dataset("openclimatefix/nimrod-uk-1km", "sample", split="train", streaming=False) - train_dataset.set_format( - type="torch", columns=["radar_frames", "radar_mask", "sample_prob"] - ) + train_dataset = TFDataset(split="train", variant="random_crops_256") + #train_dataset = load_dataset("openclimatefix/nimrod-uk-1km", "sample", split="train", streaming=False) + #train_dataset.set_format( + # type="torch", columns=["radar_frames", "radar_mask", "sample_prob"] + #) dataloader = DataLoader(train_dataset, batch_size=1) return dataloader def val_dataloader(self): - train_dataset = load_dataset("openclimatefix/nimrod-uk-1km", "sample", split="val", streaming=False) - train_dataset.set_format( - type="torch", columns=["radar_frames", "radar_mask", "sample_prob"] - ) + #train_dataset = load_dataset("openclimatefix/nimrod-uk-1km", "sample", split="val", streaming=False) + #train_dataset.set_format( + # type="torch", columns=["radar_frames", "radar_mask", "sample_prob"] + #) + train_dataset = TFDataset(split="valid", variant="subsampled_tiles_256_20min_stride") dataloader = DataLoader(train_dataset, batch_size=1) return dataloader