Skip to content

Commit

Permalink
Splits for HSCDataSet (#105)
Browse files Browse the repository at this point in the history
- HSCDataSet has been significantly modified to implement splits
- Constructors for fibad_data_sets now take a split argument
- New configs added in new "prepare" section to define splits
- Some configs in "model" reorganized to a "train" section that
  is similar to the "prepare" "predict" and "download" sections
  in that it configures the action more than anything else.
- Added "split" config to train and predict sections to select
  the split that will be used.
-  Added Tests for new HSCDataSet split functionality.
  • Loading branch information
mtauraso authored Oct 25, 2024
1 parent b504b62 commit ea42b17
Show file tree
Hide file tree
Showing 7 changed files with 520 additions and 31 deletions.
10 changes: 8 additions & 2 deletions src/fibad/data_sets/example_cifar_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@ class CifarDataSet(CIFAR10):
FIBAD config with a transformation that works well for example code.
"""

def __init__(self, config):
def __init__(self, config, split: str):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
super().__init__(root=config["general"]["data_dir"], train=True, download=True, transform=transform)

if split not in ["train", "test"]:
RuntimeError("CIFAR10 dataset only supports 'train' and 'test' splits.")

train = split == "train"

super().__init__(root=config["general"]["data_dir"], train=train, download=True, transform=transform)

def shape(self):
return (3, 32, 32)
227 changes: 225 additions & 2 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import logging
import re
from copy import copy, deepcopy
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

import numpy as np
import torch
Expand All @@ -18,6 +19,228 @@

@fibad_data_set
class HSCDataSet(Dataset):
"""Interface object to allow simple access to splits on a corpus of HSC data files
f/s operations and management are handled in HSCDatSetContainer
splits on the dataset and their generation are handled by HSCDataSetSplit
"""

def __init__(self, config, split: Union[str, None]):
# initialize the filesystem references
self.container = HSCDataSetContainer(config)

# initalize our splits from configuration
self._create_splits(config)

# Set the split to what was requested.
self._set_split(split)

def _create_splits(self, config):
seed = config["prepare"]["seed"] if config["prepare"]["seed"] else None

# Init the splits based on config values
train_size = config["prepare"]["train_size"] if config["prepare"]["train_size"] else None
test_size = config["prepare"]["test_size"] if config["prepare"]["test_size"] else None
validate_size = config["prepare"]["validate_size"] if config["prepare"]["validate_size"] else None

# Convert all values specified as counts into ratios of the underlying container
if isinstance(train_size, int):
train_size = train_size / len(self.container)
if isinstance(test_size, int):
test_size = test_size / len(self.container)
if isinstance(validate_size, int):
validate_size = validate_size / len(self.container)

# Fill in any values not provided
if test_size is None:
if train_size is None:
train_size = 0.25
test_size = 1.0 - train_size
elif train_size is None:
train_size = 1.0 - test_size
elif validate_size is None:
validate_size = 1.0 - (train_size + test_size)

# Generate splits
self.splits = {}
self.splits["test"] = HSCDataSetSplit(self.container, test_size, seed=seed)
rest = copy(self.splits["test"]).complement()
self.splits["train"] = HSCDataSetSplit(rest, train_size, seed=seed)

# Validate is only generated if it is provided, or if both test and train are provided.
if validate_size:
rest = rest.logical_and(copy(self.splits["train"]).complement())
self.splits["validate"] = HSCDataSetSplit(rest, validate_size, seed=seed)

logger.info("HSC Data Set Splits loaded are:")
for key, value in self.splits.items():
logger.info(f"{key} split contains {len(value)} items")

def _set_split(self, split: Union[str, None] = None):
self.current_split = self.splits.get(split, self.container)

if split is not None and self.current_split == self.container:
splits = list(self.splits.keys())
raise RuntimeError(f"Split {split} does not exist. valid split names are {splits}")

def shape(self) -> tuple[int, int, int]:
return self.container.shape()

def __getitem__(self, idx: int) -> torch.Tensor:
return self.current_split[idx]

def __len__(self) -> int:
return len(self.current_split)


class HSCDataSetSplit(Dataset):
def __init__(
self,
data: Union["HSCDataSetContainer", "HSCDataSetSplit"],
ratio: float,
seed: Union[int, None] = None,
):
"""
This class represents a split of an HSCDataset.
It should only get created by passing in an existing HSCDataSetContainer (or HSCDataSetSplit)
and splitting it according to the train_test_split like parameters. When you split a split,
all splits end up referring to the same uderlying HSCDataSetContainer object.
Each encodes a subset of the underlying HSCDataSetContainer by keeping a list of boolean values.
Parameters
----------
data : Union[HSCDataSetContainer, "HSCDataSetSplit"]
The underlying HSCDataSet or split to operate on. Creating a split from an existing split ends up
referring to a subset of the data selected by the original split, but the new object only refers
to an underlying HSCDataSet object, not any other split object.
ratio : float
Ratio of the underlying data source to use for this split. This is expressed as a fraction of the
HSCDataSetContainer even when an HSCDataSetSplit is passed.
seed : Union[int, None] , optional
The seed value to provide to the random number generator, or None if you would like to use system
entropy to generate a seed. None by default.
shuffle : bool, optional
Whether to shuffle the order of the underlying data when accessing the split object, by default
True
"""
self.rng = np.random.default_rng(seed)

if ratio > 1.0 or ratio < 0.0:
msg = f"Split provided for HSCDatSetSplit as a ratio is {ratio}, which is not between 0.0 and 1.0"
raise RuntimeError(msg)

self.data = data.data if isinstance(data, HSCDataSetSplit) else data

# The length of this split once constructed
length = int(np.round(len(self.data) * ratio))

if isinstance(data, HSCDataSetSplit):
# If we're splitting a split we need to modify the existing mask of the prior split
# Namely we switch some true values to false to more of the underlying dataset
split = data
self.mask = copy(split.mask)
remove_count = len(split) - length
self._flip_mask_values(remove_count, "true_to_false")

else:
# If we're splitting a normal hscdataset we generate a single mask with the appropriate values
self.mask = np.zeros(len(data), dtype=bool)
self._flip_mask_values(length, "false_to_true")

self.indexes = np.nonzero(self.mask)[0]

def _flip_mask_values(self, num: int, mode: Literal["false_to_true", "true_to_false"]):
"""
Private helper to flips some values of self.mask. The direction to flip is controlled by the
mode parameter. Either the function randomly finds `num` true values to flip to false, or `num` false
values to flip to true.
This function is used during object construction to create a set number of randomly selected true
values in the mask.
Parameters
----------
num : int
The number of values to flip
mode : Literal[&quot;false_to_true&quot;, &quot;true_to_false&quot;]
The mode to work in, either flipping True values false or the reverse
Raises
------
RuntimeError
It is a RuntimeError to try to flip more values than the mask has of that type.
"""
mask_tmp = np.logical_not(self.mask) if mode == "false_to_true" else self.mask
target_val = mode == "false_to_true"
target_indexes = np.nonzero(mask_tmp)[0]

if num > len(target_indexes):
msg_mode = mode.replace("_", " ")
num_tgt = len(target_indexes)
msg = f"Cannot flip {num} values {msg_mode} when only {num_tgt} {target_val} values exist in mask"
raise RuntimeError(msg)

change_indexes = self.rng.permutation(target_indexes)[:num]
for i in change_indexes:
self.mask[i] = target_val

def complement(self) -> "HSCDataSetSplit":
"""Mutates the split by inverting it with respect to the underlying dataset.
e.g. if you have an underlying dataset with 5 members, and indexes 1,2, and 4 are part of this split
The compliment would be a dataset selecting indexes 0 and 3.
"""
self.mask = np.logical_not(self.mask)
self.indexes = np.nonzero(self.mask)[0]
return self

def logical_and(self, obj: "HSCDataSetSplit") -> "HSCDataSetSplit":
"""Takes the logical and of this object and the passed in object. self is modified, the passed in
object is not
If the self object selects indicies 1,2 and 4 and the passed in object selects indicies 2, 4, and 0
the self object would be modified to select indicies 2, and 4 only.
It is a RuntimeError to and two split objects that do not reference the same underlying HSCDataSet
Parameters
----------
obj : HSCDataSetSplit
The object to and with
"""
if self.data != obj.data:
msg = "Tried to take logical and of two HSCDataSetSplits with different HSCDataSet objects"
raise RuntimeError(msg)

self.mask = np.logical_and(self.mask, obj.mask)
self.indexes = np.nonzero(self.mask)[0]
return self

def __copy__(self) -> "HSCDataSetSplit":
# Create a HSCDataSetSplit with no data selected, but the same data source as self
copy_object = HSCDataSetSplit(self.data, 0.0)

# Copy mask and indexes over
copy_object.mask = self.mask.copy()
copy_object.indexes = self.indexes.copy()

# Copy RNG state over.
copy_object.rng = deepcopy(self.rng)

return copy_object

def __len__(self) -> int:
return len(self.indexes)

def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[self.indexes[idx]]


class HSCDataSetContainer(Dataset):
def __init__(self, config):
# TODO: What will be a reasonable set of tranformations?
# For now tanh all the values so they end up in [-1,1]
Expand Down Expand Up @@ -280,7 +503,7 @@ def __getitem__(self, idx: int) -> torch.Tensor:
return self._object_id_to_tensor(object_id)

def __contains__(self, object_id: str) -> bool:
"""Allows you to do `object_id in dataset` queries
"""Allows you to do `object_id in dataset` queries. Used by testing code.
Parameters
----------
Expand Down
42 changes: 39 additions & 3 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,16 @@ mask = false
# e.g. "user_package.submodule.ExternalModel" or "ExampleAutoencoder"
name = "ExampleAutoencoder"

weights_filepath = "example_model.pth"
epochs = 10

base_channel_size = 32
latent_dim = 64

[train]
weights_filepath = "example_model.pth"
epochs = 10
# Set this to the path of a checkpoint file to resume, or continue training,
# from a checkpoint. Otherwise, set to false to start from scratch.
resume = false
split = "train"

[data_set]
# Name of the built-in data loader to use or the libpath to an external data loader
Expand Down Expand Up @@ -92,6 +93,41 @@ batch_size = 32
shuffle = true
num_workers = 2

[prepare]
# How to split the data between training and eval sets.
# The semantics are borrowed from scikit-learn's train-test-split, and HF Dataset's train-test-split function
# It is an error for these values to add to more than 1.0 as ratios or the size of the dataset if expressed
# as integers.

# train_size: Size of the train split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# train split.
# If `int`, represents the absolute number of train samples.
# If `false`, the value is automatically set to the complement of the test size.
train_size = 0.6

# validate_size: Size of the validation split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# train split.
# If `int`, represents the absolute number of train samples.
# If `false`, and both train_size and test_size are defined, the value is automatically set to the complement
# of the other two sizes summed.
# If `false`, and only one of the other sizes is defined, no validate split is created
validate_size = 0.2

# test_size: Size of the test split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# test split.
# If `int`, represents the absolute number of test samples.
# If `false`, the value is set to the complement of the train size.
# If `train_size` is also `false`, it will be set to `0.25`.
test_size = 0.6

# Number to seed with for generating a random split. False means the data will be seeded from
# a system source at runtime.
seed = false

[predict]
model_weights_file = false
batch_size = 32
split = "test"
3 changes: 2 additions & 1 deletion src/fibad/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def run(config: ConfigDict):
The parsed config file as a nested dict
"""

model, data_set = setup_model_and_dataset(config)
model, data_set = setup_model_and_dataset(config, split=config["predict"]["split"])
logger.info(f"data set has length {len(data_set)}")
data_loader = dist_data_loader(data_set, config)

# Create a results directory and dump our config there
Expand Down
10 changes: 6 additions & 4 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger = logging.getLogger(__name__)


def setup_model_and_dataset(config: ConfigDict) -> tuple:
def setup_model_and_dataset(config: ConfigDict, split: str) -> tuple:
"""
Construct the dataset and the model according to configuration.
Expand All @@ -27,6 +27,8 @@ def setup_model_and_dataset(config: ConfigDict) -> tuple:
----------
config : ConfigDict
The entire runtime config
split : str
The name of the split we want to use from the data set.
Returns
-------
Expand All @@ -35,7 +37,7 @@ def setup_model_and_dataset(config: ConfigDict) -> tuple:
"""
# Fetch data loader class specified in config and create an instance of it
data_set_cls = fetch_data_set_class(config)
data_set = data_set_cls(config)
data_set = data_set_cls(config, split)

# Fetch model class specified in config and create an instance of it
model_cls = fetch_model_class(config)
Expand Down Expand Up @@ -212,8 +214,8 @@ def neg_loss_score(engine):
greater_or_equal=True,
)

if config["model"]["resume"]:
prev_checkpoint = torch.load(config["model"]["resume"], map_location=device)
if config["train"]["resume"]:
prev_checkpoint = torch.load(config["train"]["resume"], map_location=device)
Checkpoint.load_objects(to_load=to_save, checkpoint=prev_checkpoint)

@trainer.on(Events.STARTED)
Expand Down
6 changes: 3 additions & 3 deletions src/fibad/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ def run(config):
results_dir = create_results_dir(config, "train")
log_runtime_config(config, results_dir)

model, data_set = setup_model_and_dataset(config)
model, data_set = setup_model_and_dataset(config, split=config["train"]["split"])
data_loader = dist_data_loader(data_set, config)

# Create trainer, a pytorch-ignite `Engine` object
trainer = create_trainer(model, config, results_dir)

# Run the training process
trainer.run(data_loader, max_epochs=config["model"]["epochs"])
trainer.run(data_loader, max_epochs=config["train"]["epochs"])

# Save the trained model
model.save(results_dir / config["model"]["weights_filepath"])
model.save(results_dir / config["train"]["weights_filepath"])

logger.info("Finished Training")
Loading

0 comments on commit ea42b17

Please sign in to comment.