generated from HephaestusProject/template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. 대략적인 뼈대 작성 2. train 코드가 동작하지 않음 Refs: #2
- Loading branch information
Showing
21 changed files
with
376 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
.vscode | ||
output/ | ||
wandb/ | ||
data/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
type: CIFAR10 | ||
params: | ||
path: | ||
train: data/cifar10/ | ||
test: data/cifar10/ |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
type: DeepHash | ||
params: | ||
hash_bits: 48 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
type: Runner | ||
|
||
dataloader: | ||
type: DataLoader | ||
params: | ||
num_workers: 48 | ||
batch_size: 256 | ||
|
||
optimizer: | ||
type: SGD | ||
params: | ||
learning_rate: 1e-2 | ||
momentum: .9 | ||
|
||
scheduler: | ||
type: MultiStepLR | ||
params: | ||
gamma: .1 | ||
|
||
trainer: | ||
type: Trainer | ||
params: | ||
max_epochs: 128 | ||
gpus: -1 | ||
distributed_backend: "ddp" | ||
|
||
experiments: | ||
name: martin_deephash | ||
output_dir: output/runs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""DeepHash: Deep Learning of Binary Hash Codes for Fast Image Retrieval | ||
Usage: | ||
main.py <command> [<args>...] | ||
main.py (-h | --help) | ||
Available commands: | ||
train | ||
predict | ||
test | ||
Options: | ||
-h --help Show this. | ||
See 'python main.py <command> --help' for more information on a specific command. | ||
""" | ||
from pathlib import Path | ||
|
||
from type_docopt import docopt | ||
|
||
if __name__ == "__main__": | ||
args = docopt(__doc__, options_first=True) | ||
argv = [args["<command>"]] + args["<args>"] | ||
|
||
if args["<command>"] == "train": | ||
from train import __doc__, train | ||
|
||
train(docopt(__doc__, argv=argv, types={"path": Path})) | ||
|
||
else: | ||
raise NotImplementedError(f"Command does not exist: {args['<command>']}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
torch==1.5.0 | ||
torchvision==0.6.0 | ||
pytorch-lightning==0.8.5 | ||
pytorch-lightning==0.8.5 | ||
omegaconf==2.0.1rc11 | ||
type-docopt==0.1.0 | ||
isort==5.4.0 | ||
black==19.10b0 | ||
wandb==0.9.4 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from torchvision import datasets | ||
|
||
DataSetRegistry = {"CIFAR10": datasets.CIFAR10} |
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,31 @@ | ||
""" | ||
This script was made by Nick at 19/07/20. | ||
To implement code of your network using operation from ops.py. | ||
""" | ||
import os | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torchvision import models | ||
|
||
alexnet_model = models.alexnet(pretrained=True) | ||
|
||
|
||
class DeepHash(nn.Module): | ||
def __init__(self, hash_bits: int): | ||
""" | ||
Args: | ||
bits (int): lenght of encoded binary bits | ||
""" | ||
super(DeepHash, self).__init__() | ||
self.hash_bits = hash_bits | ||
self.features = nn.Sequential(*list(alexnet_model.features.children())) | ||
self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1]) | ||
self.Linear1 = nn.Linear(4096, self.hash_bits) | ||
self.sigmoid = nn.Sigmoid() | ||
self.Linear2 = nn.Linear(self.hash_bits, 10) | ||
|
||
def forward(self, x: torch.Tensor): | ||
x = self.features(x) | ||
x = x.view(x.size(0), 256 * 6 * 6) | ||
x = self.remain(x) | ||
x = self.Linear1(x) | ||
features = self.sigmoid(x) | ||
result = self.Linear2(features) | ||
return features, result |
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import torch.nn as nn | ||
|
||
cross_entropy = nn.CrossEntropyLoss() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
import torch.nn as nn | ||
from omegaconf import DictConfig | ||
from pytorch_lightning.core import LightningModule | ||
from torch.optim import SGD | ||
from torch.optim.lr_scheduler import MultiStepLR | ||
|
||
|
||
class Runner(LightningModule): | ||
def __init__(self, model: nn.Module, config: DictConfig): | ||
super().__init__() | ||
self.model = model | ||
self.hparams.update({"dataset": f"{config.dataset.type}"}) | ||
self.hparams.update({"model": f"{config.model.type}"}) | ||
self.hparams.update(config.model.params) | ||
self.hparams.update(config.runner.dataloader.params) | ||
self.hparams.update({"optimizer": f"{config.runner.optimizer.params.type}"}) | ||
self.hparams.update(config.runner.optimizer.params) | ||
self.hparams.update({"scheduler": f"{config.runner.scheduler.type}"}) | ||
self.hparams.update({"scheduler_gamma": f"{config.runner.scheduler.params.gamma}"}) | ||
self.hparams.update(config.runner.trainer.params) | ||
print(self.hparams) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
def configure_optimizers(self): | ||
opt = SGD(params=self.model.parameters(), lr=self.hparams.learning_rate) | ||
scheduler = MultiStepLR(opt, milestones=[self.hparams.max_epochs], gamma=self.hparams.scheduler_gamma) | ||
return [opt], [scheduler] | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self(x) | ||
loss = cross_entropy(y_hat, y) | ||
|
||
prediction = torch.argmax(y_hat, dim=1) | ||
acc = (y == prediction).float().mean() | ||
|
||
return {"loss": loss, "train_acc": acc} | ||
|
||
def training_epoch_end(self, outputs): | ||
avg_loss = torch.stack([x["loss"] for x in outputs]).mean() | ||
avg_acc = torch.stack([x["train_acc"] for x in outputs]).mean() | ||
tqdm_dict = {"train_acc": avg_acc, "train_loss": avg_loss} | ||
return {**tqdm_dict, "progress_bar": tqdm_dict} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self(x) | ||
loss = cross_entropy(y_hat, y) | ||
prediction = torch.argmax(y_hat, dim=1) | ||
number_of_correct_pred = torch.sum(y == prediction).item() | ||
return {"val_loss": loss, "n_correct_pred": number_of_correct_pred, "n_pred": len(x)} | ||
|
||
def validation_epoch_end(self, outputs): | ||
total_count = sum([x["n_pred"] for x in outputs]) | ||
total_n_correct_pred = sum([x["n_correct_pred"] for x in outputs]) | ||
total_loss = torch.stack([x["val_loss"] * x["n_pred"] for x in outputs]).sum() | ||
val_loss = total_loss / total_count | ||
val_acc = total_n_correct_pred / total_count | ||
tqdm_dict = {"val_acc": val_acc, "val_loss": val_loss} | ||
return {**tqdm_dict, "progress_bar": tqdm_dict} | ||
|
||
def test_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self(x) | ||
loss = cross_entropy(y_hat, y) | ||
prediction = torch.argmax(y_hat, dim=1) | ||
number_of_correct_pred = torch.sum(y == prediction).item() | ||
return {"loss": loss, "n_correct_pred": number_of_correct_pred, "n_pred": len(x)} | ||
|
||
def test_epoch_end(self, outputs): | ||
total_count = sum([x["n_pred"] for x in outputs]) | ||
total_n_correct_pred = sum([x["n_correct_pred"] for x in outputs]) | ||
total_loss = torch.stack([x["loss"] * x["n_pred"] for x in outputs]).sum().item() | ||
test_loss = total_loss / total_count | ||
test_acc = total_n_correct_pred / total_count | ||
return {"loss": test_loss, "acc": test_acc} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,20 @@ | ||
""" | ||
This script was made by Nick at 19/07/20. | ||
To implement code for utility. | ||
""" | ||
from pathlib import Path | ||
|
||
|
||
def get_next_version(root_dir: Path): | ||
version_prefix = "v" | ||
if not root_dir.exists(): | ||
next_version = 0 | ||
else: | ||
existing_versions = [] | ||
for child_path in root_dir.iterdir(): | ||
if child_path.is_dir() and child_path.name.startswith(version_prefix): | ||
existing_versions.append(int(child_path.name[len(version_prefix) :])) | ||
|
||
if len(existing_versions) == 0: | ||
last_version = -1 | ||
else: | ||
last_version = max(existing_versions) | ||
|
||
next_version = last_version + 1 | ||
return f"{version_prefix}{next_version:0>3}" |
Oops, something went wrong.