From 14c98e0b3f0a5e3141b8b4da724a0bf5f1cde978 Mon Sep 17 00:00:00 2001 From: Martin Hwang Date: Thu, 13 Aug 2020 13:16:25 +0000 Subject: [PATCH] =?UTF-8?q?=20feat:=20train,=20model=20=EB=B6=80=20?= =?UTF-8?q?=EC=BD=94=EB=93=9C=20=EB=A6=AC=ED=8C=A9=ED=86=A0=EB=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 대략적인 뼈대 작성 2. train 코드가 동작하지 않음 Refs: #2 --- .gitignore | 5 ++ conf/dataset/dataset.yml | 5 ++ conf/dataset/your_dataset.yml | 2 - conf/model/model.yml | 3 + conf/model/your_model.yml | 2 - conf/runner/runner.yml | 29 ++++++ main.py | 27 ++++++ requirements.txt | 7 +- src/data.py | 4 - src/dataset/dataset_registry.py | 3 + src/metric.py | 4 - src/model/__init__.py | 0 src/model/net.py | 35 +++++++- src/model/ops.py | 16 ---- src/runner/__init__.py | 0 src/runner/metric.py | 3 + src/runner/runner.py | 79 +++++++++++++++++ src/utils.py | 24 ++++- train.py | 150 ++++++++++++++++++++++++++++++-- utils.py | 20 +++++ utils/collect_env.py | 3 +- 21 files changed, 376 insertions(+), 45 deletions(-) create mode 100644 conf/dataset/dataset.yml delete mode 100644 conf/dataset/your_dataset.yml create mode 100644 conf/model/model.yml delete mode 100644 conf/model/your_model.yml create mode 100644 conf/runner/runner.yml create mode 100644 main.py delete mode 100644 src/data.py create mode 100644 src/dataset/dataset_registry.py delete mode 100644 src/metric.py create mode 100644 src/model/__init__.py delete mode 100644 src/model/ops.py create mode 100644 src/runner/__init__.py create mode 100644 src/runner/metric.py create mode 100644 src/runner/runner.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore index 26962cb..564ade4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +.vscode +output/ +wandb/ +data/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/conf/dataset/dataset.yml b/conf/dataset/dataset.yml new file mode 100644 index 0000000..f605c58 --- /dev/null +++ b/conf/dataset/dataset.yml @@ -0,0 +1,5 @@ +type: CIFAR10 +params: + path: + train: data/cifar10/ + test: data/cifar10/ \ No newline at end of file diff --git a/conf/dataset/your_dataset.yml b/conf/dataset/your_dataset.yml deleted file mode 100644 index 72d77e0..0000000 --- a/conf/dataset/your_dataset.yml +++ /dev/null @@ -1,2 +0,0 @@ -# This your_mode.yml was made by Nick at 19/07/20. -# writing configuration of your dataset. \ No newline at end of file diff --git a/conf/model/model.yml b/conf/model/model.yml new file mode 100644 index 0000000..2faf126 --- /dev/null +++ b/conf/model/model.yml @@ -0,0 +1,3 @@ +type: DeepHash +params: + hash_bits: 48 \ No newline at end of file diff --git a/conf/model/your_model.yml b/conf/model/your_model.yml deleted file mode 100644 index 3e61d59..0000000 --- a/conf/model/your_model.yml +++ /dev/null @@ -1,2 +0,0 @@ -# This your_mode.yml was made by Nick at 19/07/20. -# writing configuration of your model \ No newline at end of file diff --git a/conf/runner/runner.yml b/conf/runner/runner.yml new file mode 100644 index 0000000..64ba0ed --- /dev/null +++ b/conf/runner/runner.yml @@ -0,0 +1,29 @@ +type: Runner + +dataloader: + type: DataLoader + params: + num_workers: 48 + batch_size: 256 + +optimizer: + type: SGD + params: + learning_rate: 1e-2 + momentum: .9 + +scheduler: + type: MultiStepLR + params: + gamma: .1 + +trainer: + type: Trainer + params: + max_epochs: 128 + gpus: -1 + distributed_backend: "ddp" + +experiments: + name: martin_deephash + output_dir: output/runs \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..b0ea5f4 --- /dev/null +++ b/main.py @@ -0,0 +1,27 @@ +"""DeepHash: Deep Learning of Binary Hash Codes for Fast Image Retrieval +Usage: + main.py [...] + main.py (-h | --help) +Available commands: + train + predict + test +Options: + -h --help Show this. +See 'python main.py --help' for more information on a specific command. +""" +from pathlib import Path + +from type_docopt import docopt + +if __name__ == "__main__": + args = docopt(__doc__, options_first=True) + argv = [args[""]] + args[""] + + if args[""] == "train": + from train import __doc__, train + + train(docopt(__doc__, argv=argv, types={"path": Path})) + + else: + raise NotImplementedError(f"Command does not exist: {args['']}") diff --git a/requirements.txt b/requirements.txt index ad4c783..24df858 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,8 @@ torch==1.5.0 torchvision==0.6.0 -pytorch-lightning==0.8.5 \ No newline at end of file +pytorch-lightning==0.8.5 +omegaconf==2.0.1rc11 +type-docopt==0.1.0 +isort==5.4.0 +black==19.10b0 +wandb==0.9.4 \ No newline at end of file diff --git a/src/data.py b/src/data.py deleted file mode 100644 index 593174f..0000000 --- a/src/data.py +++ /dev/null @@ -1,4 +0,0 @@ -""" - This script was made by Nick at 19/07/20. - To implement code for data pipeline. (e.g. custom class subclassing torch.utils.data.Dataset) -""" diff --git a/src/dataset/dataset_registry.py b/src/dataset/dataset_registry.py new file mode 100644 index 0000000..b65660f --- /dev/null +++ b/src/dataset/dataset_registry.py @@ -0,0 +1,3 @@ +from torchvision import datasets + +DataSetRegistry = {"CIFAR10": datasets.CIFAR10} diff --git a/src/metric.py b/src/metric.py deleted file mode 100644 index aa8d065..0000000 --- a/src/metric.py +++ /dev/null @@ -1,4 +0,0 @@ -""" - This script was made by Nick at 19/07/20. - To implement code for metric (e.g. NLL loss). -""" diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/net.py b/src/model/net.py index 44f6318..87aa48f 100644 --- a/src/model/net.py +++ b/src/model/net.py @@ -1,4 +1,31 @@ -""" - This script was made by Nick at 19/07/20. - To implement code of your network using operation from ops.py. -""" +import os + +import torch +import torch.nn as nn +from torchvision import models + +alexnet_model = models.alexnet(pretrained=True) + + +class DeepHash(nn.Module): + def __init__(self, hash_bits: int): + """ + Args: + bits (int): lenght of encoded binary bits + """ + super(DeepHash, self).__init__() + self.hash_bits = hash_bits + self.features = nn.Sequential(*list(alexnet_model.features.children())) + self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1]) + self.Linear1 = nn.Linear(4096, self.hash_bits) + self.sigmoid = nn.Sigmoid() + self.Linear2 = nn.Linear(self.hash_bits, 10) + + def forward(self, x: torch.Tensor): + x = self.features(x) + x = x.view(x.size(0), 256 * 6 * 6) + x = self.remain(x) + x = self.Linear1(x) + features = self.sigmoid(x) + result = self.Linear2(features) + return features, result diff --git a/src/model/ops.py b/src/model/ops.py deleted file mode 100644 index 0c82149..0000000 --- a/src/model/ops.py +++ /dev/null @@ -1,16 +0,0 @@ -""" - This script was made by Nick at 19/07/20. - To implement code of your operation to be being used your network. -""" - - -def multiply(a, b): - return a * b - - -def add(a, b): - return a + b - - -def subtract(a, b): - return a - b diff --git a/src/runner/__init__.py b/src/runner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/runner/metric.py b/src/runner/metric.py new file mode 100644 index 0000000..af4f95a --- /dev/null +++ b/src/runner/metric.py @@ -0,0 +1,3 @@ +import torch.nn as nn + +cross_entropy = nn.CrossEntropyLoss() diff --git a/src/runner/runner.py b/src/runner/runner.py new file mode 100644 index 0000000..f785626 --- /dev/null +++ b/src/runner/runner.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +from omegaconf import DictConfig +from pytorch_lightning.core import LightningModule +from torch.optim import SGD +from torch.optim.lr_scheduler import MultiStepLR + + +class Runner(LightningModule): + def __init__(self, model: nn.Module, config: DictConfig): + super().__init__() + self.model = model + self.hparams.update({"dataset": f"{config.dataset.type}"}) + self.hparams.update({"model": f"{config.model.type}"}) + self.hparams.update(config.model.params) + self.hparams.update(config.runner.dataloader.params) + self.hparams.update({"optimizer": f"{config.runner.optimizer.params.type}"}) + self.hparams.update(config.runner.optimizer.params) + self.hparams.update({"scheduler": f"{config.runner.scheduler.type}"}) + self.hparams.update({"scheduler_gamma": f"{config.runner.scheduler.params.gamma}"}) + self.hparams.update(config.runner.trainer.params) + print(self.hparams) + + def forward(self, x): + return self.model(x) + + def configure_optimizers(self): + opt = SGD(params=self.model.parameters(), lr=self.hparams.learning_rate) + scheduler = MultiStepLR(opt, milestones=[self.hparams.max_epochs], gamma=self.hparams.scheduler_gamma) + return [opt], [scheduler] + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = cross_entropy(y_hat, y) + + prediction = torch.argmax(y_hat, dim=1) + acc = (y == prediction).float().mean() + + return {"loss": loss, "train_acc": acc} + + def training_epoch_end(self, outputs): + avg_loss = torch.stack([x["loss"] for x in outputs]).mean() + avg_acc = torch.stack([x["train_acc"] for x in outputs]).mean() + tqdm_dict = {"train_acc": avg_acc, "train_loss": avg_loss} + return {**tqdm_dict, "progress_bar": tqdm_dict} + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = cross_entropy(y_hat, y) + prediction = torch.argmax(y_hat, dim=1) + number_of_correct_pred = torch.sum(y == prediction).item() + return {"val_loss": loss, "n_correct_pred": number_of_correct_pred, "n_pred": len(x)} + + def validation_epoch_end(self, outputs): + total_count = sum([x["n_pred"] for x in outputs]) + total_n_correct_pred = sum([x["n_correct_pred"] for x in outputs]) + total_loss = torch.stack([x["val_loss"] * x["n_pred"] for x in outputs]).sum() + val_loss = total_loss / total_count + val_acc = total_n_correct_pred / total_count + tqdm_dict = {"val_acc": val_acc, "val_loss": val_loss} + return {**tqdm_dict, "progress_bar": tqdm_dict} + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = cross_entropy(y_hat, y) + prediction = torch.argmax(y_hat, dim=1) + number_of_correct_pred = torch.sum(y == prediction).item() + return {"loss": loss, "n_correct_pred": number_of_correct_pred, "n_pred": len(x)} + + def test_epoch_end(self, outputs): + total_count = sum([x["n_pred"] for x in outputs]) + total_n_correct_pred = sum([x["n_correct_pred"] for x in outputs]) + total_loss = torch.stack([x["loss"] * x["n_pred"] for x in outputs]).sum().item() + test_loss = total_loss / total_count + test_acc = total_n_correct_pred / total_count + return {"loss": test_loss, "acc": test_acc} diff --git a/src/utils.py b/src/utils.py index 9ebfbcb..647100a 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,4 +1,20 @@ -""" - This script was made by Nick at 19/07/20. - To implement code for utility. -""" +from pathlib import Path + + +def get_next_version(root_dir: Path): + version_prefix = "v" + if not root_dir.exists(): + next_version = 0 + else: + existing_versions = [] + for child_path in root_dir.iterdir(): + if child_path.is_dir() and child_path.name.startswith(version_prefix): + existing_versions.append(int(child_path.name[len(version_prefix) :])) + + if len(existing_versions) == 0: + last_version = -1 + else: + last_version = max(existing_versions) + + next_version = last_version + 1 + return f"{version_prefix}{next_version:0>3}" diff --git a/train.py b/train.py index 9f716b1..51ad591 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,146 @@ """ - This script was made by Nick at 19/07/20. - To implement code for training your model. +Usage: + main.py train [options] + main.py train (-h | --help) +Options: + --dataset-config Path to YAML file for dataset configuration [default: conf/dataset/dataset.yml] [type: path] + --model-config Path to YAML file for model configuration [default: conf/model/model.yml] [type: path] + --runner-config Path to YAML file for model configuration [default: conf/runner/runner.yml] [type: path] + -h --help Show this. """ -import pytorch_lightning -import torch -pytorch_lightning.seed_everything(777) -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False + +import pickle + +from argparse import ArgumentParser, Namespace +from pathlib import Path +from typing import Dict, List, Tuple, Union + +from omegaconf import DictConfig, OmegaConf + +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, LearningRateLogger, EarlyStopping +from pytorch_lightning.loggers import WandbLogger + +from torch.utils.data import DataLoader +import torchvision.transforms as transforms + +from src.dataset.dataset_registry import DataSetRegistry +from src.model.net import DeepHash +from src.runner.runner import Runner +from src.utils import get_next_version + +# from src.utils.preprocessing import PreProcessor + + +def get_config(hparams: Dict) -> DictConfig: + config = OmegaConf.create() + + config_dir = Path(".") + model_config = OmegaConf.load(config_dir / hparams.get("--model-config")) + dataset_config = OmegaConf.load(config_dir / hparams.get("--dataset-config")) + runner_config = OmegaConf.load(config_dir / hparams.get("--runner-config")) + + config.update(model=model_config, dataset=dataset_config, runner=runner_config) + OmegaConf.set_readonly(config, True) + + return config + + +def get_logger_and_callbacks( + config: DictConfig, +) -> Tuple[WandbLogger, Union[Callback, List[Callback]]]: + + root_dir = Path(config.runner.experiments.output_dir) / Path(config.runner.experiments.name) + next_version = get_next_version(root_dir) + checkpoint_dir = root_dir.joinpath(next_version) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + wandb_logger = WandbLogger( + name=str(config.runner.experiments.name), + save_dir=str(checkpoint_dir), + offline=True, + version=next_version, + ) + + checkpoint_prefix = f"{config.model.type}" + checkpoint_suffix = "_{epoch:02d}-{tr_loss:.2f}-{val_loss:.2f}-{tr_acc:.2f}-{val_acc:.2f}" + checkpoint_path = checkpoint_dir.joinpath(checkpoint_prefix + checkpoint_suffix) + checkpoint_callback = ModelCheckpoint( + filepath=checkpoint_path, save_top_k=2, save_weights_only=True + ) + + return wandb_logger, checkpoint_callback + + +def get_data_loaders(config: DictConfig) -> Tuple[DataLoader, DataLoader]: + + dataset = DataSetRegistry.get(config.dataset.type) + train_dataset = dataset( + root=config.dataset.params.path.train, + train=True, + transform=transforms.ToTensor(), + download=True, + ) + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=config.runner.dataloader.params.batch_size, + num_workers=config.runner.dataloader.params.num_workers, + drop_last=True, + shuffle=True, + ) + test_dataset = dataset( + root=config.dataset.params.path.test, + train=False, + transform=transforms.ToTensor(), + download=True, + ) + test_dataloader = DataLoader( + dataset=test_dataset, + batch_size=config.runner.dataloader.params.batch_size, + num_workers=config.runner.dataloader.params.num_workers, + drop_last=False, + shuffle=False, + ) + return train_dataloader, test_dataloader + + +def train(hparams: dict): + config = get_config(hparams=hparams) + + wandb_logger, checkpoint_callback = get_logger_and_callbacks(config=config) + lr_logger = LearningRateLogger() + + early_stop_callback = EarlyStopping( + monitor="val_acc", min_delta=0.00, patience=10, verbose=True, mode="max" + ) + + train_dataloader, test_dataloader = get_data_loaders(config=config) + + model = DeepHash(hash_bits=config.model.params.hash_bits) + runner = Runner(model=model, config=config) + + trainer = Trainer( + distributed_backend=config.runner.trainer.distributed_backend, + fast_dev_run=False, + gpus=config.runner.trainer.gpus, + amp_level="O2", + logger=wandb_logger, + row_log_interval=10, + callbacks=[lr_logger], + early_stop_callback=early_stop_callback, + checkpoint_callback=checkpoint_callback, + max_epochs=config.runner.max_epochs, + weights_summary="top", + reload_dataloaders_every_epoch=False, + resume_from_checkpoint=None, + benchmark=False, + deterministic=True, + num_sanity_val_steps=5, + overfit_batches=0.0, + precision=32, + profiler=True, + ) + trainer.fit( + model=runner, train_dataloader=train_dataloader, val_dataloaders=test_dataloader, + ) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..647100a --- /dev/null +++ b/utils.py @@ -0,0 +1,20 @@ +from pathlib import Path + + +def get_next_version(root_dir: Path): + version_prefix = "v" + if not root_dir.exists(): + next_version = 0 + else: + existing_versions = [] + for child_path in root_dir.iterdir(): + if child_path.is_dir() and child_path.name.startswith(version_prefix): + existing_versions.append(int(child_path.name[len(version_prefix) :])) + + if len(existing_versions) == 0: + last_version = -1 + else: + last_version = max(existing_versions) + + next_version = last_version + 1 + return f"{version_prefix}{next_version:0>3}" diff --git a/utils/collect_env.py b/utils/collect_env.py index 78f1cc5..96d18fe 100644 --- a/utils/collect_env.py +++ b/utils/collect_env.py @@ -1,6 +1,7 @@ # This script outputs relevant system environment info # Run it with `python collect_env.py`. -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import (absolute_import, division, print_function, + unicode_literals) import locale import os