diff --git a/.github/workflows/pypi.yaml b/.github/workflows/pypi.yaml index 592c1bd..3b67bba 100644 --- a/.github/workflows/pypi.yaml +++ b/.github/workflows/pypi.yaml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest if: | github.event.workflow_run.conclusion == 'success' && - github.event.workflow_run.head_branch == 'main' + github.event.workflow_run.head_commit.ref == 'refs/heads/main' steps: - uses: actions/checkout@v4 - name: Set up Python diff --git a/src/dlomix/__init__.py b/src/dlomix/__init__.py index ce201e8..3a3cdd2 100644 --- a/src/dlomix/__init__.py +++ b/src/dlomix/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.5" +__version__ = "0.1.6" META_DATA = { "author": "Omar Shouman", diff --git a/src/dlomix/data/dataset.py b/src/dlomix/data/dataset.py index 6f8d132..704a534 100644 --- a/src/dlomix/data/dataset.py +++ b/src/dlomix/data/dataset.py @@ -223,7 +223,7 @@ def _refresh_config(self): { k: v for k, v in self.__dict__.items() - if k.startswith("_") and k != "_config" + if k.startswith("_") and k not in ["_config", "_additional_data"] } ) @@ -627,7 +627,12 @@ def load_from_disk(cls, path: str): @classmethod def from_dataset_config(cls, config: DatasetConfig): - d = cls(**config.__dict__) + config_dict = config.__dict__.copy() + + # remove the additional data from the config dict + config_dict.pop("_additional_data") + + d = cls(**config_dict) # data_source=config.data_source, # val_data_source=config.val_data_source, # test_data_source=config.test_data_source, diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5a33473..575139d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -3,11 +3,16 @@ import zipfile from os import makedirs from os.path import exists, join +from shutil import rmtree import pytest from datasets import Dataset, DatasetDict, load_dataset -from dlomix.data import FragmentIonIntensityDataset, RetentionTimeDataset +from dlomix.data import ( + FragmentIonIntensityDataset, + RetentionTimeDataset, + load_processed_dataset, +) logger = logging.getLogger(__name__) @@ -243,6 +248,41 @@ def test_nested_model_features(): assert example[0]["nested_feature"].shape == [2, 1, 2] +def test_save_dataset(): + hfdata = Dataset.from_dict(RAW_GENERIC_NESTED_DATA) + + intensity_dataset = FragmentIonIntensityDataset( + data_format="hf", + data_source=hfdata, + sequence_column="seq", + label_column="label", + model_features=["nested_feature"], + ) + + save_path = "./test_dataset" + intensity_dataset.save_to_disk(save_path) + rmtree(save_path) + + +def test_load_dataset(): + rtdataset = RetentionTimeDataset( + data_source=join(DOWNLOAD_PATH_FOR_ASSETS, "file_2.csv"), + data_format="csv", + sequence_column="sequence", + label_column="irt", + val_ratio=0.2, + ) + + save_path = "./test_dataset" + rtdataset.save_to_disk(save_path) + splits = rtdataset._data_files_available_splits + + loaded_dataset = load_processed_dataset(save_path) + assert loaded_dataset._data_files_available_splits == splits + assert loaded_dataset.hf_dataset is not None + rmtree(save_path) + + def test_no_split_datasetDict_hf_inmemory(): hfdata = Dataset.from_dict(RAW_GENERIC_NESTED_DATA) hf_dataset = DatasetDict({"train": hfdata})