Skip to content

Commit

Permalink
Fixed nested datatypes casting in Dataset columns (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
omsh authored Dec 13, 2024
2 parents d5578e8 + 7f963cf commit 4fef16a
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'main'
github.event.workflow_run.head_commit.ref == 'refs/heads/main'
steps:
- uses: actions/checkout@v4
- name: Set up Python
Expand Down
2 changes: 1 addition & 1 deletion src/dlomix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.5"
__version__ = "0.1.6"

META_DATA = {
"author": "Omar Shouman",
Expand Down
9 changes: 7 additions & 2 deletions src/dlomix/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _refresh_config(self):
{
k: v
for k, v in self.__dict__.items()
if k.startswith("_") and k != "_config"
if k.startswith("_") and k not in ["_config", "_additional_data"]
}
)

Expand Down Expand Up @@ -627,7 +627,12 @@ def load_from_disk(cls, path: str):

@classmethod
def from_dataset_config(cls, config: DatasetConfig):
d = cls(**config.__dict__)
config_dict = config.__dict__.copy()

# remove the additional data from the config dict
config_dict.pop("_additional_data")

d = cls(**config_dict)
# data_source=config.data_source,
# val_data_source=config.val_data_source,
# test_data_source=config.test_data_source,
Expand Down
42 changes: 41 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import zipfile
from os import makedirs
from os.path import exists, join
from shutil import rmtree

import pytest
from datasets import Dataset, DatasetDict, load_dataset

from dlomix.data import FragmentIonIntensityDataset, RetentionTimeDataset
from dlomix.data import (
FragmentIonIntensityDataset,
RetentionTimeDataset,
load_processed_dataset,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -243,6 +248,41 @@ def test_nested_model_features():
assert example[0]["nested_feature"].shape == [2, 1, 2]


def test_save_dataset():
hfdata = Dataset.from_dict(RAW_GENERIC_NESTED_DATA)

intensity_dataset = FragmentIonIntensityDataset(
data_format="hf",
data_source=hfdata,
sequence_column="seq",
label_column="label",
model_features=["nested_feature"],
)

save_path = "./test_dataset"
intensity_dataset.save_to_disk(save_path)
rmtree(save_path)


def test_load_dataset():
rtdataset = RetentionTimeDataset(
data_source=join(DOWNLOAD_PATH_FOR_ASSETS, "file_2.csv"),
data_format="csv",
sequence_column="sequence",
label_column="irt",
val_ratio=0.2,
)

save_path = "./test_dataset"
rtdataset.save_to_disk(save_path)
splits = rtdataset._data_files_available_splits

loaded_dataset = load_processed_dataset(save_path)
assert loaded_dataset._data_files_available_splits == splits
assert loaded_dataset.hf_dataset is not None
rmtree(save_path)


def test_no_split_datasetDict_hf_inmemory():
hfdata = Dataset.from_dict(RAW_GENERIC_NESTED_DATA)
hf_dataset = DatasetDict({"train": hfdata})
Expand Down

0 comments on commit 4fef16a

Please sign in to comment.