Skip to content

Commit

Permalink
Dev icenet-ai#186: mapping of data through the entire chain to plotti…
Browse files Browse the repository at this point in the history
…ng of outputs - some work still required
  • Loading branch information
JimCircadian committed Jul 25, 2024
1 parent 95d0616 commit b46dd0e
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 111 deletions.
11 changes: 5 additions & 6 deletions icenet/data/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SplittingMixin:
>>> split_dataset = SplittingMixin()
# Add file paths to the train, validation, and test datasets
>>> split_dataset.add_records(base_path="./network_datasets/notebook_data/", hemi="south")
>>> split_dataset.add_records(base_path="./network_datasets/notebook_data/")
"""
_batch_size: int
_dtype: object
Expand All @@ -71,22 +71,21 @@ class SplittingMixin:
test_fns = []
val_fns = []

def add_records(self, base_path: str, hemi: str) -> None:
def add_records(self, base_path: str) -> None:
"""Add list of paths to train, val, test *.tfrecord(s) to relevant instance attributes.
Add sorted list of file paths to train, validation, and test datasets in SplittingMixin.
Args:
base_path (str): The base path where the datasets are located.
hemi (str): The hemisphere the datasets correspond to.
Returns:
None. Updates `self.train_fns`, `self.val_fns`, `self.test_fns` with list
of *.tfrecord files.
"""
train_path = os.path.join(base_path, hemi, "train")
val_path = os.path.join(base_path, hemi, "val")
test_path = os.path.join(base_path, hemi, "test")
train_path = os.path.join(base_path, "train")
val_path = os.path.join(base_path, "val")
test_path = os.path.join(base_path, "test")

logging.info("Training dataset path: {}".format(train_path))
self.train_fns += sorted(glob.glob("{}/*.tfrecord".format(train_path)))
Expand Down
7 changes: 4 additions & 3 deletions icenet/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def create_get_args() -> object:
ap.add_argument("-fl",
"--forecast-length",
dest="forecast_length",
default=6,
default=None,
type=int)

ap.add_argument("-i",
"--implementation",
type=str,
choices=implementations,
default=implementations[0])
ap.add_argument("-l", "--lag", type=int, default=2)
ap.add_argument("-l", "--lag", type=int, default=None)

ap.add_argument("-ob",
"--output-batch-size",
Expand Down Expand Up @@ -99,7 +99,8 @@ def create_network_dataset():
args.loader_configuration,
args.network_dataset_name,
dry=args.dry,
n_forecast_days=args.forecast_length,
lag_time=args.lag,
lead_time=args.forecast_length,
output_batch_size=args.batch_size,
pickup=args.pickup,
generate_workers=args.workers,
Expand Down
141 changes: 92 additions & 49 deletions icenet/data/loaders/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import datetime as dt
import json
import logging
import os
from abc import abstractmethod

from pprint import pformat

import orjson
import numpy as np

from download_toolbox.interface import DataCollection
from download_toolbox.interface import DataCollection, get_dataset_config_implementation
from preprocess_toolbox.interface import get_processor_from_source

"""
"""
Expand All @@ -19,7 +21,7 @@
class IceNetBaseDataLoader(DataCollection):
"""
:param configuration_path,
:param loader_configuration,
:param identifier,
:param var_lag,
:param dataset_config_path:
Expand All @@ -32,16 +34,16 @@ class IceNetBaseDataLoader(DataCollection):
"""

def __init__(self,
configuration_path: str,
loader_configuration: str,
identifier: str,
var_lag: int,
*args,
dataset_config_path: str = ".",
dates_override: object = None,
dry: bool = False,
generate_workers: int = 8,
lag_time: int = None,
lead_time: int = None,
loss_weight_days: bool = True,
n_forecast_days: int = 93,
output_batch_size: int = 32,
path: str = os.path.join(".", "network_datasets"),
pickup: bool = False,
Expand All @@ -52,35 +54,78 @@ def __init__(self,
self._channels = dict()
self._channel_files = dict()

self._configuration_path = configuration_path
self._configuration_path = loader_configuration
self._dataset_config_path = dataset_config_path
self._dates_override = dates_override
self._config = dict()
self._dry = dry
self._loss_weight_days = loss_weight_days
self._meta_channels = []
self._missing_dates = []
self._n_forecast_days = n_forecast_days
self._output_batch_size = output_batch_size
self._pickup = pickup
self._trend_steps = dict()
self._workers = generate_workers

self._var_lag = var_lag
self._load_configuration(loader_configuration)

# TODO: we assume that ground truth is the first dataset in the ordering
ground_truth_id, ground_truth_cfg = list(self._config["sources"].items())[0]
processor = get_processor_from_source(ground_truth_id, ground_truth_cfg)
ds_config = get_dataset_config_implementation(processor.dataset_config)
# TODO: this is smelly, it suggests there is missing logic between Processor and
# NormalisingChannelProcessor to handle suffixes
ref_ds = processor.get_dataset(["{}_abs".format(el) for el in processor.abs_vars])
ref_da = getattr(ref_ds.isel(time=0), list(ref_ds.data_vars)[0])

# Things that come from preprocessing by default
self._dtype = ref_da.dtype
# TODO: we shouldn't ideally need this but we do need a concept of location for masks
self._ds_config = processor.dataset_config
self._frequency_attr = ds_config.frequency.attribute
self._lag_time = lag_time if lag_time is not None else processor.lag_time
self._lead_time = lead_time if lead_time is not None else processor.lead_time
self._north = ds_config.location.north
self._shape = ref_da.shape
self._south = ds_config.location.south
self._var_lag_override = dict() \
if not var_lag_override else var_lag_override

self._load_configuration(configuration_path)
self._construct_channels()

self._dtype = getattr(np, self._config["dtype"])
self._shape = tuple(self._config["shape"])
self._missing_dates = []
# # TODO: format needs to be picked up from dataset frequencies
# dt.datetime.strptime(s, DATE_FORMAT)
# for s in self._config["missing_dates"]
#]

self._missing_dates = [
# TODO: format needs to be picked up from dataset frequencies
dt.datetime.strptime(s, DATE_FORMAT)
for s in self._config["missing_dates"]
]
def get_data_var_folder(self,
var_name: str,
append: object = None,
missing_error: bool = False) -> os.PathLike:
"""Returns the path for a specific data variable.
Appends additional folders to the path if specified in the `append` parameter.
:param var_name: The data variable.
:param append: Additional folders to append to the path.
:param missing_error: Flag to specify if missing directories should be treated as an error.
:return str: The path for the specific data variable.
"""
if not append:
append = []

data_var_path = os.path.join(self.path, *[var_name, *append])

if not os.path.exists(data_var_path):
if not missing_error:
os.makedirs(data_var_path, exist_ok=True)
else:
raise OSError("Directory {} is missing and this is "
"flagged as an error!".format(data_var_path))

return data_var_path

def write_dataset_config_only(self):
"""
Expand Down Expand Up @@ -167,23 +212,26 @@ def _construct_channels(self):
"""
# As of Python 3.7 dict guarantees the order of keys based on
# original insertion order, which is great for this method
attr_map = dict(
abs="absolute_vars",
anom="anomoly_vars",
linear_trend="linear_trends"
)
lag_vars = [
(identity, var, data_format)
for data_format in ("abs", "anom")
for identity in sorted(self._config["sources"].keys())
for var in sorted(self._config["sources"][identity][data_format])
for var in sorted(self._config["sources"][identity][attr_map[data_format]])
]

for identity, var_name, data_format in lag_vars:
var_prefix = "{}_{}".format(var_name, data_format)
var_lag = (self._var_lag if var_name not in self._var_lag_override
var_lag = (self._lag_time
if var_name not in self._var_lag_override
else self._var_lag_override[var_name])

self._channels[var_prefix] = int(var_lag)
self._add_channel_files(var_prefix, [
el for el in self._config["sources"][identity]["var_files"]
[var_name] if var_prefix in os.path.split(el)[1]
])
self._channels[var_prefix] = int(var_lag) + 1
self._add_channel_files(var_prefix, self._config["sources"][identity]["processed_files"][var_prefix])

trend_names = [(identity, var,
self._config["sources"][identity]["linear_trend_steps"])
Expand All @@ -196,26 +244,16 @@ def _construct_channels(self):

self._channels[var_prefix] = len(trend_steps)
self._trend_steps[var_prefix] = trend_steps
filelist = [
el for el in self._config["sources"][identity]["var_files"]
[var_name] if "linear_trend" in os.path.split(el)[1]
]
self._add_channel_files(var_prefix,
self._config["sources"][identity]["processed_files"][var_prefix])

self._add_channel_files(var_prefix, filelist)

# Metadata input variables that don't span time
meta_names = [
(identity, var)
for identity in sorted(self._config["sources"].keys())
for var in sorted(self._config["sources"][identity]["meta"])
]

for identity, var_name in meta_names:
# Meta channels
for var_name, meta_channel in self._config["channels"].items():
self._meta_channels.append(var_name)
self._channels[var_name] = 1
self._add_channel_files(
var_name,
self._config["sources"][identity]["var_files"][var_name])
meta_channel["files"])

logging.debug(
"Channel quantities deduced:\n{}\n\nTotal channels: {}".format(
Expand Down Expand Up @@ -249,9 +287,9 @@ def _load_configuration(self, path: str):
logging.info("Loading configuration {}".format(path))

with open(path, "r") as fh:
obj = json.load(fh)
obj = orjson.loads(fh.read())

self._config.update(obj)
self._config.update(obj)
else:
raise OSError("{} not found".format(path))

Expand Down Expand Up @@ -281,13 +319,11 @@ def _serialize(x):
for i in range(1, s + 1)
],
"counts": counts,
"dtype": self._dtype.__name__,
"dtype": str(self._dtype),
"loader_config": os.path.abspath(self._configuration_path),
"missing_dates": [
date.strftime(DATE_FORMAT)
for date in self._missing_dates
],
"n_forecast_days": self._n_forecast_days,
"missing_dates": self._missing_dates,
"lag_time": self._lag_time,
"lead_time": self._lead_time,
"north": self.north,
"num_channels": self.num_channels,
# FIXME: this naming is inconsistent, sort it out!!! ;)
Expand All @@ -300,7 +336,6 @@ def _serialize(x):
"generate_workers": self.workers,
"loss_weight_days": self._loss_weight_days,
"output_batch_size": self._output_batch_size,
"var_lag": self._var_lag,
"var_lag_override": self._var_lag_override,
}

Expand All @@ -311,7 +346,7 @@ def _serialize(x):
logging.info("Writing configuration to {}".format(output_path))

with open(output_path, "w") as fh:
json.dump(configuration, fh, indent=4, default=_serialize)
fh.write(orjson.dumps(configuration, option=orjson.OPT_INDENT_2).decode())

@property
def channel_names(self):
Expand All @@ -329,6 +364,10 @@ def config(self):
def dates_override(self):
return self._dates_override

@property
def north(self):
return self._north

@property
def num_channels(self):
return sum(self._channels.values())
Expand All @@ -337,6 +376,10 @@ def num_channels(self):
def pickup(self):
return self._pickup

@property
def south(self):
return self._south

@property
def workers(self):
return self._workers
Loading

0 comments on commit b46dd0e

Please sign in to comment.