diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md
new file mode 100644
index 00000000..054c6205
--- /dev/null
+++ b/docs/docs/data/multitask.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiTask
+
+::: multimolecule.data.multitask
diff --git a/docs/docs/runners/config.md b/docs/docs/runners/config.md
new file mode 100644
index 00000000..8e188199
--- /dev/null
+++ b/docs/docs/runners/config.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiMoleculeConfig
+
+::: multimolecule.runners.MultiMoleculeConfig
diff --git a/docs/docs/runners/index.md b/docs/docs/runners/index.md
new file mode 100644
index 00000000..75a0f528
--- /dev/null
+++ b/docs/docs/runners/index.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# runners
+
+--8<-- "multimolecule/runners/README.md:8:"
diff --git a/docs/docs/runners/runner.md b/docs/docs/runners/runner.md
new file mode 100644
index 00000000..7f2f4c12
--- /dev/null
+++ b/docs/docs/runners/runner.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiMoleculeRunner
+
+::: multimolecule.runners.base_runner.BaseRunner
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 93d53a45..9324d9c3 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -9,9 +9,14 @@ repo_url: https://github.com/DLS5-Omics/multimolecule
nav:
- index.md
+ - runners:
+ - runners/index.md
+ - MultiMoleculeRunner: runners/runner.md
+ - MultiMoleculeConfig: runners/config.md
- data:
- data/index.md
- Dataset: data/dataset.md
+ - multitask: data/multitask.md
- datasets:
- datasets/index.md
- DNA:
diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py
index 63f937c5..c9ee3a87 100644
--- a/multimolecule/__init__.py
+++ b/multimolecule/__init__.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from .apis import evaluate, infer, train
from .data import Dataset
from .models import (
AutoModelForContactPrediction,
@@ -124,6 +125,7 @@
TokenKMerHead,
TokenPredictionHead,
)
+from .runners import MultiMoleculeConfig, MultiMoleculeRunner
from .tasks import Task, TaskLevel, TaskType
from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer
from .utils import count_parameters
diff --git a/multimolecule/apis/__init__.py b/multimolecule/apis/__init__.py
new file mode 100644
index 00000000..8e3e5b3c
--- /dev/null
+++ b/multimolecule/apis/__init__.py
@@ -0,0 +1,19 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .run import evaluate, infer, train
+
+__all__ = ["train", "evaluate", "infer"]
diff --git a/multimolecule/apis/run.py b/multimolecule/apis/run.py
new file mode 100644
index 00000000..1fdb7666
--- /dev/null
+++ b/multimolecule/apis/run.py
@@ -0,0 +1,115 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+# mypy: disable-error-code="attr-defined"
+
+import atexit
+import os
+import warnings
+from typing import Type
+
+import danling as dl
+import torch
+
+from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner
+
+try:
+ import nni
+except ImportError:
+ nni = None
+
+
+def train(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ config.training = True
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if config.get("nni", False):
+ if nni is None:
+ raise ValueError("Unable to retrieve nni parameters, since nni is not installed.")
+ config.merge(nni.get_next_parameter())
+ with dl.debug(config.get("debug", False)):
+ runner = runner_cls(config)
+ atexit.register(runner.print_result)
+ atexit.register(runner.save_result)
+ atexit.register(runner.save_checkpoint)
+ result = runner.train()
+ return result
+
+
+def evaluate(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig.empty()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ config.training = False
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if "checkpoint" not in config or not isinstance(config.checkpoint, str):
+ raise RuntimeError("Please specify `checkpoint` to run evaluate")
+ for name, data in config.datas.items():
+ if "evaluation" not in data or not isinstance(data.evaluate, str):
+ raise RuntimeError(f"Please specify `evaluation` to run evaluate in datas.{name}")
+ runner = runner_cls(config)
+ result = runner.evaluate_epoch("evaluation")
+ print(result)
+ return result
+
+
+def infer(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig.empty()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ config.training = False
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if "checkpoint" not in config or not isinstance(config.checkpoint, str):
+ raise RuntimeError("Please specify `checkpoint` to run infer.")
+ for name, data in config.datas.items():
+ if "inference" not in data or not isinstance(data.inference, str):
+ raise RuntimeError(f"Please specify `inference` to run infer in datas.{name}")
+ if "result_path" not in config or not isinstance(config.result_path, str):
+ config.result_path = os.path.join(os.getcwd(), "result.json")
+ warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2)
+ runner = runner_cls(config)
+ result = runner.infer()
+ runner.save(result, config.result_path)
+ return result
diff --git a/multimolecule/apis/stat.py b/multimolecule/apis/stat.py
new file mode 100644
index 00000000..5e525d55
--- /dev/null
+++ b/multimolecule/apis/stat.py
@@ -0,0 +1,99 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import os
+import shutil
+from statistics import mean
+from typing import List
+
+import chanfig
+import pandas as pd
+from chanfig import NestedDict
+from tqdm import tqdm
+
+
+class Result(NestedDict):
+ pretrained: str
+ id: str
+ seed: int
+ epoch: int
+ validation: NestedDict
+ test: NestedDict
+
+
+def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]:
+ results = []
+ for root, _, files in tqdm(os.walk(experiment_root)):
+ if "run.log" in files:
+ if "best.json" not in files:
+ if remove_empty:
+ shutil.rmtree(root)
+ continue
+ best = NestedDict.from_json(os.path.join(root, "best.json"))
+ if "index" not in best:
+ if remove_empty:
+ shutil.rmtree(root)
+ continue
+ config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml"))
+ pretrained = config.pretrained.split("/")[-1]
+ seed = config.seed
+ pretrained, seed = "", 1
+ result = Result(id=best.id, pretrained=pretrained, seed=seed)
+ result.validation = NestedDict(
+ {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.items()}
+ )
+ result.test = NestedDict(
+ {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.items()}
+ )
+ result.epoch = best.index
+ result.pop("validation.time", None)
+ result.pop("test.time", None)
+ result.pop("validation.loss", None)
+ result.pop("test.loss", None)
+ result.pop("validation.lr", None)
+ result.pop("test.lr", None)
+ results.append(result)
+ # Remove empty directories, perform twice to remove all empty directories
+ if remove_empty:
+ for root, dirs, files in os.walk(experiment_root):
+ if not files and not dirs:
+ os.rmdir(root)
+ for root, dirs, files in os.walk(experiment_root):
+ if not files and not dirs:
+ os.rmdir(root)
+ results.sort(key=lambda x: (x.pretrained, x.seed, x.id))
+ return results
+
+
+def write_result_stat(results: List[Result], path: str):
+ results = [dict(result.all_items()) for result in results] # type: ignore[misc]
+ df = pd.DataFrame.from_dict(results)
+ df.insert(len(df.keys()) - 1, "comment", "")
+ df.fillna("")
+ df.to_csv(path, index=False)
+
+
+class Config(chanfig.Config):
+ experiment_root: str = "experiments"
+ out_path: str = "result.csv"
+
+
+if __name__ == "__main__":
+ config = Config().parse()
+ result_stat = get_result_stat(config.experiment_root)
+ if not len(result_stat) > 0:
+ raise ValueError("No results found")
+ write_result_stat(result_stat, config.out_path)
diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py
index 20d1bd5f..d6563e18 100644
--- a/multimolecule/data/__init__.py
+++ b/multimolecule/data/__init__.py
@@ -15,9 +15,13 @@
# along with this program. If not, see .
from .dataset import Dataset
+from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
from .utils import no_collate
__all__ = [
"Dataset",
+ "MultiTaskDataset",
+ "MultiTaskSampler",
+ "DistributedMultiTaskSampler",
"no_collate",
]
diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py
new file mode 100644
index 00000000..7c20e829
--- /dev/null
+++ b/multimolecule/data/multitask.py
@@ -0,0 +1,246 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from bisect import bisect_right
+from collections.abc import Iterator, Mapping, Sequence
+from copy import deepcopy
+from random import choices
+
+import torch
+from chanfig import NestedDict
+from torch import distributed as dist
+from torch.utils import data
+
+from .dataset import Dataset
+
+
+class MultiTaskDataset(data.ConcatDataset):
+
+ datasets: Mapping
+ dataset_keys: Sequence[str]
+ dataset_values: Sequence[Dataset]
+
+ def __init__(self, datasets: Mapping) -> None:
+ for key, dataset in datasets.items():
+ if not isinstance(dataset, Dataset):
+ raise TypeError(f"Dataset {key} should be an instance of Dataset")
+ self.datasets = datasets
+ if not len(self.datasets) > 0:
+ raise ValueError("MultiTaskDataset should contain at least one dataset")
+ self.dataset_keys, self.dataset_values = zip(*self.datasets.items())
+ self.cumulative_sizes = self.cumsum(self.dataset_values)
+
+ def __getitems__(self, key: Sequence[int]) -> Mapping:
+ dataset_idx = bisect_right(self.cumulative_sizes, key[0])
+ if dataset_idx == 0:
+ sample_idx = key
+ else:
+ sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key]
+ batch = self.dataset_values[dataset_idx][sample_idx]
+ batch["dataset"] = self.dataset_keys[dataset_idx]
+ return batch
+
+ @property
+ def tasks(self) -> NestedDict:
+ tasks = NestedDict()
+ for dataset in self.dataset_values:
+ for n, t in dataset.tasks.items():
+ if n not in tasks:
+ tasks[n] = t
+ elif tasks[n] != t:
+ raise ValueError(f"Task {n} has different configurations across datasets")
+ return tasks
+
+ @property
+ def dataset_tasks(self) -> NestedDict:
+ return NestedDict({k: v.tasks for k, v in self.datasets.items()})
+
+ def __repr__(self) -> str:
+ return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})"
+
+
+class MultiTaskSampler(data.BatchSampler):
+ r"""
+ Ensure all items in a batch comes from the same dataset.
+
+ Arguments:
+ sampler (Sampler): Base sampler.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``
+ """
+
+ datasets: Sequence[Dataset]
+
+ def __init__( # pylint: disable=super-init-not-called
+ self,
+ dataset: MultiTaskDataset,
+ batch_size: int,
+ shuffle: bool = True,
+ drop_last: bool = False,
+ sampler_cls: type[data.Sampler] | None = None,
+ weights: list[int] | None = None,
+ ) -> None:
+ self.datasets = dataset.dataset_values
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+ self.shuffle = shuffle
+ if sampler_cls is None:
+ sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler
+ self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore
+ self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore
+ self.cumulative_sizes = dataset.cumulative_sizes
+ self.num_datasets = len(self.datasets)
+ self.weights = weights if weights is not None else self.dataset_sizes
+
+ def __iter__(self):
+ sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
+ sampler_weights = deepcopy(self.weights)
+ sampler_idx = 0
+ # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
+ if self.drop_last:
+ while sampler_iters:
+ if self.shuffle:
+ sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ try:
+ batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
+ yield batch
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ else:
+ while sampler_iters:
+ if self.shuffle:
+ sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ batch = [0] * self.batch_size
+ idx_in_batch = 0
+ try:
+ for _ in range(self.batch_size):
+ batch[idx_in_batch] = next(sampler_iter) + cumulative_size
+ idx_in_batch += 1
+ yield batch
+ idx_in_batch = 0 # noqa: SIM113
+ batch = [0] * self.batch_size
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ if idx_in_batch > 0:
+ yield batch[:idx_in_batch]
+
+ def __len__(self):
+ batch_size = self.batch_size
+ if self.drop_last:
+ return sum(len(d) // batch_size for d in self.datasets)
+ return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+
+class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods
+ r"""
+ Distributed version of MultiTaskSampler, which ensures that all GPUs sample data from the
+ same sub-dataset in each step without requiring additional communication.
+ The dataset selection is based on a random seed mechanism that is synchronized across epochs.
+
+ See Also:
+ [MultiTaskSampler][MultiTaskSampler]
+ """
+
+ def __init__(
+ self,
+ dataset: MultiTaskDataset,
+ batch_size: int,
+ shuffle: bool = True,
+ drop_last: bool = False,
+ sampler_cls: type[data.Sampler] = data.RandomSampler,
+ weights: list[int] | None = None,
+ seed: int = 0,
+ ) -> None:
+ super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights)
+ self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets]
+ self.seed = seed
+ self.epoch = 0
+
+ def set_epoch(self, epoch: int):
+ """
+ Sets the epoch for deterministic shuffling.
+ """
+ self.epoch = epoch
+ for sampler in self.samplers:
+ sampler.set_epoch(epoch)
+
+ def _get_sampler_idx(self, high: int) -> int:
+ """
+ Determines which sampler (i.e., sub-dataset) to use based on the seed and epoch.
+ """
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ sampler_idx = torch.randint(low=0, high=high, size=(1,), generator=g).item()
+ return sampler_idx
+
+ def __iter__(self) -> Iterator:
+ sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
+ sampler_weights = deepcopy(self.weights)
+
+ if self.drop_last:
+ while sampler_iters:
+ # Sample the same sub-dataset across all GPUs using the seeded index
+ sampler_idx = self._get_sampler_idx(len(sampler_iters))
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ try:
+ batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
+ yield batch
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ else:
+ while sampler_iters:
+ # Sample the same sub-dataset across all GPUs using the seeded index
+ sampler_idx = self._get_sampler_idx(len(sampler_iters))
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ batch = [0] * self.batch_size
+ idx_in_batch = 0
+ try:
+ for _ in range(self.batch_size):
+ batch[idx_in_batch] = next(sampler_iter) + cumulative_size
+ idx_in_batch += 1
+ yield batch
+ idx_in_batch = 0 # noqa: SIM113
+ batch = [0] * self.batch_size
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ if idx_in_batch > 0:
+ yield batch[:idx_in_batch]
+
+ def __len__(self) -> int:
+ batch_size = self.batch_size * self.world_size
+ if self.drop_last:
+ return sum(len(d) // batch_size for d in self.datasets)
+ return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+ @property
+ def world_size(self) -> int:
+ r"""Return the number of processes in the current process group."""
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_world_size()
+ return 1
diff --git a/multimolecule/runners/README.md b/multimolecule/runners/README.md
new file mode 100644
index 00000000..bb1000ad
--- /dev/null
+++ b/multimolecule/runners/README.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# runners
+
+`runners` provide an easy-to-use interface for running experiments.
diff --git a/multimolecule/runners/__init__.py b/multimolecule/runners/__init__.py
new file mode 100644
index 00000000..70fa4076
--- /dev/null
+++ b/multimolecule/runners/__init__.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .config import MultiMoleculeConfig
+from .runner import MultiMoleculeRunner
+
+__all__ = ["MultiMoleculeConfig", "MultiMoleculeRunner"]
diff --git a/multimolecule/runners/base_runner.py b/multimolecule/runners/base_runner.py
new file mode 100644
index 00000000..4b62bdb8
--- /dev/null
+++ b/multimolecule/runners/base_runner.py
@@ -0,0 +1,343 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import math
+import os
+from functools import cached_property, partial
+from typing import Any, Tuple
+from warnings import warn
+
+import danling as dl
+import torch
+from art import text2art
+from chanfig import FlatDict, NestedDict
+from danling import MultiTaskMetrics
+from datasets import disable_progress_bars, get_dataset_split_names
+from lazy_imports import try_import
+from torch import nn
+from torch.nn import functional as F
+from torch.utils import data
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from multimolecule import defaults
+from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
+from multimolecule.module import HeadConfig, ModelRegistry, MultiMoleculeModel
+
+from .config import MultiMoleculeConfig
+from .metrics import MetricRegistry
+
+with try_import() as ema:
+ from ema_pytorch import EMA
+
+
+disable_progress_bars()
+
+
+class BaseRunner(dl.BaseRunner):
+
+ config: MultiMoleculeConfig
+ model: MultiMoleculeModel
+ all_datasets: NestedDict
+
+ def __init__(self, config: MultiMoleculeConfig):
+ if config.art:
+ print(text2art("MultiMolecule", "rand-large"))
+ super().__init__(config)
+ # must build tokenizer before datasets
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained)
+ self.build_datasets()
+ self.build_dataloaders()
+ self.model = ModelRegistry.build(**self.network)
+ self.model = self.model.to(self.device)
+ if self.config.ema.pop("enabled"):
+ ema.check()
+ self.ema = EMA(self.model)
+ if self.config.training:
+ pretrained_ratio = self.config.optim.pop("pretrained_ratio", 1e-2)
+ self.optimizer = self.get_optimizer(self.config.optim.pop("name"))(
+ params=self.model.trainable_parameters(pretrained_ratio=pretrained_ratio, **self.config.optim),
+ **self.config.optim,
+ )
+ if self.config.sched:
+ self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched)
+ self.metrics = self.build_metrics()
+
+ def __post_init__(self):
+ if self.config.platform != "deepspeed" and "checkpoint" in self.config:
+ self.load_checkpoint(self.config.checkpoint)
+ if self.distributed:
+ self.model = nn.parallel.DistributedDataParallel(
+ self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True
+ )
+ super().__post_init__()
+ if self.config.platform == "deepspeed" and "checkpoint" in self.config:
+ self.load_checkpoint(self.config.checkpoint)
+ self.yaml(os.path.join(self.dir, "trainer.yaml"))
+ print(self)
+ print(self.get_dataset_lengths())
+
+ def train_step(self, data) -> Tuple[Any, torch.Tensor]:
+ with self.autocast(), self.accumulate():
+ pred = self.model(**data)
+ loss = self.loss_fn(pred, data)
+ self.advance(loss)
+ self.metric_fn(pred, data)
+ return pred, loss
+
+ def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]:
+ model = self.ema or self.model
+ pred = model(**data)
+ loss = self.loss_fn(pred, data)
+ self.metric_fn(pred, data)
+ return pred, loss
+
+ @torch.inference_mode()
+ def infer(self, split: str = "inf") -> NestedDict | FlatDict | list:
+ r"""
+ Perform inference on `split`.
+
+ Args:
+ split (str): split to run inference
+
+ Return:
+ Inference outputs.
+
+ - If the model has single output:
+
+ - If labels are available, a [`FlatDict`][chanfig.FlatDict] with keys `predict` and `label` is
+ returned.
+ - If labels are not available, a list of predictions is returned.
+
+ - If the model has multiple outputs:
+ - If labels are available, a [`NestedDict`][chanfig.NestedDict] with keys as task names and values
+ as dictionaries with keys `predict` and `label` is returned.
+ - If labels are not available, a [`FlatDict`][chanfig.FlatDict] with keys as task names and values
+ as lists of predictions is returned.
+ """
+
+ self.mode = "inf" # type: ignore
+ loader = self.dataloaders[split]
+ preds = FlatDict()
+ labels = FlatDict()
+ model = self.ema or self.model
+ for _, data in tqdm(enumerate(loader), total=len(loader)): # noqa: F402
+ pred = model(**data)
+ for task, p in pred.items():
+ preds[task].extend(p["logits"].squeeze(-1).tolist())
+ if task in data:
+ labels[task].extend(data[task].squeeze(-1).tolist())
+
+ if self.distributed:
+ torch.cuda.synchronize()
+ for task in preds.keys():
+ preds[task] = self.gather_for_metrics(preds[task])
+ for task in labels.keys():
+ labels[task] = self.gather_for_metrics(labels[task])
+ if labels:
+ if len(preds) == 1:
+ return FlatDict(predict=next(iter(preds.values())), label=next(iter(labels.values())))
+ return NestedDict({task: {"predict": preds[task], "label": labels[task]} for task in preds})
+ if len(preds) == 1:
+ return next(iter(preds.values()))
+ return preds
+
+ def loss_fn(self, pred, data):
+ if self.balance == "rlw":
+ loss = torch.stack([p["loss"] for p in pred.values()])
+ weight = F.softmax(torch.randn(len(pred), device=loss.device, dtype=loss.dtype), dim=-1)
+ return loss.T @ weight
+ if self.balance == "gls":
+ return math.prod(p["loss"] for p in pred.values()) ** (1 / len(pred))
+ if self.balance != "ew":
+ warn(f"Unknown balance method {self.balance}, using equal weighting.")
+ return sum(p["loss"] for p in pred.values()) / len(pred)
+
+ def metric_fn(self, pred, data):
+ metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics
+ metric.update({t: (p["logits"], data[t]) for t, p in pred.items()})
+
+ @cached_property
+ def tasks(self):
+ if not self.datasets:
+ raise ValueError("No datasets found")
+ if "train" in self.datasets:
+ return self.datasets.train.tasks
+ return next(iter(self.datasets.values())).tasks
+
+ @cached_property
+ def dataset_tasks(self):
+ if not self.datasets:
+ raise ValueError("No datasets found")
+ dataset = self.datasets.train if "train" in self.datasets else next(iter(self.datasets.values()))
+ tasks = self.tasks
+ dataset_tasks = dataset.dataset_tasks if isinstance(dataset, MultiTaskDataset) else dataset.tasks
+ for dataset in self.datasets.values():
+ if isinstance(dataset, MultiTaskDataset):
+ for dataset_name, tasks_ in dataset.dataset_tasks.items():
+ for task_name, task_ in tasks_.items():
+ if task_name not in tasks:
+ raise ValueError(f"Task {task_name} of dataset {dataset_name} is not defined")
+ task = tasks[task_name]
+ if task != task_:
+ warn(
+ f"Task {task_name} of dataset {dataset_name} has different configurations "
+ "compared to training data, using training configuration.\n"
+ "This may lead to unexpected behavior.",
+ )
+ if dataset_name not in dataset_tasks:
+ dataset_tasks[dataset_name] = NestedDict()
+ if task_name not in dataset_tasks[dataset_name]:
+ dataset_tasks[dataset_name][task_name] = task
+ else:
+ for task_name, task_ in dataset.tasks.items():
+ if task_name not in tasks:
+ raise ValueError(f"Task {task_name} is not defined")
+ task = tasks[task_name]
+ if task != task_:
+ warn(
+ f"Task {task_name} has different configurations "
+ "compared to training data, using training configuration.\n"
+ "This may lead to unexpected behavior.",
+ )
+ if task_name not in dataset_tasks:
+ dataset_tasks[task_name] = task
+ return dataset_tasks
+
+ @cached_property
+ def network(self):
+ heads = {
+ name: HeadConfig(num_labels=task.num_labels, problem_type=task.type, type=task.level)
+ for name, task in self.tasks.items()
+ }
+ if "heads" not in self.config.network:
+ self.config.network.heads = NestedDict(heads)
+ else:
+ self.config.network.heads.merge(heads, overwrite=False)
+ return self.config.network
+
+ def build_datasets(self):
+ if "data" in self.config:
+ self.datasets = self.all_datasets = self._build_dataset(self.config.data)
+ return
+ if "datas" in self.config:
+ self.all_datasets = NestedDict(
+ {name: self._build_dataset(config, name) for name, config in self.config.datas.items()}
+ )
+ datasets = {
+ subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict}
+ for subkey in {k for v in self.all_datasets.values() for k in v}
+ }
+ self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()})
+ return
+ raise ValueError("No data configuration found")
+
+ def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict:
+ name = name or config.root
+ print(f"Building dataset {name}")
+ dataset = NestedDict()
+ train_splits = [key for key in config.keys() if key.startswith(defaults.TRAIN_SPLITS)]
+ validation_splits = [key for key in config.keys() if key.startswith(defaults.VALIDATION_SPLITS)]
+ test_splits = [key for key in config.keys() if key.startswith(defaults.TEST_SPLITS)]
+ inference_splits = [key for key in config.keys() if key.startswith(defaults.INFERENCE_SPLITS)]
+ all_splits = train_splits + validation_splits + test_splits + inference_splits
+ ignored_keys = all_splits + ["root"]
+ dataset_factory = partial(
+ Dataset,
+ tokenizer=self.tokenizer,
+ **{k: v for k, v in config.items() if k not in ignored_keys},
+ )
+ if os.path.isdir(config.root):
+ for split in train_splits:
+ dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="train")
+ for split in validation_splits:
+ dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="validation")
+ for split in test_splits:
+ dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="test")
+ for split in inference_splits:
+ dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split=config[split])
+ else:
+ splits = [k for k in defaults.DATASET_SPLITS if config.get(k) is not None]
+ if not splits:
+ existing_splits = get_dataset_split_names(config.root)
+ if "train" in existing_splits:
+ config.train = "train"
+ splits.append("train")
+ if "validation" in existing_splits:
+ config.validation = "validation"
+ splits.append("validation")
+ if "test" in existing_splits:
+ config.test = "test"
+ splits.append("test")
+ for split in splits:
+ dataset[split] = dataset_factory(config.root, split=split)
+ if not dataset:
+ raise ValueError(f"No datasets built. This is likely due to missing data paths in {config}.")
+ return dataset
+
+ def build_dataloaders(self):
+ datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
+ default_kwargs = self.config.get("dataloader", NestedDict())
+ dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
+ for k, d in datasets.items():
+ dataloader_kwargs.setdefault(k, NestedDict())
+ dataloader_kwargs[k].merge(default_kwargs, overwrite=False)
+ batch_size = dataloader_kwargs[k].pop("batch_size")
+ shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True))
+ drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True))
+ if isinstance(d, MultiTaskDataset):
+ batch_sampler = (
+ DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+ if self.distributed
+ else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+ )
+ else:
+ sampler = (
+ data.distributed.DistributedSampler(d, shuffle=shuffle)
+ if self.distributed
+ else data.RandomSampler(d) if shuffle else data.SequentialSampler(d)
+ )
+ batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last)
+ self.dataloaders[k] = data.DataLoader(
+ d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k]
+ )
+
+ def build_metrics(self) -> MultiTaskMetrics:
+ return MultiTaskMetrics(
+ {
+ name: MetricRegistry.build(type=task.type, num_labels=task.num_labels)
+ for name, task in self.dataset_tasks.all_items()
+ }
+ )
+
+ def collate_fn(self, batch):
+ return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()}
+
+ def get_dataset_lengths(self) -> str:
+ repr = "dataset lengths:\n"
+ longest_name = max(len(name) for name in self.all_datasets.keys())
+ for name, dataset in self.all_datasets.items():
+ if isinstance(dataset, NestedDict):
+ repr += f"{name}:"
+ if len(name) < longest_name:
+ repr += " " * (longest_name - len(name))
+ repr += "\t\t"
+ for split, d in dataset.items():
+ repr += f" {split}: {len(d)}\t"
+ else:
+ repr += f"{name}: {len(dataset)}\t"
+ repr += "\n"
+ return repr
diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py
new file mode 100644
index 00000000..f4be8532
--- /dev/null
+++ b/multimolecule/runners/config.py
@@ -0,0 +1,105 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import os
+from pathlib import Path
+from typing import List
+
+from chanfig import Config
+from transformers import PretrainedConfig
+
+
+class DataConfig(Config):
+ root: str = "."
+ train: str | None
+ validation: str | None
+ test: str | None
+ feature_cols: List | None = None
+ label_cols: List | None = None
+ truncation: bool = True
+
+
+class OptimConfig(Config):
+ name: str = "AdamW"
+ lr: float = 1e-3
+ weight_decay: float = 1e-2
+ pretrained_ratio: float = 1e-2
+
+
+class EmaConfig(Config):
+ enabled: bool = False
+ beta: float = 0.999
+ update_after_step: int = 0
+ update_every: int = 10
+
+
+class MultiMoleculeConfig(Config):
+ name: str
+ seed: int = 1016
+
+ balance: str = "ew"
+ platform: str = "torch"
+ training: bool = True
+
+ pretrained: str | None
+ use_pretrained: bool = True
+ transformers: PretrainedConfig
+ epoch_end: int = 20
+
+ data: DataConfig
+
+ tensorboard: bool = True
+ save_interval: int = 10
+
+ art: bool = True
+ allow_tf32: bool = True
+ reduced_precision_reduction: bool = False
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.datas = Config(default_factory=DataConfig)
+ self.dataloader.batch_size = 32
+ self.optim = OptimConfig()
+ self.ema = EmaConfig()
+ self.sched.final_lr = 0
+
+ def post(self):
+ if "pretrained" not in self and "checkpoint" not in self:
+ raise ValueError("Either one of `pretrained` or `checkpoint` must be specified")
+ if "data" in self:
+ if self.datas:
+ raise ValueError("Only one of `data` or `datas` can be specified, but not both")
+ del self.datas
+ if "pretrained" in self:
+ self["network.backbone.sequence.name"] = self.get("pretrained")
+ self.name = self.get_name()
+ self["network.backbone.sequence.use_pretrained"] = self.use_pretrained
+
+ def get_name(self) -> str:
+ pretrained = self.get("pretrained")
+ if os.path.exists(pretrained):
+ path = Path(pretrained)
+ if os.path.isfile(pretrained):
+ pretrained = str(path.relative_to(path.parents[1]).with_suffix(""))
+ else:
+ pretrained = path.stem
+ name = pretrained.replace("/", "--")
+ if "optim" in self:
+ optim_name = self.optim.get("name", "no")
+ name += f"-{self.optim.lr}@{optim_name}"
+ return name + f"-{self.seed}"
diff --git a/multimolecule/runners/metrics.py b/multimolecule/runners/metrics.py
new file mode 100644
index 00000000..da584cbc
--- /dev/null
+++ b/multimolecule/runners/metrics.py
@@ -0,0 +1,37 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from chanfig import Registry as Registry_
+from danling.metrics import binary_metrics, multiclass_metrics, multilabel_metrics, regression_metrics
+
+
+class Registry(Registry_):
+
+ def build(self, type, num_labels: int | None = None, **kwargs):
+ if type == "multilabel":
+ return self.init(self.lookup(type), num_labels=num_labels, **kwargs)
+ if type == "multiclass":
+ return self.init(self.lookup(type), num_classes=num_labels, **kwargs)
+ if type == "regression":
+ return self.init(self.lookup(type), num_outputs=num_labels, **kwargs)
+ return self.init(self.lookup(type), **kwargs)
+
+
+MetricRegistry = Registry(key="type")
+MetricRegistry.register(binary_metrics, "binary")
+MetricRegistry.register(multiclass_metrics, "multiclass")
+MetricRegistry.register(multilabel_metrics, "multilabel")
+MetricRegistry.register(regression_metrics, "regression")
diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py
new file mode 100644
index 00000000..97cab4e7
--- /dev/null
+++ b/multimolecule/runners/runner.py
@@ -0,0 +1,42 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import danling as dl
+
+from .base_runner import BaseRunner
+
+
+class MultiMoleculeRunner(type):
+ def __new__(cls, config):
+ if config.get("platform", "torch") == "torch":
+ return TorchRunner(config)
+ if config.platform == "deepspeed":
+ return DeepSpeedRunner(config)
+ if config.platform == "accelerate":
+ return AccelerateRunner(config)
+ raise ValueError(f"Unsupported platform: {config.platform}")
+
+
+class TorchRunner(BaseRunner, dl.TorchRunner):
+ pass
+
+
+class DeepSpeedRunner(BaseRunner, dl.DeepSpeedRunner):
+ pass
+
+
+class AccelerateRunner(BaseRunner, dl.AccelerateRunner):
+ pass
diff --git a/multimolecule/train.py b/multimolecule/train.py
new file mode 100644
index 00000000..f146bc97
--- /dev/null
+++ b/multimolecule/train.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .apis import train
+
+if __name__ == "__main__":
+ train()
diff --git a/pyproject.toml b/pyproject.toml
index ba202944..4260bd3b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,8 +45,9 @@ dynamic = [
]
dependencies = [
"accelerate",
+ "art",
"chanfig>=0.0.105",
- "danling[torch]>=0.3.11",
+ "danling[torch]>=0.4.0b1",
"datasets",
'StrEnum; python_version < "3.11"',
"torch",