Skip to content

Commit

Permalink
add clone method to datasets (#2625)
Browse files Browse the repository at this point in the history
Summary:

This makes it far easier to obtain slices of different kinds of datasets (Supervised, MultiTask, Contextual), which will be helpful for things like doing LOOCV MBM in Ax.

Reviewed By: saitcakmak

Differential Revision: D65616941
  • Loading branch information
sdaulton authored and facebook-github-bot committed Nov 18, 2024
1 parent 3c2ce15 commit 7fd1cda
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 57 deletions.
14 changes: 14 additions & 0 deletions botorch/utils/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

from __future__ import annotations

import dataclasses

from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any

import torch

from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor


Expand Down Expand Up @@ -102,6 +106,9 @@ def _validate(self) -> None:
f"`event shape` {self.event_shape}."
)

def clone(self) -> DenseContainer:
return dataclasses.replace(self)


@dataclass(eq=False)
class SliceContainer(BotorchContainer):
Expand Down Expand Up @@ -149,3 +156,10 @@ def _validate(self) -> None:
f"Shapes of `values` {values.shape} and `indices` "
f"{indices.shape} incompatible with `event_shape` {event_shape}."
)

def clone(self) -> SliceContainer:
return type(self)(
values=self.values.clone(),
indices=self.indices.clone(),
event_shape=torch.Size(self.event_shape),
)
112 changes: 111 additions & 1 deletion botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from __future__ import annotations

import copy

from typing import Any

import torch
Expand Down Expand Up @@ -70,6 +72,7 @@ def __init__(
self._Yvar = Yvar
self.feature_names = feature_names
self.outcome_names = outcome_names
self.validate_init = validate_init
if validate_init:
self._validate()

Expand Down Expand Up @@ -147,6 +150,52 @@ def __eq__(self, other: Any) -> bool:
and self.outcome_names == other.outcome_names
)

def clone(
self, deepcopy: bool = False, mask: Tensor | None = None
) -> SupervisedDataset:
"""Return a copy of the dataset.
Args:
deepcopy: If True, perform a deep copy. Otherwise, use the same
tensors/lists.
mask: A `n`-dim boolean mask indicating which rows to keep. This is used
along the -2 dimension.
Returns:
The new dataset.
"""
new_X = self._X
new_Y = self._Y
new_Yvar = self._Yvar
feature_names = self.feature_names
outcome_names = self.outcome_names
if mask is not None:
if any(isinstance(x, BotorchContainer) for x in [new_X, new_Y, new_Yvar]):
raise NotImplementedError(
"Masking is not supported for BotorchContainers."
)
new_X = new_X[..., mask, :]
new_Y = new_Y[..., mask, :]
if new_Yvar is not None:
new_Yvar = new_Yvar[..., mask, :]
if deepcopy:
new_X = new_X.clone()
new_Y = new_Y.clone()
new_Yvar = new_Yvar.clone() if new_Yvar is not None else None
feature_names = copy.copy(self.feature_names)
outcome_names = copy.copy(self.outcome_names)
kwargs = {}
if new_Yvar is not None:
kwargs = {"Yvar": new_Yvar}
return type(self)(
X=new_X,
Y=new_Y,
feature_names=feature_names,
outcome_names=outcome_names,
validate_init=self.validate_init,
**kwargs,
)


class RankingDataset(SupervisedDataset):
r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations
Expand Down Expand Up @@ -339,7 +388,7 @@ def from_joint_dataset(
outcome_names=[outcome_name],
)
datasets.append(new_dataset)
# Return the new
# Return the new dataset
return cls(
datasets=datasets,
target_outcome_name=outcome_names_per_task.get(
Expand Down Expand Up @@ -466,6 +515,37 @@ def __eq__(self, other: Any) -> bool:
and self.task_feature_index == other.task_feature_index
)

def clone(
self, deepcopy: bool = False, mask: Tensor | None = None
) -> MultiTaskDataset:
"""Return a copy of the dataset.
Args:
deepcopy: If True, perform a deep copy. Otherwise, use the same
tensors/lists/datasets.
mask: A `n`-dim boolean mask indicating which rows to keep from the target
dataset. This is used along the -2 dimension.
Returns:
The new dataset.
"""
datasets = list(self.datasets.values())
if mask is not None or deepcopy:
new_datasets = []
for outcome, ds in self.datasets.items():
new_datasets.append(
ds.clone(
deepcopy=deepcopy,
mask=mask if outcome == self.target_outcome_name else None,
)
)
datasets = new_datasets
return MultiTaskDataset(
datasets=datasets,
target_outcome_name=self.target_outcome_name,
task_feature_index=self.task_feature_index,
)


class ContextualDataset(SupervisedDataset):
"""This is a contextual dataset that is constructed from either a single
Expand Down Expand Up @@ -627,3 +707,33 @@ def _validate_decompositions(self) -> None:
raise InputDataError(
f"{outcome} is missing in metric_decomposition."
)

def clone(
self, deepcopy: bool = False, mask: Tensor | None = None
) -> ContextualDataset:
"""Return a copy of the dataset.
Args:
deepcopy: If True, perform a deep copy. Otherwise, use the same
tensors/lists/datasets.
mask: A `n`-dim boolean mask indicating which rows to keep. This is used
along the -2 dimension. `n` here corresponds to the number of rows in
an individual dataset.
Returns:
The new dataset.
"""
datasets = list(self.datasets.values())
if mask is not None or deepcopy:
datasets = [ds.clone(deepcopy=deepcopy, mask=mask) for ds in datasets]
if deepcopy:
parameter_decomposition = copy.deepcopy(self.parameter_decomposition)
metric_decomposition = copy.deepcopy(self.metric_decomposition)
else:
parameter_decomposition = self.parameter_decomposition
metric_decomposition = self.metric_decomposition
return ContextualDataset(
datasets=datasets,
parameter_decomposition=parameter_decomposition,
metric_decomposition=metric_decomposition,
)
3 changes: 3 additions & 0 deletions test/utils/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def test_dense(self):
# Test `__call__`
self.assertTrue(X().equal(values))

# Test `clone`
self.assertEqual(X.clone(), X)

def test_slice(self):
for arity in (2, 4):
for vals in (
Expand Down
Loading

0 comments on commit 7fd1cda

Please sign in to comment.