Skip to content

Commit

Permalink
[WIP] add training code
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 24, 2024
1 parent 16e002e commit 497aa54
Show file tree
Hide file tree
Showing 24 changed files with 1,485 additions and 0 deletions.
91 changes: 91 additions & 0 deletions multimolecule/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
from warnings import warn

from chanfig import Config, Variable


class MultiMoleculeConfig(Config): # pylint: disable=R0902, R0903
def __init__(self, *args, **kwargs): # pylint: disable=R0915
super().__init__(*args, **kwargs)
self.sequence.name = "Auto"
self.sequence.pretrained = Variable("multimolecule/rnafm")

self.project_root = "experiments"
self.project_name = "multimolecule"
self.experiment_name = "multimolecule"
self.score_set = "val"
self.score_name = "loss"

self.epoch_end = 30
self.tolerance = 0.05
self.patience = 10

self.data.root = "ncrna"
self.data.train.feature = "train.csv"
self.data.train.label = None
self.dataset.feature_cols = None
self.dataset.label_cols = ["label"]
# self.data.features = None
# self.data.labels = None
self.dataset.truncation = True
self.dataset.max_length = 1024
self.dataset.tokenizer.name = self.sequence.name + "Tokenizer"
self.dataset.tokenizer.pretrained = self.sequence.pretrained
self.dataloader.batch_size = 32
self.dataloader.num_workers = 2
# self.batch_size_base = 32

self.network.name = "BaseModel"
self.network.dropout = 0.2
self.network.backbone.name = "fusion"
self.network.backbone.sequence.name = self.sequence.name + "Model"
self.network.backbone.sequence.pretrained = self.sequence.pretrained
self.network.backbone.sequence.auto_fix = False
# self.network.neck.name = "cat"

self.optim.name = "AdamW"
self.optim.lr = 3e-5
self.optim.weight_decay = 0.01
self.optim.momentum = 0.9
self.lrs.warmup_steps = 50
self.lrs.cooldown_steps = 0
self.lrs.final_lr = 0
self.lrs.strategy = "linear"
self.lrs.method = "numerical"
self.pretrained_lr_ratio = 1
self.loss = {}
self.metric = {}
# self.loss.pos_weight = 1

self.log_interval = None
self.save_interval = 10
self.seed = 2022

self.checkpoint = None
self.auto_resume = False
self.log = True
self.tensorboard = True
self.tracking = False

self.allow_tf32 = True
self.reduced_precision_reduction = False

self.add_argument("--batch_size", type=int, dest="dataloader.batch_size")
self.add_argument("--num_workers", type=int, dest="dataloader.num_workers")
self.add_argument("-lr", type=int, dest="optim.lr")
self.add_argument("-wd", type=int, dest="optim.weight_decay")
self.add_argument(
"-gc", "--gradient_checkpoint", type=bool, dest="network.backbone.sequence.gradient_checkpoint"
)
self.max_grad_norm = 1.0

def post(self):
# pylint: disable=W0201
if not isinstance(self.dataset.label_cols, list) and self.dataset.label_cols:
self.dataset.label_cols = [self.dataset.label_cols]
if not isinstance(self.dataset.feature_cols, list) and self.dataset.feature_cols:
self.dataset.feature_cols = [self.dataset.feature_cols]
data = os.path.basename(os.path.normpath(self.data.root))
sequence = f"{self.sequence.name}_{self.sequence.pretrained}".replace("/", "_")
neck = self.network.neck.name if self.network.neck else "null"
self.run_name = f"{data}-{sequence}-{neck}"
3 changes: 3 additions & 0 deletions multimolecule/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dataset import DATASET_TYPE, BaseDataset

__all__ = ["BaseDataset", "DATASET_TYPE"]
Loading

0 comments on commit 497aa54

Please sign in to comment.