Skip to content

Commit

Permalink
feat: train, model 부 코드 리팩토링
Browse files Browse the repository at this point in the history
    1. 대략적인 뼈대 작성
    2. train 코드가 동작하지 않음

 Refs: #2
  • Loading branch information
ssaru committed Aug 13, 2020
1 parent d0b2c24 commit 14c98e0
Show file tree
Hide file tree
Showing 21 changed files with 376 additions and 45 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
.vscode
output/
wandb/
data/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
5 changes: 5 additions & 0 deletions conf/dataset/dataset.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type: CIFAR10
params:
path:
train: data/cifar10/
test: data/cifar10/
2 changes: 0 additions & 2 deletions conf/dataset/your_dataset.yml

This file was deleted.

3 changes: 3 additions & 0 deletions conf/model/model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
type: DeepHash
params:
hash_bits: 48
2 changes: 0 additions & 2 deletions conf/model/your_model.yml

This file was deleted.

29 changes: 29 additions & 0 deletions conf/runner/runner.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
type: Runner

dataloader:
type: DataLoader
params:
num_workers: 48
batch_size: 256

optimizer:
type: SGD
params:
learning_rate: 1e-2
momentum: .9

scheduler:
type: MultiStepLR
params:
gamma: .1

trainer:
type: Trainer
params:
max_epochs: 128
gpus: -1
distributed_backend: "ddp"

