diff --git a/.gitignore b/.gitignore index b6e4761..0f03a65 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Byte-compiled / optimized / DLL files -__pycache__/ +*/__pycache__/ *.py[cod] *$py.class diff --git a/app.py b/app.py new file mode 100644 index 0000000..3459361 --- /dev/null +++ b/app.py @@ -0,0 +1,99 @@ +from typing import Optional +import dash +import dash_core_components as dcc +import dash_html_components as html +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import json +import numpy as np + +import glob +import os + + + +def collect_first_best(pattern: str, neg: Optional[str] = None): + files = [] + for file in sorted(glob.glob(pattern), reverse=True): + if neg and neg in file: + continue + files.append(file) + + figures = [] + for file in files: + title = os.path.split(file)[1].split('.')[0] + df1 = pd.read_csv(file) + figures.append(px.box(df1, x="map_name", y="counts", + color="algo_type", + notched=True, title=title)) + return figures + + +def collect_line_plots(pattern: str, neg: Optional[str] = None): + files = [] + for file in sorted(glob.glob(pattern), reverse=True): + if neg and neg in file: + continue + files.append(file) + + lineplots = [] + for file in files: + with open(file) as f: + data = json.loads(f.readline().rstrip('\n')) + title = os.path.split(file)[1].split('.')[0] + fig = go.Figure() + fig.update_layout(dict(title=title)) + for map_name, counts in data.items(): + uniform_mean = np.array(counts['uniform']['mean']) + uniform_std = np.array(counts['uniform']['std']) + + y_u1 = (uniform_mean - uniform_std).tolist() + y_u2 = (uniform_mean + uniform_std).tolist() + + roi_mean = np.array(counts['roi']['mean']) + roi_std = np.array(counts['roi']['std']) + y_r1 = (roi_mean - 0.5 * roi_std).tolist() + y_r2 = (roi_mean + roi_std).tolist() + + x1 = list(range(len(uniform_mean))) + x2 = list(range(len(roi_mean))) + + fig.add_trace(go.Scatter(x=x1 + x1[::-1], y=y_u1 + y_u2[::-1], fill='toself', name='RRT*-uniform_' + map_name)) + fig.add_trace(go.Scatter(x=x2 + x2[::-1], y=y_r1 + y_r2[::-1], fill='toself', name='RRT*-roi_' + map_name)) + + #fig.add_trace(go.Scatter(x=x1, y=uniform_mean, name='RRT*-uniform_' + map_name)) + #fig.add_trace(go.Scatter(x=x2, y=roi_mean, name='RRT*-roi_' + map_name)) + fig.update_traces(mode='lines') + lineplots.append(fig) + return lineplots + + +figures = collect_first_best('logs/collected_stats_gan*.csv') +figures.extend(collect_first_best('logs/collected_stats_pix2pix*.csv')) +figures.extend(collect_first_best('logs/gan*.csv', neg='moving_ai')) +figures.extend(collect_first_best('logs/pix2pix*.csv', neg='moving_ai')) +figures.extend(collect_first_best('logs/gan_moving_ai*.csv')) +figures.extend(collect_first_best('logs/pix2pix_moving_ai*.csv')) + +lineplots = collect_line_plots('logs/collected_stats_gan*.plot') +lineplots.extend(collect_line_plots('logs/collected_stats_pix2pix*.plot')) +lineplots.extend(collect_line_plots('logs/gan*.plot', neg='moving_ai')) +lineplots.extend(collect_line_plots('logs/pix2pix*.plot', neg='moving_ai')) +lineplots.extend(collect_line_plots('logs/gan_moving_ai*.plot')) +lineplots.extend(collect_line_plots('logs/pix2pix_moving_ai*.plot')) + +with open('box_plots.html', 'w') as f: + for fig in figures: + f.write(fig.to_html(full_html=False, include_plotlyjs='cdn')) +with open('line_plots.html', 'w') as f: + for line in lineplots: + f.write(line.to_html(full_html=False, include_plotlyjs='cdn')) + + +app = dash.Dash(__name__) +app.layout = html.Div([ + html.Div([dcc.Graph(figure=fig) for fig in figures]), + html.Div([dcc.Graph(figure=lineplot) for lineplot in lineplots]) +]) +app.run_server(debug=True)" diff --git a/download.py b/download.py index eee93a6..874d018 100644 --- a/download.py +++ b/download.py @@ -1,3 +1,4 @@ +import os import argparse from pathgan.data.utils import download_and_extract @@ -5,7 +6,8 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(prog = "top", description="Training GAN (from original paper") parser.add_argument("--url", type=str, help="Url of file.") - parser.add_argument("--root", type=str, default=".", help="Root path.") + parser.add_argument("--root", type=str, default="data", help="Root path.") parser.add_argument("--filename", type=str, default=None, help="Filname.") args = parser.parse_args() + os.makedirs(args.root, exist_ok=True) download_and_extract(args.root, args.url, args.filename) diff --git a/notebooks/analyzes_of_connectivity.ipynb b/notebooks/analyzes_of_connectivity.ipynb index 98f752b..d91b033 100644 --- a/notebooks/analyzes_of_connectivity.ipynb +++ b/notebooks/analyzes_of_connectivity.ipynb @@ -2,14 +2,25 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "sys.path.append(os.path.dirname(os.getcwd()))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import cv2\n", "import numpy as np\n", "import pandas as pd\n", - "from PathGAN.data.utils import rgb2binary\n", + "from pathgan.utils import rgb2binary\n", "import matplotlib.pyplot as plt\n", "from tqdm.notebook import tqdm\n", "from PIL import Image, ImageDraw\n", @@ -19,7 +30,7 @@ "import sys\n", "from pathlib import Path\n", "from collections import deque\n", - "from PathGAN.RRT_last import RRT" + "from pathgan.models.rrt import RRT" ] }, { @@ -28,8 +39,8 @@ "metadata": {}, "outputs": [], "source": [ - "dataset_path = './PathGAN/data/dataset/'\n", - "data_path = './PathGAN/data/' " + "dataset_path = 'data/generated_dataset/dataset/'\n", + "data_path = 'data/generated_dataset/'" ] }, { @@ -1140,7 +1151,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -1154,9 +1165,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.9.7" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/pathgan/data/mpr_dataset.py b/pathgan/data/mpr_dataset.py index 35b16bc..95681f3 100644 --- a/pathgan/data/mpr_dataset.py +++ b/pathgan/data/mpr_dataset.py @@ -29,6 +29,8 @@ class MPRDataset(Dataset): Dataframe with map/task/roi pairs. transform: Callable Transforms for map/task/roi pairs. + to_binary: bool (default=False) + If to load data (image) as binary tensor. """ def __init__( self, @@ -37,14 +39,16 @@ def __init__( roi_dir: str, csv_file: pd.DataFrame, transform: Optional[Callable] = None, - test: bool = False, + return_meta: bool = False, + to_binary: bool = False, ): self.map_dir = map_dir self.point_dir = point_dir self.roi_dir = roi_dir self.csv_file = csv_file self.transform = transform - self.test = test + self.return_meta = return_meta + self.to_binary = to_binary def __len__(self) -> int: return len(self.csv_file) @@ -54,19 +58,23 @@ def __getitem__(self, index: int) -> Dict[str, Any]: map_name = row["map"].split(".")[0] map_path = f"{self.map_dir}/{row['map']}" point_path = f"{self.point_dir}/{map_name}/{row['task']}" - if not self.test: - roi_path = f"{self.roi_dir}/{map_name}/{row['roi']}" + roi_path = f"{self.roi_dir}/{map_name}/{row['roi']}" + meta = {"map_path": map_path, "point_path": point_path, "roi_path": roi_path} - map_img = np.array(Image.open(map_path).convert('RGB')) - point_img = np.array(Image.open(point_path).convert('RGB')) - if not self.test: - roi_img = np.array(Image.open(roi_path).convert('RGB')) + if self.to_binary: + map_img = np.array(Image.open(map_path).convert("L")) + point_img = np.array(Image.open(point_path).convert("L")) + roi_img = np.array(Image.open(roi_path).convert("L")) + else: + map_img = np.array(Image.open(map_path).convert("RGB")) + point_img = np.array(Image.open(point_path).convert("RGB")) + roi_img = np.array(Image.open(roi_path).convert("RGB")) if self.transform is not None: map_img = self.transform(map_img) point_img = self.transform(point_img) - if not self.test: - roi_img = self.transform(roi_img) - if not self.test: - return map_img, point_img, roi_img - return map_img, point_img + roi_img = self.transform(roi_img) + + if self.return_meta: + return map_img, point_img, roi_img, meta + return map_img, point_img, roi_img diff --git a/pathgan/metrics/__init__.py b/pathgan/metrics/__init__.py index e69de29..1761d1a 100644 --- a/pathgan/metrics/__init__.py +++ b/pathgan/metrics/__init__.py @@ -0,0 +1 @@ +from .metrics import * diff --git a/pathgan/metrics/metrics.py b/pathgan/metrics/metrics.py index ba9c38e..b1b6674 100644 --- a/pathgan/metrics/metrics.py +++ b/pathgan/metrics/metrics.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torch.nn as nn from torch.nn import functional as F @@ -5,6 +6,28 @@ from .functional import kl_divergence, covariance, frechet_distance +def intersection_over_union(roi_pred, roi_true): + roi_pred = np.sum(roi_pred, axis=-1) + roi_pred = (roi_pred < 0.5 * np.max(roi_pred)).astype(int) + roi_true = np.sum(roi_true, axis=-1) + roi_true = (roi_true < 0.5 * np.max(roi_true)).astype(int) + inter = (roi_pred * roi_true).sum() + union = roi_pred.sum() + roi_true.sum() + iou_value = inter / (union - inter + 1e-6) + return iou_value + + +def jaccard_coefficient(roi_pred, roi_true): + roi_pred = np.sum(roi_pred, axis=-1) + roi_pred = (roi_pred < 0.5 * np.max(roi_pred)).astype(int) + roi_true = np.sum(roi_true, axis=-1) + roi_true = (roi_true < 0.5 * np.max(roi_true)).astype(int) + inter = (roi_pred * roi_true).sum() + union = roi_pred.sum() + roi_true.sum() + dice_value = 2 * inter / (union + 1e-6) + return dice_value + + class KLDivergence(nn.Module): """Kullback–Leibler divergence.""" def __init__(self,): diff --git a/pathgan/models/rrt/rrt_base.py b/pathgan/models/rrt/rrt_base.py index a707813..a67a649 100644 --- a/pathgan/models/rrt/rrt_base.py +++ b/pathgan/models/rrt/rrt_base.py @@ -32,7 +32,7 @@ class PathDescription(object): """ def __init__( self, - path: List, + path: Optional[List] = None, cost: float = float('inf'), time_sec: float = 0., time_it: int = 0, @@ -40,7 +40,7 @@ def __init__( nodes_taken: int = 0, ): """Initialize.""" - self.path = path + self.path = path if path is not None else [] self.cost = cost self.time_sec = time_sec self.time_it = time_it diff --git a/scripts/data/download_datasets.sh b/scripts/data/download_datasets.sh new file mode 100644 index 0000000..1c354b2 --- /dev/null +++ b/scripts/data/download_datasets.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env python + +echo "Downloading generated dataset..." + +python download.py \ + --url https://github.com/akanametov/pathgan/releases/download/2.0/dataset.zip \ + --root data/generated_dataset \ + +echo "Downloading movingai dataset..." + +python download.py \ + --url https://github.com/akanametov/pathgan/releases/download/2.0/movingai_dataset.zip \ + --root data/movingai_dataset \ diff --git a/scripts/data/download_results.sh b/scripts/data/download_results.sh new file mode 100644 index 0000000..77ca42e --- /dev/null +++ b/scripts/data/download_results.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +echo "Downloading SAGAN results on generated dataset..." + +python download.py \ + --url https://github.com/akanametov/pathgan/releases/download/2.0/results.zip \ + --root data/generated_dataset \ + +echo "Downloading SAGAN results on movingai dataset..." + +python download.py \ + --url https://github.com/akanametov/pathgan/releases/download/2.0/movingai_results.zip \ + --root data/movingai_dataset/movingai \ + +echo "Downloading Pix2pix results on generated dataset..." + +python download.py \ + --url https://github.com/akanametov/pathgan/releases/download/2.0/pixresults.zip \ + --root data/generated_dataset \ + +echo "Downloading Pix2pix results on movingai dataset..." + +python download.py \ + --url https://github.com/akanametov/pathgan/releases/download/2.0/movingai_pixresults.zip \ + --root data/movingai_dataset/movingai \ diff --git a/scripts/logs/get_logs_generated_pix2pix.sh b/scripts/logs/get_logs_generated_pix2pix.sh new file mode 100644 index 0000000..46c6d5c --- /dev/null +++ b/scripts/logs/get_logs_generated_pix2pix.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +echo "Get Pix2pix logs on generated dataset..." + +python get_logs.py A \ + --map_params "{'data_folder': './data/generated_dataset', 'maps_folder': 'maps', 'results_folder': 'pixresults', 'results_file': 'result.csv'}" \ + --rrt_params "{'path_resolution': 1, 'step_len': 2, 'max_iter': 10000}" \ + --mu 0.1 \ + --gamma 10 \ + --n 50 \ + --output_dir logs/pix2pix \ + --output_fname pix2pix_generated_logs.txt \ diff --git a/scripts/logs/get_logs_generated_sagan.sh b/scripts/logs/get_logs_generated_sagan.sh new file mode 100644 index 0000000..375831e --- /dev/null +++ b/scripts/logs/get_logs_generated_sagan.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +echo "Get SAGAN logs on generated dataset..." + +python get_logs.py A \ + --map_params "{'data_folder': './data/generated_dataset', 'maps_folder': 'maps', 'results_folder': 'results', 'results_file': 'results.csv'}" \ + --rrt_params "{'path_resolution': 1, 'step_len': 2, 'max_iter': 10000}" \ + --mu 0.1 \ + --gamma 10 \ + --n 50 \ + --output_dir logs/ \ + --output_fname sagan_generated_logs.txt \ diff --git a/scripts/logs/get_logs_movingai_pix2pix.sh b/scripts/logs/get_logs_movingai_pix2pix.sh new file mode 100644 index 0000000..a1ebf1a --- /dev/null +++ b/scripts/logs/get_logs_movingai_pix2pix.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +echo "Get Pix2pix logs on movingai dataset..." + +python get_logs.py A \ + --map_params "{'data_folder': './data/movingai_dataset', 'maps_folder': 'maps', 'results_folder': 'movingai/pixresults', 'results_file': 'pixresult.csv'}" \ + --rrt_params "{'path_resolution': 1, 'step_len': 2, 'max_iter': 10000}" \ + --mu 0.1 \ + --gamma 10 \ + --n 50 \ + --output_dir logs/pix2pix \ + --output_fname pix2pix_movingai_logs.txt \ diff --git a/scripts/logs/get_logs_movingai_sagan.sh b/scripts/logs/get_logs_movingai_sagan.sh new file mode 100644 index 0000000..a07726d --- /dev/null +++ b/scripts/logs/get_logs_movingai_sagan.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +echo "Get SAGAN logs on movingai dataset..." + +python get_logs.py A \ + --map_params "{'data_folder': './data/movingai_dataset', 'maps_folder': 'maps', 'results_folder': 'movingai/results', 'results_file': 'result.csv'}" \ + --rrt_params "{'path_resolution': 1, 'step_len': 2, 'max_iter': 10000}" \ + --mu 0.1 \ + --gamma 10 \ + --n 50 \ + --output_dir logs/sagan \ + --output_fname sagan_movingai_logs.txt \ diff --git a/test_pix2pix.py b/test_pix2pix.py index 1a173aa..fb3606a 100644 --- a/test_pix2pix.py +++ b/test_pix2pix.py @@ -13,14 +13,15 @@ from pathgan.data import MPRDataset from pathgan.models import Generator +from pathgan.metrics import intersection_over_union, jaccard_coefficient if __name__ == '__main__': parser = argparse.ArgumentParser(prog = 'top', description='Testing Pix2Pix GAN (our GAN)') - parser.add_argument('--checkpoint_path', default=None, help='Load directory to continue training (default: "None")') - parser.add_argument('--batch_size', type=int, default=1, help='"Batch size" with which GAN will be trained (default: 1)') + parser.add_argument('--checkpoint_path', default='checkpoints/pix2pix/generator.pt', help='Path to trained Generator') + parser.add_argument('--dataset_path', default='data/generated_dataset/dataset', help='Path to dataset') parser.add_argument('--save_dir', default='results/pix2pix', help='Save directory (default: "results/pix2pix")') - parser.add_argument('--device', type=str, default='cuda:0', help='Device (default: "cuda:0")') + parser.add_argument('--device', type=str, default='cuda', help='Device (default: "cuda")') args = parser.parse_args() device = torch.device(args.device if torch.cuda.is_available() else 'cpu') @@ -32,30 +33,51 @@ std=(0.5, 0.5, 0.5), ), ]) - df = pd.read_csv('dataset/test.csv') + dataset = MPRDataset( - map_dir = 'dataset/maps', - point_dir = 'dataset/tasks', - roi_dir = 'dataset/tasks', - csv_file = df, - transform = transform, + map_dir=os.path.join(args.dataset_path, 'maps'), + point_dir=os.path.join(args.dataset_path, 'tasks'), + roi_dir=os.path.join(args.dataset_path, 'tasks'), + csv_file=pd.read_csv(os.path.join(args.dataset_path, 'test.csv')), + transform=transform, + return_meta=True, ) - dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) generator = Generator() - print('=========== Loading weights for Generator ===========') + print(f"Loading weights from: {args.checkpoint_path}") generator.load_state_dict(torch.load(args.checkpoint_path, map_location="cpu")) generator = generator.to(device) generator = generator.eval() - print('============== Testing Started ==============') + print("Start evaluation") os.makedirs(args.save_dir, exist_ok=True) - for i, (maps, points, rois) in enumerate(tqdm(dataloader)): - maps = maps.to(device) - points = points.to(device) + true_roi_paths = [] + pred_roi_paths = [] + iou_values = [] + dice_values = [] + for i in tqdm(range(len(dataset))): + maps, points, rois, meta = dataset[i] + maps = maps.unsqueeze(0).to(device) + points = points.unsqueeze(0).to(device) with torch.no_grad(): pred_rois = generator(maps, points).detach().cpu()[0] pred_rois = pred_rois.permute(1,2,0).numpy() pred_rois = (pred_rois > 0).astype(np.uint8) * 255 - roi_img = Image.fromarray(pred_rois) - roi_path = os.path.join(args.save_dir, f"roi_{i}.png") - roi_img.save(roi_path) - print('============== Testing Finished! ==============') + + pred_roi_img = Image.fromarray(pred_rois) + pred_roi_path = os.path.join(args.save_dir, f"roi_{i}.png") + pred_roi_img.save(pred_roi_path) + + iou = intersection_over_union(pred_rois, rois.permute(1,2,0).numpy()) + dice = jaccard_coefficient(pred_rois, rois.permute(1,2,0).numpy()) + true_roi_paths.append(meta["roi_path"]) + pred_roi_paths.append(pred_roi_path) + iou_values.append(iou) + dice_values.append(dice) + + csv_file = pd.DataFrame({ + "true_roi": true_roi_paths, + "pred_roi": pred_roi_paths, + "iou": iou_values, + "dice": dice_values, + }) + csv_file.to_csv(os.path.join(args.save_dir, "results.csv"), index=False) + print(f"Saving .csv file to: {args.save_dir}") diff --git a/test_sagan.py b/test_sagan.py index 1f111f6..6bafe6d 100644 --- a/test_sagan.py +++ b/test_sagan.py @@ -13,14 +13,15 @@ from pathgan.data import MPRDataset from pathgan.models import SAGenerator +from pathgan.metrics import intersection_over_union, jaccard_coefficient if __name__ == '__main__': parser = argparse.ArgumentParser(prog = 'top', description='Testing SAGAN (from original paper)') - parser.add_argument('--checkpoint_path', default=None, help='Load directory to continue training (default: "None")') - parser.add_argument('--batch_size', type=int, default=1, help='"Batch size" with which GAN will be trained (default: 1)') + parser.add_argument('--checkpoint_path', default='checkpoints/sagan/generator.pt', help='Path to trained Generator') + parser.add_argument('--dataset_path', default='data/generated_dataset/dataset', help='Path to dataset') parser.add_argument('--save_dir', default='results/sagan', help='Save directory (default: "results/sagan")') - parser.add_argument('--device', type=str, default='cuda:0', help='Device (default: "cuda:0")') + parser.add_argument('--device', type=str, default='cuda', help='Device (default: "cuda")') args = parser.parse_args() device = torch.device(args.device if torch.cuda.is_available() else 'cpu') @@ -32,25 +33,30 @@ std=(0.5, 0.5, 0.5), ), ]) - df = pd.read_csv('dataset/test.csv') + dataset = MPRDataset( - map_dir = 'dataset/maps', - point_dir = 'dataset/tasks', - roi_dir = 'dataset/tasks', - csv_file = df, - transform = transform, + map_dir=os.path.join(args.dataset_path, 'maps'), + point_dir=os.path.join(args.dataset_path, 'tasks'), + roi_dir=os.path.join(args.dataset_path, 'tasks'), + csv_file=pd.read_csv(os.path.join(args.dataset_path, 'test.csv')), + transform=transform, + return_meta=True, ) - dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) generator = SAGenerator() - print('=========== Loading weights for Generator ===========') + print(f"Loading weights from: {args.checkpoint_path}") generator.load_state_dict(torch.load(args.checkpoint_path, map_location="cpu")) generator = generator.to(device) generator = generator.eval() - print('============== Testing Started ==============') + print("Start evaluation") os.makedirs(args.save_dir, exist_ok=True) - for i, (maps, points, rois) in enumerate(tqdm(dataloader)): - maps = maps.to(device) - points = points.to(device) + true_roi_paths = [] + pred_roi_paths = [] + iou_values = [] + dice_values = [] + for i in tqdm(range(len(dataset))): + maps, points, rois, meta = dataset[i] + maps = maps.unsqueeze(0).to(device) + points = points.unsqueeze(0).to(device) b, _, h, w = maps.size() noise = torch.rand(b, 1, h, w) noise = noise.to(device) @@ -58,7 +64,23 @@ pred_rois = generator(maps, points, noise).detach().cpu()[0] pred_rois = pred_rois.permute(1,2,0).numpy() pred_rois = (pred_rois > 0).astype(np.uint8) * 255 - roi_img = Image.fromarray(pred_rois) - roi_path = os.path.join(args.save_dir, f"roi_{i}.png") - roi_img.save(roi_path) - print('============== Testing Finished! ==============') + + pred_roi_img = Image.fromarray(pred_rois) + pred_roi_path = os.path.join(args.save_dir, f"roi_{i}.png") + pred_roi_img.save(pred_roi_path) + + iou = intersection_over_union(pred_rois, rois.permute(1,2,0).numpy()) + dice = jaccard_coefficient(pred_rois, rois.permute(1,2,0).numpy()) + true_roi_paths.append(meta["roi_path"]) + pred_roi_paths.append(pred_roi_path) + iou_values.append(iou) + dice_values.append(dice) + + csv_file = pd.DataFrame({ + "true_roi": true_roi_paths, + "pred_roi": pred_roi_paths, + "iou": iou_values, + "dice": dice_values, + }) + csv_file.to_csv(os.path.join(args.save_dir, "results.csv"), index=False) + print(f"Saving .csv file to: {args.save_dir}") diff --git a/train_pix2pix.py b/train_pix2pix.py index f731555..7972254 100644 --- a/train_pix2pix.py +++ b/train_pix2pix.py @@ -16,11 +16,11 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(prog = 'top', description='Training Pix2Pix GAN (our GAN)') + parser.add_argument('--dataset_path', default='data/generated_dataset/dataset', help='Path to dataset') parser.add_argument('--batch_size', type=int, default=8, help='"Batch size" with which GAN will be trained (default: 8)') parser.add_argument('--epochs', type=int, default=3, help='Number of "epochs" GAN will be trained (default: 3)') parser.add_argument('--g_lr', type=float, default=0.001, help='"Learning rate" of Generator (default: 0.001)') parser.add_argument('--d_lr', type=float, default=0.0007, help='"Learning rate" of Discriminator (default: 0.0007)') - parser.add_argument('--load_dir', default=None, help='Load directory to continue training (default: "None")') parser.add_argument("--save_dir", default="checkpoints/pix2pix", help='Save directory (default: "checkpoints/pix2pix")') parser.add_argument("--device", type=str, default="cuda:0", help="Device (default: 'cuda:0')") args = parser.parse_args() @@ -34,20 +34,16 @@ std=(0.5, 0.5, 0.5), ), ]) - df = pd.read_csv('dataset/train.csv') dataset = MPRDataset( - map_dir = 'dataset/maps', - point_dir = 'dataset/tasks', - roi_dir = 'dataset/tasks', - csv_file = df, - transform = transform, + map_dir=os.path.join(args.dataset_path, 'maps'), + point_dir=os.path.join(args.dataset_path, 'tasks'), + roi_dir=os.path.join(args.dataset_path, 'tasks'), + csv_file=pd.read_csv(os.path.join(args.dataset_path, 'train.csv')), + transform=transform, ) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) generator = Generator() discriminator = Discriminator() - if args.load_dir: - print('=========== Loading weights for Generator ===========') - generator.load_state_dict(torch.load(args.load_dir)) g_criterion = GeneratorLoss() d_criterion = DiscriminatorLoss() @@ -63,11 +59,10 @@ d_optimizer=d_optimizer, device=device, ) - print('============== Training Started ==============') + print("Start training") trainer.fit(dataloader, epochs=args.epochs, device=device) - print('============== Training Finished! ==============') if args.save_dir: - print('=========== Saving weights for Pix2Pix ===========') + print(f"Saving weights for Pix2pix to: {args.save_dir}") os.makedirs(args.save_dir, exist_ok=True) torch.save(generator.cpu().state_dict(), os.path.join(args.save_dir, "generator.pt")) - torch.save(discriminator.cpu().state_dict(), os.path.join(args.save_dir, "discriminator.pt")) + torch.save(discriminator.cpu().state_dict(), os.path.join(args.save_dir, "discriminator.pt")) \ No newline at end of file diff --git a/train_sagan.py b/train_sagan.py index efb058e..47d08f4 100644 --- a/train_sagan.py +++ b/train_sagan.py @@ -16,12 +16,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(prog = "top", description="Training GAN (from original paper)") + parser.add_argument('--dataset_path', default='data/generated_dataset/dataset', help='Path to dataset') parser.add_argument("--batch_size", type=int, default=8, help="Batch size (default: 8)") parser.add_argument("--epochs", type=int, default=3, help="Number of `epochs` GAN will be trained (default: 3)") parser.add_argument("--g_lr", type=float, default=0.0001, help="Learning rate of Generator (default: 0.0001)") parser.add_argument("--md_lr", type=float, default=0.00005, help="Learning rate of Map Discriminator (default: 0.00005)") parser.add_argument("--pd_lr", type=float, default=0.00005, help="Learning rate of Point Discriminator (default: 0.00005)") - parser.add_argument("--load_dir", default=None, help='Load directory to continue training (default: "None")') parser.add_argument("--save_dir", default="checkpoints/sagan", help='Save directory (default: "checkpoints/sagan")') parser.add_argument("--device", type=str, default="cuda:0", help="Device (default: 'cuda:0')") args = parser.parse_args() @@ -35,12 +35,11 @@ std=(0.5, 0.5, 0.5), ), ]) - df = pd.read_csv("dataset/train.csv") dataset = MPRDataset( - map_dir="dataset/maps", - point_dir="dataset/tasks", - roi_dir="dataset/tasks", - csv_file=df, + map_dir=os.path.join(args.dataset_path, 'maps'), + point_dir=os.path.join(args.dataset_path, 'tasks'), + roi_dir=os.path.join(args.dataset_path, 'tasks'), + csv_file=pd.read_csv(os.path.join(args.dataset_path, 'train.csv')), transform=transform, ) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) @@ -48,10 +47,6 @@ generator = SAGenerator() map_discriminator = MapDiscriminator() point_discriminator = PointDiscriminator() - # Load weights - if args.load_dir: - print('=========== Loading weights for Generator ===========') - generator.load_state_dict(torch.load(args.load_dir)) # Losses g_criterion = AdaptiveSAGeneratorLoss() md_criterion = DiscriminatorLoss() @@ -73,11 +68,10 @@ pd_optimizer=pd_optimizer, device=device, ) - print('============== Training Started ==============') + print("Start training") trainer.fit(dataloader, epochs=args.epochs, device=device) - print('============== Training Finished! ==============') if args.save_dir: - print('=========== Saving weights for SAGAN ===========') + print(f"Saving weights for SAGAN to: {args.save_dir}") os.makedirs(args.save_dir, exist_ok=True) torch.save(generator.cpu().state_dict(), os.path.join(args.save_dir, "generator.pt")) torch.save(map_discriminator.cpu().state_dict(), os.path.join(args.save_dir, "map_discriminator.pt"))