From 7a2cb9b2ab8f11d0a73e61fb86a91f4ee63f033d Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Sat, 11 May 2024 17:04:38 +0800 Subject: [PATCH] add runner Signed-off-by: Zhiyuan Chen --- docs/docs/data/multitask.md | 9 + docs/docs/runners/config.md | 9 + docs/docs/runners/index.md | 9 + docs/docs/runners/runner.md | 9 + docs/mkdocs.yml | 5 + multimolecule/__init__.py | 2 + multimolecule/apis/__init__.py | 19 ++ multimolecule/apis/run.py | 115 +++++++++ multimolecule/apis/stat.py | 99 ++++++++ multimolecule/data/__init__.py | 4 + multimolecule/data/multitask.py | 246 ++++++++++++++++++ multimolecule/runners/README.md | 9 + multimolecule/runners/__init__.py | 20 ++ multimolecule/runners/base_runner.py | 357 +++++++++++++++++++++++++++ multimolecule/runners/config.py | 105 ++++++++ multimolecule/runners/metrics.py | 37 +++ multimolecule/runners/runner.py | 42 ++++ multimolecule/train.py | 20 ++ pyproject.toml | 3 +- 19 files changed, 1118 insertions(+), 1 deletion(-) create mode 100644 docs/docs/data/multitask.md create mode 100644 docs/docs/runners/config.md create mode 100644 docs/docs/runners/index.md create mode 100644 docs/docs/runners/runner.md create mode 100644 multimolecule/apis/__init__.py create mode 100644 multimolecule/apis/run.py create mode 100644 multimolecule/apis/stat.py create mode 100644 multimolecule/data/multitask.py create mode 100644 multimolecule/runners/README.md create mode 100644 multimolecule/runners/__init__.py create mode 100644 multimolecule/runners/base_runner.py create mode 100644 multimolecule/runners/config.py create mode 100644 multimolecule/runners/metrics.py create mode 100644 multimolecule/runners/runner.py create mode 100644 multimolecule/train.py diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md new file mode 100644 index 00000000..054c6205 --- /dev/null +++ b/docs/docs/data/multitask.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiTask + +::: multimolecule.data.multitask diff --git a/docs/docs/runners/config.md b/docs/docs/runners/config.md new file mode 100644 index 00000000..8e188199 --- /dev/null +++ b/docs/docs/runners/config.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiMoleculeConfig + +::: multimolecule.runners.MultiMoleculeConfig diff --git a/docs/docs/runners/index.md b/docs/docs/runners/index.md new file mode 100644 index 00000000..75a0f528 --- /dev/null +++ b/docs/docs/runners/index.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# runners + +--8<-- "multimolecule/runners/README.md:8:" diff --git a/docs/docs/runners/runner.md b/docs/docs/runners/runner.md new file mode 100644 index 00000000..7f2f4c12 --- /dev/null +++ b/docs/docs/runners/runner.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiMoleculeRunner + +::: multimolecule.runners.base_runner.BaseRunner diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 67a94e72..4eb120dc 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -9,9 +9,14 @@ repo_url: https://github.com/DLS5-Omics/multimolecule nav: - index.md + - runners: + - runners/index.md + - MultiMoleculeRunner: runners/runner.md + - MultiMoleculeConfig: runners/config.md - data: - data/index.md - Dataset: data/dataset.md + - multitask: data/multitask.md - datasets: - datasets/index.md - DNA: diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py index e9018441..ac7637fb 100644 --- a/multimolecule/__init__.py +++ b/multimolecule/__init__.py @@ -20,6 +20,7 @@ # . +from .apis import evaluate, infer, train from .data import Dataset from .models import ( AutoModelForContactPrediction, @@ -130,6 +131,7 @@ TokenKMerHead, TokenPredictionHead, ) +from .runners import MultiMoleculeConfig, MultiMoleculeRunner from .tasks import Task, TaskLevel, TaskType from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer from .utils import count_parameters diff --git a/multimolecule/apis/__init__.py b/multimolecule/apis/__init__.py new file mode 100644 index 00000000..8e3e5b3c --- /dev/null +++ b/multimolecule/apis/__init__.py @@ -0,0 +1,19 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .run import evaluate, infer, train + +__all__ = ["train", "evaluate", "infer"] diff --git a/multimolecule/apis/run.py b/multimolecule/apis/run.py new file mode 100644 index 00000000..1fdb7666 --- /dev/null +++ b/multimolecule/apis/run.py @@ -0,0 +1,115 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +# mypy: disable-error-code="attr-defined" + +import atexit +import os +import warnings +from typing import Type + +import danling as dl +import torch + +from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner + +try: + import nni +except ImportError: + nni = None + + +def train( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = True + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if config.get("nni", False): + if nni is None: + raise ValueError("Unable to retrieve nni parameters, since nni is not installed.") + config.merge(nni.get_next_parameter()) + with dl.debug(config.get("debug", False)): + runner = runner_cls(config) + atexit.register(runner.print_result) + atexit.register(runner.save_result) + atexit.register(runner.save_checkpoint) + result = runner.train() + return result + + +def evaluate( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig.empty() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = False + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if "checkpoint" not in config or not isinstance(config.checkpoint, str): + raise RuntimeError("Please specify `checkpoint` to run evaluate") + for name, data in config.datas.items(): + if "evaluation" not in data or not isinstance(data.evaluate, str): + raise RuntimeError(f"Please specify `evaluation` to run evaluate in datas.{name}") + runner = runner_cls(config) + result = runner.evaluate_epoch("evaluation") + print(result) + return result + + +def infer( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig.empty() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = False + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if "checkpoint" not in config or not isinstance(config.checkpoint, str): + raise RuntimeError("Please specify `checkpoint` to run infer.") + for name, data in config.datas.items(): + if "inference" not in data or not isinstance(data.inference, str): + raise RuntimeError(f"Please specify `inference` to run infer in datas.{name}") + if "result_path" not in config or not isinstance(config.result_path, str): + config.result_path = os.path.join(os.getcwd(), "result.json") + warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2) + runner = runner_cls(config) + result = runner.infer() + runner.save(result, config.result_path) + return result diff --git a/multimolecule/apis/stat.py b/multimolecule/apis/stat.py new file mode 100644 index 00000000..3f7b7168 --- /dev/null +++ b/multimolecule/apis/stat.py @@ -0,0 +1,99 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import os +import shutil +from statistics import mean +from typing import List + +import chanfig +import pandas as pd +from chanfig import NestedDict +from tqdm import tqdm + + +class Result(NestedDict): + pretrained: str + id: str + seed: int + epoch: int + validation: NestedDict + test: NestedDict + + +def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]: + results = [] + for root, _, files in tqdm(os.walk(experiment_root)): + if "run.log" in files: + if "best.json" not in files: + if remove_empty: + shutil.rmtree(root) + continue + best = NestedDict.from_json(os.path.join(root, "best.json")) + if "index" not in best: + if remove_empty: + shutil.rmtree(root) + continue + config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml")) + pretrained = config.pretrained.split("/")[-1] + seed = config.seed + result = Result(id=best.id, pretrained=pretrained, seed=seed) + result.validation = NestedDict( + {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.all_items()} + ) + result.test = NestedDict( + {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.all_items()} + ) + result.epoch = best.index + result.pop("validation.time", None) + result.pop("test.time", None) + result.pop("validation.loss", None) + result.pop("test.loss", None) + result.pop("validation.lr", None) + result.pop("test.lr", None) + results.append(result) + # Remove empty directories, perform twice to remove all empty directories + if remove_empty: + for root, dirs, files in os.walk(experiment_root): + if not files and not dirs: + os.rmdir(root) + for root, dirs, files in os.walk(experiment_root): + if not files and not dirs: + os.rmdir(root) + results.sort(key=lambda x: (x.pretrained, x.seed, x.id)) + return results + + +def write_result_stat(results: List[Result], path: str): + results = [dict(result.all_items()) for result in results] # type: ignore[misc] + df = pd.DataFrame.from_dict(results) + df.insert(len(df.keys()) - 1, "comment", "") + df.fillna("") + df.to_csv(path, index=False) + + +class Config(chanfig.Config): + experiment_root: str = "experiments" + out_path: str = "result.csv" + remove_empty: bool = True + + +if __name__ == "__main__": + config = Config().parse() + result_stat = get_result_stat(config.experiment_root, config.remove_empty) + if not len(result_stat) > 0: + raise ValueError("No results found") + write_result_stat(result_stat, config.out_path) diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py index b24fd6cf..d5e9ac2e 100644 --- a/multimolecule/data/__init__.py +++ b/multimolecule/data/__init__.py @@ -20,9 +20,13 @@ # https://multimolecule.danling.org/about/license-faq from .dataset import Dataset +from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler from .utils import no_collate __all__ = [ "Dataset", + "MultiTaskDataset", + "MultiTaskSampler", + "DistributedMultiTaskSampler", "no_collate", ] diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py new file mode 100644 index 00000000..7c20e829 --- /dev/null +++ b/multimolecule/data/multitask.py @@ -0,0 +1,246 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from bisect import bisect_right +from collections.abc import Iterator, Mapping, Sequence +from copy import deepcopy +from random import choices + +import torch +from chanfig import NestedDict +from torch import distributed as dist +from torch.utils import data + +from .dataset import Dataset + + +class MultiTaskDataset(data.ConcatDataset): + + datasets: Mapping + dataset_keys: Sequence[str] + dataset_values: Sequence[Dataset] + + def __init__(self, datasets: Mapping) -> None: + for key, dataset in datasets.items(): + if not isinstance(dataset, Dataset): + raise TypeError(f"Dataset {key} should be an instance of Dataset") + self.datasets = datasets + if not len(self.datasets) > 0: + raise ValueError("MultiTaskDataset should contain at least one dataset") + self.dataset_keys, self.dataset_values = zip(*self.datasets.items()) + self.cumulative_sizes = self.cumsum(self.dataset_values) + + def __getitems__(self, key: Sequence[int]) -> Mapping: + dataset_idx = bisect_right(self.cumulative_sizes, key[0]) + if dataset_idx == 0: + sample_idx = key + else: + sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key] + batch = self.dataset_values[dataset_idx][sample_idx] + batch["dataset"] = self.dataset_keys[dataset_idx] + return batch + + @property + def tasks(self) -> NestedDict: + tasks = NestedDict() + for dataset in self.dataset_values: + for n, t in dataset.tasks.items(): + if n not in tasks: + tasks[n] = t + elif tasks[n] != t: + raise ValueError(f"Task {n} has different configurations across datasets") + return tasks + + @property + def dataset_tasks(self) -> NestedDict: + return NestedDict({k: v.tasks for k, v in self.datasets.items()}) + + def __repr__(self) -> str: + return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})" + + +class MultiTaskSampler(data.BatchSampler): + r""" + Ensure all items in a batch comes from the same dataset. + + Arguments: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + """ + + datasets: Sequence[Dataset] + + def __init__( # pylint: disable=super-init-not-called + self, + dataset: MultiTaskDataset, + batch_size: int, + shuffle: bool = True, + drop_last: bool = False, + sampler_cls: type[data.Sampler] | None = None, + weights: list[int] | None = None, + ) -> None: + self.datasets = dataset.dataset_values + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + if sampler_cls is None: + sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler + self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore + self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore + self.cumulative_sizes = dataset.cumulative_sizes + self.num_datasets = len(self.datasets) + self.weights = weights if weights is not None else self.dataset_sizes + + def __iter__(self): + sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)] + sampler_weights = deepcopy(self.weights) + sampler_idx = 0 + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + if self.drop_last: + while sampler_iters: + if self.shuffle: + sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0] + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + try: + batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)] + yield batch + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + else: + while sampler_iters: + if self.shuffle: + sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0] + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + batch = [0] * self.batch_size + idx_in_batch = 0 + try: + for _ in range(self.batch_size): + batch[idx_in_batch] = next(sampler_iter) + cumulative_size + idx_in_batch += 1 + yield batch + idx_in_batch = 0 # noqa: SIM113 + batch = [0] * self.batch_size + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self): + batch_size = self.batch_size + if self.drop_last: + return sum(len(d) // batch_size for d in self.datasets) + return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets) + + +class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods + r""" + Distributed version of MultiTaskSampler, which ensures that all GPUs sample data from the + same sub-dataset in each step without requiring additional communication. + The dataset selection is based on a random seed mechanism that is synchronized across epochs. + + See Also: + [MultiTaskSampler][MultiTaskSampler] + """ + + def __init__( + self, + dataset: MultiTaskDataset, + batch_size: int, + shuffle: bool = True, + drop_last: bool = False, + sampler_cls: type[data.Sampler] = data.RandomSampler, + weights: list[int] | None = None, + seed: int = 0, + ) -> None: + super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights) + self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets] + self.seed = seed + self.epoch = 0 + + def set_epoch(self, epoch: int): + """ + Sets the epoch for deterministic shuffling. + """ + self.epoch = epoch + for sampler in self.samplers: + sampler.set_epoch(epoch) + + def _get_sampler_idx(self, high: int) -> int: + """ + Determines which sampler (i.e., sub-dataset) to use based on the seed and epoch. + """ + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + sampler_idx = torch.randint(low=0, high=high, size=(1,), generator=g).item() + return sampler_idx + + def __iter__(self) -> Iterator: + sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)] + sampler_weights = deepcopy(self.weights) + + if self.drop_last: + while sampler_iters: + # Sample the same sub-dataset across all GPUs using the seeded index + sampler_idx = self._get_sampler_idx(len(sampler_iters)) + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + try: + batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)] + yield batch + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + else: + while sampler_iters: + # Sample the same sub-dataset across all GPUs using the seeded index + sampler_idx = self._get_sampler_idx(len(sampler_iters)) + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + batch = [0] * self.batch_size + idx_in_batch = 0 + try: + for _ in range(self.batch_size): + batch[idx_in_batch] = next(sampler_iter) + cumulative_size + idx_in_batch += 1 + yield batch + idx_in_batch = 0 # noqa: SIM113 + batch = [0] * self.batch_size + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self) -> int: + batch_size = self.batch_size * self.world_size + if self.drop_last: + return sum(len(d) // batch_size for d in self.datasets) + return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets) + + @property + def world_size(self) -> int: + r"""Return the number of processes in the current process group.""" + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return 1 diff --git a/multimolecule/runners/README.md b/multimolecule/runners/README.md new file mode 100644 index 00000000..bb1000ad --- /dev/null +++ b/multimolecule/runners/README.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# runners + +`runners` provide an easy-to-use interface for running experiments. diff --git a/multimolecule/runners/__init__.py b/multimolecule/runners/__init__.py new file mode 100644 index 00000000..70fa4076 --- /dev/null +++ b/multimolecule/runners/__init__.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .config import MultiMoleculeConfig +from .runner import MultiMoleculeRunner + +__all__ = ["MultiMoleculeConfig", "MultiMoleculeRunner"] diff --git a/multimolecule/runners/base_runner.py b/multimolecule/runners/base_runner.py new file mode 100644 index 00000000..ccf18ce6 --- /dev/null +++ b/multimolecule/runners/base_runner.py @@ -0,0 +1,357 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import math +import os +from functools import cached_property, partial +from typing import Any, Tuple +from warnings import warn + +import danling as dl +import torch +from art import text2art +from chanfig import FlatDict, NestedDict +from danling import MultiTaskMetricMeters, MultiTaskMetrics +from datasets import disable_progress_bars, get_dataset_split_names +from lazy_imports import try_import +from torch import nn +from torch.nn import functional as F +from torch.utils import data +from tqdm import tqdm +from transformers import AutoTokenizer + +from multimolecule import defaults +from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler +from multimolecule.module import HeadConfig, ModelRegistry, MultiMoleculeModel + +from .config import MultiMoleculeConfig +from .metrics import MetricRegistry + +with try_import() as ema: + from ema_pytorch import EMA + + +disable_progress_bars() + + +class BaseRunner(dl.BaseRunner): + + config: MultiMoleculeConfig + model: MultiMoleculeModel + all_datasets: NestedDict + + def __init__(self, config: MultiMoleculeConfig): + if config.art: + print(text2art("MultiMolecule", "rand-large")) + super().__init__(config) + # must build tokenizer before datasets + self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained) + self.build_datasets() + self.build_dataloaders() + self.model = ModelRegistry.build(**self.network) + self.model = self.model.to(self.device) + ema_enabled = self.config.ema.pop("enabled", False) + if ema_enabled: + ema.check() + self.ema = EMA(self.model, coerce_dtype=True) + self.config.ema.enabled = ema_enabled + if self.config.training: + optim_name = self.config.optim.pop("name", "AdamW") + pretrained_ratio = self.config.optim.pop("pretrained_ratio", 1e-2) + self.optimizer = self.get_optimizer(optim_name)( + params=self.model.trainable_parameters(pretrained_ratio=pretrained_ratio, **self.config.optim), + **self.config.optim, + ) + self.config.optim.name = optim_name + self.config.optim.pretrained_ratio = pretrained_ratio + if self.config.sched: + self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched) + self.train_metrics = self.build_train_metrics() + self.evaluate_metrics = self.build_evaluate_metrics() + + def __post_init__(self): + if self.config.platform != "deepspeed" and "checkpoint" in self.config: + self.load_checkpoint(self.config.checkpoint) + if self.distributed: + self.model = nn.parallel.DistributedDataParallel( + self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True + ) + super().__post_init__() + if self.config.platform == "deepspeed" and "checkpoint" in self.config: + self.load_checkpoint(self.config.checkpoint) + self.yaml(os.path.join(self.dir, "trainer.yaml")) + print(self) + print(self.get_dataset_lengths()) + + def train_step(self, data) -> Tuple[Any, torch.Tensor]: + with self.autocast(), self.accumulate(): + pred = self.model(**data) + loss = self.loss_fn(pred, data) + self.advance(loss) + self.metric_fn(pred, data) + return pred, loss + + def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]: + model = self.ema or self.model + pred = model(**data) + loss = self.loss_fn(pred, data) + self.metric_fn(pred, data) + return pred, loss + + @torch.inference_mode() + def infer(self, split: str = "inf") -> NestedDict | FlatDict | list: + r""" + Perform inference on `split`. + + Args: + split (str): split to run inference + + Return: + Inference outputs. + + - If the model has single output: + + - If labels are available, a [`FlatDict`][chanfig.FlatDict] with keys `predict` and `label` is + returned. + - If labels are not available, a list of predictions is returned. + + - If the model has multiple outputs: + - If labels are available, a [`NestedDict`][chanfig.NestedDict] with keys as task names and values + as dictionaries with keys `predict` and `label` is returned. + - If labels are not available, a [`FlatDict`][chanfig.FlatDict] with keys as task names and values + as lists of predictions is returned. + """ + + self.mode = "inf" # type: ignore + loader = self.dataloaders[split] + preds = FlatDict() + labels = FlatDict() + model = self.ema or self.model + for _, data in tqdm(enumerate(loader), total=len(loader)): # noqa: F402 + pred = model(**data) + for task, p in pred.items(): + preds[task].extend(p["logits"].squeeze(-1).tolist()) + if task in data: + labels[task].extend(data[task].squeeze(-1).tolist()) + + if self.distributed: + torch.cuda.synchronize() + for task in preds.keys(): + preds[task] = self.gather_for_metrics(preds[task]) + for task in labels.keys(): + labels[task] = self.gather_for_metrics(labels[task]) + if labels: + if len(preds) == 1: + return FlatDict(predict=next(iter(preds.values())), label=next(iter(labels.values()))) + return NestedDict({task: {"predict": preds[task], "label": labels[task]} for task in preds}) + if len(preds) == 1: + return next(iter(preds.values())) + return preds + + def loss_fn(self, pred, data): + if self.balance == "rlw": + loss = torch.stack([p["loss"] for p in pred.values()]) + weight = F.softmax(torch.randn(len(pred), device=loss.device, dtype=loss.dtype), dim=-1) + return loss.T @ weight + if self.balance == "gls": + return math.prod(p["loss"] for p in pred.values()) ** (1 / len(pred)) + if self.balance != "ew": + warn(f"Unknown balance method {self.balance}, using equal weighting.") + return sum(p["loss"] for p in pred.values()) / len(pred) + + def metric_fn(self, pred, data): + metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics + metric.update({t: (p["logits"], data[t]) for t, p in pred.items()}) + + @cached_property + def tasks(self): + if not self.datasets: + raise ValueError("No datasets found") + if "train" in self.datasets: + return self.datasets.train.tasks + return next(iter(self.datasets.values())).tasks + + @cached_property + def dataset_tasks(self): + if not self.datasets: + raise ValueError("No datasets found") + dataset = self.datasets.train if "train" in self.datasets else next(iter(self.datasets.values())) + tasks = self.tasks + dataset_tasks = dataset.dataset_tasks if isinstance(dataset, MultiTaskDataset) else dataset.tasks + for dataset in self.datasets.values(): + if isinstance(dataset, MultiTaskDataset): + for dataset_name, tasks_ in dataset.dataset_tasks.items(): + for task_name, task_ in tasks_.items(): + if task_name not in tasks: + raise ValueError(f"Task {task_name} of dataset {dataset_name} is not defined") + task = tasks[task_name] + if task != task_: + warn( + f"Task {task_name} of dataset {dataset_name} has different configurations " + "compared to training data, using training configuration.\n" + "This may lead to unexpected behavior.", + ) + if dataset_name not in dataset_tasks: + dataset_tasks[dataset_name] = NestedDict() + if task_name not in dataset_tasks[dataset_name]: + dataset_tasks[dataset_name][task_name] = task + else: + for task_name, task_ in dataset.tasks.items(): + if task_name not in tasks: + raise ValueError(f"Task {task_name} is not defined") + task = tasks[task_name] + if task != task_: + warn( + f"Task {task_name} has different configurations " + "compared to training data, using training configuration.\n" + "This may lead to unexpected behavior.", + ) + if task_name not in dataset_tasks: + dataset_tasks[task_name] = task + return dataset_tasks + + @cached_property + def network(self): + heads = { + name: HeadConfig(num_labels=task.num_labels, problem_type=task.type, type=task.level) + for name, task in self.tasks.items() + } + if "heads" not in self.config.network: + self.config.network.heads = NestedDict(heads) + else: + self.config.network.heads.merge(heads, overwrite=False) + return self.config.network + + def build_datasets(self): + if "data" in self.config: + self.datasets = self.all_datasets = self._build_dataset(self.config.data) + return + if "datas" in self.config: + self.all_datasets = NestedDict( + {name: self._build_dataset(config, name) for name, config in self.config.datas.items()} + ) + datasets = { + subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict} + for subkey in {k for v in self.all_datasets.values() for k in v} + } + self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()}) + return + raise ValueError("No data configuration found") + + def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict: + name = name or config.root + print(f"Building dataset {name}") + dataset = NestedDict() + train_splits = [key for key in config.keys() if key.startswith(defaults.TRAIN_SPLITS)] + validation_splits = [key for key in config.keys() if key.startswith(defaults.VALIDATION_SPLITS)] + test_splits = [key for key in config.keys() if key.startswith(defaults.TEST_SPLITS)] + inference_splits = [key for key in config.keys() if key.startswith(defaults.INFERENCE_SPLITS)] + all_splits = train_splits + validation_splits + test_splits + inference_splits + ignored_keys = all_splits + ["root"] + dataset_factory = partial( + Dataset, + tokenizer=self.tokenizer, + **{k: v for k, v in config.items() if k not in ignored_keys}, + ) + if os.path.isdir(config.root): + for split in train_splits: + dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="train") + for split in validation_splits: + dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="validation") + for split in test_splits: + dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="test") + for split in inference_splits: + dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split=config[split]) + else: + splits = [k for k in defaults.DATASET_SPLITS if config.get(k) is not None] + if not splits: + existing_splits = get_dataset_split_names(config.root) + if "train" in existing_splits: + config.train = "train" + splits.append("train") + if "validation" in existing_splits: + config.validation = "validation" + splits.append("validation") + if "test" in existing_splits: + config.test = "test" + splits.append("test") + for split in splits: + dataset[split] = dataset_factory(config.root, split=split) + if not dataset: + raise ValueError(f"No datasets built. This is likely due to missing data paths in {config}.") + return dataset + + def build_dataloaders(self): + datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders} + default_kwargs = self.config.get("dataloader", NestedDict()) + dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs}) + for k, d in datasets.items(): + dataloader_kwargs.setdefault(k, NestedDict()) + dataloader_kwargs[k].merge(default_kwargs, overwrite=False) + batch_size = dataloader_kwargs[k].pop("batch_size") + shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True)) + drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True)) + if isinstance(d, MultiTaskDataset): + batch_sampler = ( + DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last) + if self.distributed + else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last) + ) + else: + sampler = ( + data.distributed.DistributedSampler(d, shuffle=shuffle) + if self.distributed + else data.RandomSampler(d) if shuffle else data.SequentialSampler(d) + ) + batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last) + self.dataloaders[k] = data.DataLoader( + d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k] + ) + + def build_train_metrics(self) -> MultiTaskMetricMeters: + return MultiTaskMetricMeters( + { + name: MetricRegistry.build(type=task.type, num_labels=task.num_labels) + for name, task in self.dataset_tasks.all_items() + } + ) + + def build_evaluate_metrics(self) -> MultiTaskMetrics: + return MultiTaskMetrics( + { + name: MetricRegistry.build(type=task.type, num_labels=task.num_labels) + for name, task in self.dataset_tasks.all_items() + } + ) + + def collate_fn(self, batch): + return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()} + + def get_dataset_lengths(self) -> str: + repr = "dataset lengths:\n" + longest_name = max(len(name) for name in self.all_datasets.keys()) + for name, dataset in self.all_datasets.items(): + if isinstance(dataset, NestedDict): + repr += f"{name}:" + if len(name) < longest_name: + repr += " " * (longest_name - len(name)) + repr += "\t\t" + for split, d in dataset.items(): + repr += f" {split}: {len(d)}\t" + else: + repr += f"{name}: {len(dataset)}\t" + repr += "\n" + return repr diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py new file mode 100644 index 00000000..f4be8532 --- /dev/null +++ b/multimolecule/runners/config.py @@ -0,0 +1,105 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import os +from pathlib import Path +from typing import List + +from chanfig import Config +from transformers import PretrainedConfig + + +class DataConfig(Config): + root: str = "." + train: str | None + validation: str | None + test: str | None + feature_cols: List | None = None + label_cols: List | None = None + truncation: bool = True + + +class OptimConfig(Config): + name: str = "AdamW" + lr: float = 1e-3 + weight_decay: float = 1e-2 + pretrained_ratio: float = 1e-2 + + +class EmaConfig(Config): + enabled: bool = False + beta: float = 0.999 + update_after_step: int = 0 + update_every: int = 10 + + +class MultiMoleculeConfig(Config): + name: str + seed: int = 1016 + + balance: str = "ew" + platform: str = "torch" + training: bool = True + + pretrained: str | None + use_pretrained: bool = True + transformers: PretrainedConfig + epoch_end: int = 20 + + data: DataConfig + + tensorboard: bool = True + save_interval: int = 10 + + art: bool = True + allow_tf32: bool = True + reduced_precision_reduction: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.datas = Config(default_factory=DataConfig) + self.dataloader.batch_size = 32 + self.optim = OptimConfig() + self.ema = EmaConfig() + self.sched.final_lr = 0 + + def post(self): + if "pretrained" not in self and "checkpoint" not in self: + raise ValueError("Either one of `pretrained` or `checkpoint` must be specified") + if "data" in self: + if self.datas: + raise ValueError("Only one of `data` or `datas` can be specified, but not both") + del self.datas + if "pretrained" in self: + self["network.backbone.sequence.name"] = self.get("pretrained") + self.name = self.get_name() + self["network.backbone.sequence.use_pretrained"] = self.use_pretrained + + def get_name(self) -> str: + pretrained = self.get("pretrained") + if os.path.exists(pretrained): + path = Path(pretrained) + if os.path.isfile(pretrained): + pretrained = str(path.relative_to(path.parents[1]).with_suffix("")) + else: + pretrained = path.stem + name = pretrained.replace("/", "--") + if "optim" in self: + optim_name = self.optim.get("name", "no") + name += f"-{self.optim.lr}@{optim_name}" + return name + f"-{self.seed}" diff --git a/multimolecule/runners/metrics.py b/multimolecule/runners/metrics.py new file mode 100644 index 00000000..da584cbc --- /dev/null +++ b/multimolecule/runners/metrics.py @@ -0,0 +1,37 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from chanfig import Registry as Registry_ +from danling.metrics import binary_metrics, multiclass_metrics, multilabel_metrics, regression_metrics + + +class Registry(Registry_): + + def build(self, type, num_labels: int | None = None, **kwargs): + if type == "multilabel": + return self.init(self.lookup(type), num_labels=num_labels, **kwargs) + if type == "multiclass": + return self.init(self.lookup(type), num_classes=num_labels, **kwargs) + if type == "regression": + return self.init(self.lookup(type), num_outputs=num_labels, **kwargs) + return self.init(self.lookup(type), **kwargs) + + +MetricRegistry = Registry(key="type") +MetricRegistry.register(binary_metrics, "binary") +MetricRegistry.register(multiclass_metrics, "multiclass") +MetricRegistry.register(multilabel_metrics, "multilabel") +MetricRegistry.register(regression_metrics, "regression") diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py new file mode 100644 index 00000000..97cab4e7 --- /dev/null +++ b/multimolecule/runners/runner.py @@ -0,0 +1,42 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import danling as dl + +from .base_runner import BaseRunner + + +class MultiMoleculeRunner(type): + def __new__(cls, config): + if config.get("platform", "torch") == "torch": + return TorchRunner(config) + if config.platform == "deepspeed": + return DeepSpeedRunner(config) + if config.platform == "accelerate": + return AccelerateRunner(config) + raise ValueError(f"Unsupported platform: {config.platform}") + + +class TorchRunner(BaseRunner, dl.TorchRunner): + pass + + +class DeepSpeedRunner(BaseRunner, dl.DeepSpeedRunner): + pass + + +class AccelerateRunner(BaseRunner, dl.AccelerateRunner): + pass diff --git a/multimolecule/train.py b/multimolecule/train.py new file mode 100644 index 00000000..f146bc97 --- /dev/null +++ b/multimolecule/train.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .apis import train + +if __name__ == "__main__": + train() diff --git a/pyproject.toml b/pyproject.toml index e83e5258..f37b58aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,9 @@ dynamic = [ ] dependencies = [ "accelerate", + "art", "chanfig>=0.0.105", - "danling[torch]>=0.3.11", + "danling[torch]>=0.4.0b1", "datasets", 'StrEnum; python_version < "3.11"', "torch",