experiments:
name: martin_deephash
output_dir: output/runs
27 changes: 27 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""DeepHash: Deep Learning of Binary Hash Codes for Fast Image Retrieval
Usage:
main.py <command> [<args>...]
main.py (-h | --help)
Available commands:
train
predict
test
Options:
-h --help Show this.
See 'python main.py <command> --help' for more information on a specific command.
"""
from pathlib import Path

from type_docopt import docopt

if __name__ == "__main__":
args = docopt(__doc__, options_first=True)
argv = [args["<command>"]] + args["<args>"]

if args["<command>"] == "train":
from train import __doc__, train

train(docopt(__doc__, argv=argv, types={"path": Path}))

else:
raise NotImplementedError(f"Command does not exist: {args['<command>']}")
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
torch==1.5.0
torchvision==0.6.0
pytorch-lightning==0.8.5
pytorch-lightning==0.8.5
omegaconf==2.0.1rc11
type-docopt==0.1.0
isort==5.4.0
black==19.10b0
wandb==0.9.4
4 changes: 0 additions & 4 deletions src/data.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/dataset/dataset_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchvision import datasets

DataSetRegistry = {"CIFAR10": datasets.CIFAR10}
4 changes: 0 additions & 4 deletions src/metric.py

This file was deleted.

Empty file added src/model/__init__.py
Empty file.
35 changes: 31 additions & 4 deletions src/model/net.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
"""
This script was made by Nick at 19/07/20.
To implement code of your network using operation from ops.py.
"""
import os

import torch
import torch.nn as nn
from torchvision import models

alexnet_model = models.alexnet(pretrained=True)


class DeepHash(nn.Module):
def __init__(self, hash_bits: int):
"""
Args:
bits (int): lenght of encoded binary bits
"""
super(DeepHash, self).__init__()
self.hash_bits = hash_bits
self.features = nn.Sequential(*list(alexnet_model.features.children()))
self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1])
self.Linear1 = nn.Linear(4096, self.hash_bits)
self.sigmoid = nn.Sigmoid()
self.Linear2 = nn.Linear(self.hash_bits, 10)

def forward(self, x: torch.Tensor):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.remain(x)
x = self.Linear1(x)
features = self.sigmoid(x)
result = self.Linear2(features)
return features, result
16 changes: 0 additions & 16 deletions src/model/ops.py

This file was deleted.

Empty file added src/runner/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions src/runner/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import torch.nn as nn

cross_entropy = nn.CrossEntropyLoss()
79 changes: 79 additions & 0 deletions src/runner/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import torch.nn as nn
from omegaconf import DictConfig
from pytorch_lightning.core import LightningModule
from torch.optim import SGD
from torch.optim.lr_scheduler import MultiStepLR


class Runner(LightningModule):
def __init__(self, model: nn.Module, config: DictConfig):
super().__init__()
self.model = model
self.hparams.update({"dataset": f"{config.dataset.type}"})
self.hparams.update({"model": f"{config.model.type}"})
self.hparams.update(config.model.params)
self.hparams.update(config.runner.dataloader.params)
self.hparams.update({"optimizer": f"{config.runner.optimizer.params.type}"})
self.hparams.update(config.runner.optimizer.params)
self.hparams.update({"scheduler": f"{config.runner.scheduler.type}"})
self.hparams.update({"scheduler_gamma": f"{config.runner.scheduler.params.gamma}"})
self.hparams.update(config.runner.trainer.params)
print(self.hparams)

def forward(self, x):
return self.model(x)

def configure_optimizers(self):
opt = SGD(params=self.model.parameters(), lr=self.hparams.learning_rate)
scheduler = MultiStepLR(opt, milestones=[self.hparams.max_epochs], gamma=self.hparams.scheduler_gamma)
return [opt], [scheduler]

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = cross_entropy(y_hat, y)

prediction = torch.argmax(y_hat, dim=1)
acc = (y == prediction).float().mean()

return {"loss": loss, "train_acc": acc}

def training_epoch_end(self, outputs):
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["train_acc"] for x in outputs]).mean()
tqdm_dict = {"train_acc": avg_acc, "train_loss": avg_loss}
return {**tqdm_dict, "progress_bar": tqdm_dict}

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = cross_entropy(y_hat, y)
prediction = torch.argmax(y_hat, dim=1)
number_of_correct_pred = torch.sum(y == prediction).item()
return {"val_loss": loss, "n_correct_pred": number_of_correct_pred, "n_pred": len(x)}

def validation_epoch_end(self, outputs):
total_count = sum([x["n_pred"] for x in outputs])
total_n_correct_pred = sum([x["n_correct_pred"] for x in outputs])
total_loss = torch.stack([x["val_loss"] * x["n_pred"] for x in outputs]).sum()
val_loss = total_loss / total_count
val_acc = total_n_correct_pred / total_count
tqdm_dict = {"val_acc": val_acc, "val_loss": val_loss}
return {**tqdm_dict, "progress_bar": tqdm_dict}

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = cross_entropy(y_hat, y)
prediction = torch.argmax(y_hat, dim=1)
number_of_correct_pred = torch.sum(y == prediction).item()
return {"loss": loss, "n_correct_pred": number_of_correct_pred, "n_pred": len(x)}

def test_epoch_end(self, outputs):
total_count = sum([x["n_pred"] for x in outputs])
total_n_correct_pred = sum([x["n_correct_pred"] for x in outputs])
total_loss = torch.stack([x["loss"] * x["n_pred"] for x in outputs]).sum().item()
test_loss = total_loss / total_count
test_acc = total_n_correct_pred / total_count
return {"loss": test_loss, "acc": test_acc}
24 changes: 20 additions & 4 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
"""
This script was made by Nick at 19/07/20.
To implement code for utility.
"""
from pathlib import Path


def get_next_version(root_dir: Path):
version_prefix = "v"
if not root_dir.exists():
next_version = 0
else:
existing_versions = []
for child_path in root_dir.iterdir():
if child_path.is_dir() and child_path.name.startswith(version_prefix):
existing_versions.append(int(child_path.name[len(version_prefix) :]))

if len(existing_versions) == 0:
last_version = -1
else:
last_version = max(existing_versions)

next_version = last_version + 1
return f"{version_prefix}{next_version:0>3}"
Loading

0 comments on commit 14c98e0

Please sign in to comment.