diff --git a/Segmentation_Baseline_Code/FCN8s baseline (VGG imageNet weight).ipynb b/Segmentation_Baseline_Code/FCN8s baseline (VGG imageNet weight).ipynb
new file mode 100644
index 0000000..39ffe8e
--- /dev/null
+++ b/Segmentation_Baseline_Code/FCN8s baseline (VGG imageNet weight).ipynb
@@ -0,0 +1,1006 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "toc": true
+ },
+ "source": [
+ "
Table of Contents
\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:06:58.944902Z",
+ "start_time": "2021-04-22T11:06:56.623974Z"
+ },
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import random\n",
+ "import time\n",
+ "import json\n",
+ "import warnings \n",
+ "warnings.filterwarnings('ignore')\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from utils import label_accuracy_score\n",
+ "import cv2\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "# 전처리를 위한 라이브러리\n",
+ "from pycocotools.coco import COCO\n",
+ "import torchvision\n",
+ "import torchvision.transforms as transforms\n",
+ "\n",
+ "import albumentations as A\n",
+ "from albumentations.pytorch import ToTensorV2\n",
+ "\n",
+ "# 시각화를 위한 라이브러리\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns; sns.set()\n",
+ "\n",
+ "plt.rcParams['axes.grid'] = False\n",
+ "\n",
+ "print('pytorch version: {}'.format(torch.__version__))\n",
+ "print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))\n",
+ "\n",
+ "print(torch.cuda.get_device_name(0))\n",
+ "print(torch.cuda.device_count())\n",
+ "\n",
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # GPU 사용 가능 여부에 따라 device 정보 저장"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 하이퍼파라미터 세팅 및 seed 고정"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:06:59.171980Z",
+ "start_time": "2021-04-22T11:06:59.167952Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "batch_size = 16 # Mini-batch size\n",
+ "num_epochs = 20\n",
+ "learning_rate = 0.0001"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:06:59.446510Z",
+ "start_time": "2021-04-22T11:06:59.443508Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# seed 고정\n",
+ "random_seed = 21\n",
+ "torch.manual_seed(random_seed)\n",
+ "torch.cuda.manual_seed(random_seed)\n",
+ "# torch.cuda.manual_seed_all(random_seed) # if use multi-GPU\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "np.random.seed(random_seed)\n",
+ "random.seed(random_seed)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 학습 데이터 EDA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:04.139668Z",
+ "start_time": "2021-04-22T11:07:00.575728Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "\n",
+ "dataset_path = '../input/data'\n",
+ "anns_file_path = dataset_path + '/' + 'train.json'\n",
+ "\n",
+ "# Read annotations\n",
+ "with open(anns_file_path, 'r') as f:\n",
+ " dataset = json.loads(f.read())\n",
+ "\n",
+ "categories = dataset['categories']\n",
+ "anns = dataset['annotations']\n",
+ "imgs = dataset['images']\n",
+ "nr_cats = len(categories)\n",
+ "nr_annotations = len(anns)\n",
+ "nr_images = len(imgs)\n",
+ "\n",
+ "# Load categories and super categories\n",
+ "cat_names = []\n",
+ "super_cat_names = []\n",
+ "super_cat_ids = {}\n",
+ "super_cat_last_name = ''\n",
+ "nr_super_cats = 0\n",
+ "for cat_it in categories:\n",
+ " cat_names.append(cat_it['name'])\n",
+ " super_cat_name = cat_it['supercategory']\n",
+ " # Adding new supercat\n",
+ " if super_cat_name != super_cat_last_name:\n",
+ " super_cat_names.append(super_cat_name)\n",
+ " super_cat_ids[super_cat_name] = nr_super_cats\n",
+ " super_cat_last_name = super_cat_name\n",
+ " nr_super_cats += 1\n",
+ "\n",
+ "print('Number of super categories:', nr_super_cats)\n",
+ "print('Number of categories:', nr_cats)\n",
+ "print('Number of annotations:', nr_annotations)\n",
+ "print('Number of images:', nr_images)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:04.394832Z",
+ "start_time": "2021-04-22T11:07:04.141668Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# Count annotations\n",
+ "cat_histogram = np.zeros(nr_cats,dtype=int)\n",
+ "for ann in anns:\n",
+ " cat_histogram[ann['category_id']] += 1\n",
+ "\n",
+ "# Initialize the matplotlib figure\n",
+ "f, ax = plt.subplots(figsize=(5,5))\n",
+ "\n",
+ "# Convert to DataFrame\n",
+ "df = pd.DataFrame({'Categories': cat_names, 'Number of annotations': cat_histogram})\n",
+ "df = df.sort_values('Number of annotations', 0, False)\n",
+ "\n",
+ "# Plot the histogram\n",
+ "plt.title(\"category distribution of train set \")\n",
+ "plot_1 = sns.barplot(x=\"Number of annotations\", y=\"Categories\", data=df, label=\"Total\", color=\"b\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:04.409808Z",
+ "start_time": "2021-04-22T11:07:04.395831Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# category labeling \n",
+ "sorted_temp_df = df.sort_index()\n",
+ "\n",
+ "# background = 0 에 해당되는 label 추가 후 기존들을 모두 label + 1 로 설정\n",
+ "sorted_df = pd.DataFrame([\"Backgroud\"], columns = [\"Categories\"])\n",
+ "sorted_df = sorted_df.append(sorted_temp_df, ignore_index=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:04.424832Z",
+ "start_time": "2021-04-22T11:07:04.411802Z"
+ },
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "# class (Categories) 에 따른 index 확인 (0~11 : 총 12개)\n",
+ "sorted_df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 데이터 전처리 함수 정의 (Dataset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:04.439837Z",
+ "start_time": "2021-04-22T11:07:04.425804Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "category_names = list(sorted_df.Categories)\n",
+ "\n",
+ "def get_classname(classID, cats):\n",
+ " for i in range(len(cats)):\n",
+ " if cats[i]['id']==classID:\n",
+ " return cats[i]['name']\n",
+ " return \"None\"\n",
+ "\n",
+ "class CustomDataLoader(Dataset):\n",
+ " \"\"\"COCO format\"\"\"\n",
+ " def __init__(self, data_dir, mode = 'train', transform = None):\n",
+ " super().__init__()\n",
+ " self.mode = mode\n",
+ " self.transform = transform\n",
+ " self.coco = COCO(data_dir)\n",
+ " \n",
+ " def __getitem__(self, index: int):\n",
+ " # dataset이 index되어 list처럼 동작\n",
+ " image_id = self.coco.getImgIds(imgIds=index)\n",
+ " image_infos = self.coco.loadImgs(image_id)[0]\n",
+ " \n",
+ " # cv2 를 활용하여 image 불러오기\n",
+ " images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))\n",
+ " images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)\n",
+ " images /= 255.0\n",
+ " \n",
+ " if (self.mode in ('train', 'val')):\n",
+ " ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])\n",
+ " anns = self.coco.loadAnns(ann_ids)\n",
+ "\n",
+ " # Load the categories in a variable\n",
+ " cat_ids = self.coco.getCatIds()\n",
+ " cats = self.coco.loadCats(cat_ids)\n",
+ "\n",
+ " # masks : size가 (height x width)인 2D\n",
+ " # 각각의 pixel 값에는 \"category id + 1\" 할당\n",
+ " # Background = 0\n",
+ " masks = np.zeros((image_infos[\"height\"], image_infos[\"width\"]))\n",
+ " # Unknown = 1, General trash = 2, ... , Cigarette = 11\n",
+ " for i in range(len(anns)):\n",
+ " className = get_classname(anns[i]['category_id'], cats)\n",
+ " pixel_value = category_names.index(className)\n",
+ " masks = np.maximum(self.coco.annToMask(anns[i])*pixel_value, masks)\n",
+ " masks = masks.astype(np.float32)\n",
+ "\n",
+ " # transform -> albumentations 라이브러리 활용\n",
+ " if self.transform is not None:\n",
+ " transformed = self.transform(image=images, mask=masks)\n",
+ " images = transformed[\"image\"]\n",
+ " masks = transformed[\"mask\"]\n",
+ " \n",
+ " return images, masks, image_infos\n",
+ " \n",
+ " if self.mode == 'test':\n",
+ " # transform -> albumentations 라이브러리 활용\n",
+ " if self.transform is not None:\n",
+ " transformed = self.transform(image=images)\n",
+ " images = transformed[\"image\"]\n",
+ " \n",
+ " return images, image_infos\n",
+ " \n",
+ " \n",
+ " def __len__(self) -> int:\n",
+ " # 전체 dataset의 size를 return\n",
+ " return len(self.coco.getImgIds())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Dataset 정의 및 DataLoader 할당"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:09.179806Z",
+ "start_time": "2021-04-22T11:07:04.440804Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# train.json / validation.json / test.json 디렉토리 설정\n",
+ "train_path = dataset_path + '/train.json'\n",
+ "val_path = dataset_path + '/val.json'\n",
+ "test_path = dataset_path + '/test.json'\n",
+ "\n",
+ "# collate_fn needs for batch\n",
+ "def collate_fn(batch):\n",
+ " return tuple(zip(*batch))\n",
+ "\n",
+ "train_transform = A.Compose([\n",
+ " ToTensorV2()\n",
+ " ])\n",
+ "\n",
+ "val_transform = A.Compose([\n",
+ " ToTensorV2()\n",
+ " ])\n",
+ "\n",
+ "test_transform = A.Compose([\n",
+ " ToTensorV2()\n",
+ " ])\n",
+ "\n",
+ "# create own Dataset 1 (skip)\n",
+ "# validation set을 직접 나누고 싶은 경우\n",
+ "# random_split 사용하여 data set을 8:2 로 분할\n",
+ "# train_size = int(0.8*len(dataset))\n",
+ "# val_size = int(len(dataset)-train_size)\n",
+ "# dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=transform)\n",
+ "# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])\n",
+ "\n",
+ "# create own Dataset 2\n",
+ "# train dataset\n",
+ "train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)\n",
+ "\n",
+ "# validation dataset\n",
+ "val_dataset = CustomDataLoader(data_dir=val_path, mode='val', transform=val_transform)\n",
+ "\n",
+ "# test dataset\n",
+ "test_dataset = CustomDataLoader(data_dir=test_path, mode='test', transform=test_transform)\n",
+ "\n",
+ "\n",
+ "# DataLoader\n",
+ "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n",
+ " batch_size=batch_size,\n",
+ " shuffle=True,\n",
+ " num_workers=4,\n",
+ " collate_fn=collate_fn)\n",
+ "\n",
+ "val_loader = torch.utils.data.DataLoader(dataset=val_dataset, \n",
+ " batch_size=batch_size,\n",
+ " shuffle=False,\n",
+ " num_workers=4,\n",
+ " collate_fn=collate_fn)\n",
+ "\n",
+ "test_loader = torch.utils.data.DataLoader(dataset=test_dataset,\n",
+ " batch_size=batch_size,\n",
+ " num_workers=4,\n",
+ " collate_fn=collate_fn)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 데이터 샘플 시각화 (Show example image and mask)\n",
+ "\n",
+ "- `train_loader` \n",
+ "- `val_loader` \n",
+ "- `test_loader` "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:09.779805Z",
+ "start_time": "2021-04-22T11:07:09.181803Z"
+ },
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "# train_loader의 output 결과(image 및 mask) 확인\n",
+ "for imgs, masks, image_infos in train_loader:\n",
+ " image_infos = image_infos[0]\n",
+ " temp_images = imgs\n",
+ " temp_masks = masks\n",
+ " \n",
+ " break\n",
+ "\n",
+ "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12))\n",
+ "\n",
+ "print('image shape:', list(temp_images[0].shape))\n",
+ "print('mask shape: ', list(temp_masks[0].shape))\n",
+ "print('Unique values, category of transformed mask : \\n', [{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks[0]))])\n",
+ "\n",
+ "ax1.imshow(temp_images[0].permute([1,2,0]))\n",
+ "ax1.grid(False)\n",
+ "ax1.set_title(\"input image : {}\".format(image_infos['file_name']), fontsize = 15)\n",
+ "\n",
+ "ax2.imshow(temp_masks[0])\n",
+ "ax2.grid(False)\n",
+ "ax2.set_title(\"masks : {}\".format(image_infos['file_name']), fontsize = 15)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:10.469862Z",
+ "start_time": "2021-04-22T11:07:09.780831Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# val_loader의 output 결과(image 및 mask) 확인\n",
+ "for imgs, masks, image_infos in val_loader:\n",
+ " image_infos = image_infos[0]\n",
+ " temp_images = imgs\n",
+ " temp_masks = masks\n",
+ " \n",
+ " break\n",
+ "\n",
+ "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12))\n",
+ "\n",
+ "print('image shape:', list(temp_images[0].shape))\n",
+ "print('mask shape: ', list(temp_masks[0].shape))\n",
+ "\n",
+ "print('Unique values, category of transformed mask : \\n', [{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks[0]))])\n",
+ "\n",
+ "ax1.imshow(temp_images[0].permute([1,2,0]))\n",
+ "ax1.grid(False)\n",
+ "ax1.set_title(\"input image : {}\".format(image_infos['file_name']), fontsize = 15)\n",
+ "\n",
+ "ax2.imshow(temp_masks[0])\n",
+ "ax2.grid(False)\n",
+ "ax2.set_title(\"masks : {}\".format(image_infos['file_name']), fontsize = 15)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:07:10.772294Z",
+ "start_time": "2021-04-22T11:07:10.470862Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# test_loader의 output 결과(image) 확인\n",
+ "for imgs, image_infos in test_loader:\n",
+ " image_infos = image_infos[0]\n",
+ " temp_images = imgs\n",
+ " \n",
+ " break\n",
+ "\n",
+ "fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))\n",
+ "\n",
+ "print('image shape:', list(temp_images[0].shape))\n",
+ "\n",
+ "ax1.imshow(temp_images[0].permute([1,2,0]))\n",
+ "ax1.grid(False)\n",
+ "ax1.set_title(\"input image : {}\".format(image_infos['file_name']), fontsize = 15)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## baseline model\n",
+ "\n",
+ "### FCN8s (VGG imageNet weight)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:15:29.119807Z",
+ "start_time": "2021-04-22T11:15:29.109808Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torchvision import models\n",
+ "from torchvision.models import vgg16\n",
+ "\n",
+ "class FCN8s(nn.Module):\n",
+ " def __init__(self, num_classes):\n",
+ " super(FCN8s,self).__init__()\n",
+ " self.pretrained_model = vgg16(pretrained = True)\n",
+ " features, classifiers = list(self.pretrained_model.features.children()), list(self.pretrained_model.classifier.children())\n",
+ "\n",
+ " self.features_map1 = nn.Sequential(*features[0:17])\n",
+ " self.features_map2 = nn.Sequential(*features[17:24])\n",
+ " self.features_map3 = nn.Sequential(*features[24:31])\n",
+ " \n",
+ " # Score pool3\n",
+ " self.score_pool3_fr = nn.Conv2d(256, num_classes, 1)\n",
+ " \n",
+ " # Score pool4 \n",
+ " self.score_pool4_fr = nn.Conv2d(512, num_classes, 1) \n",
+ " \n",
+ " # fc6 ~ fc7\n",
+ " self.conv = nn.Sequential(nn.Conv2d(512, 4096, kernel_size = 1),\n",
+ " nn.ReLU(inplace=True),\n",
+ " nn.Dropout(),\n",
+ " nn.Conv2d(4096, 4096, kernel_size = 1),\n",
+ " nn.ReLU(inplace=True),\n",
+ " nn.Dropout()\n",
+ " )\n",
+ " \n",
+ " # Score\n",
+ " self.score_fr = nn.Conv2d(4096, num_classes, kernel_size = 1)\n",
+ " \n",
+ " # UpScore2 using deconv\n",
+ " self.upscore2 = nn.ConvTranspose2d(num_classes,\n",
+ " num_classes,\n",
+ " kernel_size=4,\n",
+ " stride=2,\n",
+ " padding=1)\n",
+ " \n",
+ " # UpScore2_pool4 using deconv\n",
+ " self.upscore2_pool4 = nn.ConvTranspose2d(num_classes, \n",
+ " num_classes, \n",
+ " kernel_size=4,\n",
+ " stride=2,\n",
+ " padding=1)\n",
+ " \n",
+ " # UpScore8 using deconv\n",
+ " self.upscore8 = nn.ConvTranspose2d(num_classes, \n",
+ " num_classes,\n",
+ " kernel_size=16,\n",
+ " stride=8,\n",
+ " padding=4)\n",
+ " \n",
+ " def forward(self, x):\n",
+ " pool3 = h = self.features_map1(x)\n",
+ " pool4 = h = self.features_map2(h)\n",
+ " h = self.features_map3(h)\n",
+ " \n",
+ " h = self.conv(h)\n",
+ " h = self.score_fr(h)\n",
+ " \n",
+ " score_pool3c = self.score_pool3_fr(pool3) \n",
+ " score_pool4c = self.score_pool4_fr(pool4)\n",
+ " \n",
+ " # Up Score I\n",
+ " upscore2 = self.upscore2(h)\n",
+ " \n",
+ " # Sum I\n",
+ " h = upscore2 + score_pool4c\n",
+ " \n",
+ " # Up Score II\n",
+ " upscore2_pool4c = self.upscore2_pool4(h)\n",
+ " \n",
+ " # Sum II\n",
+ " h = upscore2_pool4c + score_pool3c\n",
+ " \n",
+ " # Up Score III\n",
+ " upscore8 = self.upscore8(h)\n",
+ " \n",
+ " return upscore8"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:15:34.624277Z",
+ "start_time": "2021-04-22T11:15:30.068347Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# 구현된 model에 임의의 input을 넣어 output이 잘 나오는지 test\n",
+ "\n",
+ "model = FCN8s(num_classes=12)\n",
+ "x = torch.randn([1, 3, 512, 512])\n",
+ "print(\"input shape : \", x.shape)\n",
+ "out = model(x).to(device)\n",
+ "print(\"output shape : \", out.size())\n",
+ "\n",
+ "model = model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## train, validation, test 함수 정의"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:15:38.201874Z",
+ "start_time": "2021-04-22T11:15:38.187884Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def train(num_epochs, model, data_loader, val_loader, criterion, optimizer, saved_dir, val_every, device):\n",
+ " print('Start training..')\n",
+ " best_loss = 9999999\n",
+ " for epoch in range(num_epochs):\n",
+ " model.train()\n",
+ " for step, (images, masks, _) in enumerate(data_loader):\n",
+ " images = torch.stack(images) # (batch, channel, height, width)\n",
+ " masks = torch.stack(masks).long() # (batch, channel, height, width)\n",
+ " \n",
+ " # gpu 연산을 위해 device 할당\n",
+ " images, masks = images.to(device), masks.to(device)\n",
+ " \n",
+ " # inference\n",
+ " outputs = model(images)\n",
+ " \n",
+ " # loss 계산 (cross entropy loss)\n",
+ " loss = criterion(outputs, masks)\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " \n",
+ " # step 주기에 따른 loss 출력\n",
+ " if (step + 1) % 25 == 0:\n",
+ " print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(\n",
+ " epoch+1, num_epochs, step+1, len(train_loader), loss.item()))\n",
+ " \n",
+ " # validation 주기에 따른 loss 출력 및 best model 저장\n",
+ " if (epoch + 1) % val_every == 0:\n",
+ " avrg_loss = validation(epoch + 1, model, val_loader, criterion, device)\n",
+ " if avrg_loss < best_loss:\n",
+ " print('Best performance at epoch: {}'.format(epoch + 1))\n",
+ " print('Save model in', saved_dir)\n",
+ " best_loss = avrg_loss\n",
+ " save_model(model, saved_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:15:38.901226Z",
+ "start_time": "2021-04-22T11:15:38.888195Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def validation(epoch, model, data_loader, criterion, device):\n",
+ " print('Start validation #{}'.format(epoch))\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " total_loss = 0\n",
+ " cnt = 0\n",
+ " mIoU_list = []\n",
+ " for step, (images, masks, _) in enumerate(data_loader):\n",
+ " \n",
+ " images = torch.stack(images) # (batch, channel, height, width)\n",
+ " masks = torch.stack(masks).long() # (batch, channel, height, width)\n",
+ "\n",
+ " images, masks = images.to(device), masks.to(device) \n",
+ "\n",
+ " outputs = model(images)\n",
+ " loss = criterion(outputs, masks)\n",
+ " total_loss += loss\n",
+ " cnt += 1\n",
+ " \n",
+ " outputs = torch.argmax(outputs.squeeze(), dim=1).detach().cpu().numpy()\n",
+ "\n",
+ " mIoU = label_accuracy_score(masks.detach().cpu().numpy(), outputs, n_class=12)[2]\n",
+ " mIoU_list.append(mIoU)\n",
+ " \n",
+ " avrg_loss = total_loss / cnt\n",
+ " print('Validation #{} Average Loss: {:.4f}, mIoU: {:.4f}'.format(epoch, avrg_loss, np.mean(mIoU_list)))\n",
+ "\n",
+ " return avrg_loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 모델 저장 함수 정의"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:15:41.634492Z",
+ "start_time": "2021-04-22T11:15:41.627493Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# 모델 저장 함수 정의\n",
+ "val_every = 1 \n",
+ "\n",
+ "saved_dir = './saved'\n",
+ "if not os.path.isdir(saved_dir): \n",
+ " os.mkdir(saved_dir)\n",
+ " \n",
+ "def save_model(model, saved_dir, file_name='fcn8s_best_model(pretrained).pt'):\n",
+ " check_point = {'net': model.state_dict()}\n",
+ " output_path = os.path.join(saved_dir, file_name)\n",
+ " torch.save(model.state_dict(), output_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 모델 생성 및 Loss function, Optimizer 정의"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-22T11:15:43.106368Z",
+ "start_time": "2021-04-22T11:15:43.096368Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Loss function 정의\n",
+ "criterion = nn.CrossEntropyLoss()\n",
+ "\n",
+ "# Optimizer 정의\n",
+ "optimizer = torch.optim.Adam(params = model.parameters(), lr = learning_rate, weight_decay=1e-6)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "start_time": "2021-04-22T11:15:43.700Z"
+ },
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "train(num_epochs, model, train_loader, val_loader, criterion, optimizer, saved_dir, val_every, device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 저장된 model 불러오기 (학습된 이후) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-16T19:44:21.050200Z",
+ "start_time": "2021-04-16T19:44:20.802200Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# best model 저장된 경로\n",
+ "model_path = './saved/fcn8s_best_model(pretrained).pt'\n",
+ "\n",
+ "# best model 불러오기\n",
+ "checkpoint = torch.load(model_path, map_location=device)\n",
+ "model.load_state_dict(checkpoint)\n",
+ "\n",
+ "# 추론을 실행하기 전에는 반드시 설정 (batch normalization, dropout 를 평가 모드로 설정)\n",
+ "# model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-16T19:44:24.939227Z",
+ "start_time": "2021-04-16T19:44:24.518228Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# 첫번째 batch의 추론 결과 확인\n",
+ "for imgs, image_infos in test_loader:\n",
+ " image_infos = image_infos\n",
+ " temp_images = imgs\n",
+ " \n",
+ " model.eval()\n",
+ " # inference\n",
+ " outs = model(torch.stack(temp_images).to(device))\n",
+ " oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()\n",
+ " \n",
+ " break\n",
+ "\n",
+ "i = 3\n",
+ "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16, 16))\n",
+ "\n",
+ "print('Shape of Original Image :', list(temp_images[i].shape))\n",
+ "print('Shape of Predicted : ', list(oms[i].shape))\n",
+ "print('Unique values, category of transformed mask : \\n', [{int(i),category_names[int(i)]} for i in list(np.unique(oms[i]))])\n",
+ "\n",
+ "# Original image\n",
+ "ax1.imshow(temp_images[i].permute([1,2,0]))\n",
+ "ax1.grid(False)\n",
+ "ax1.set_title(\"Original image : {}\".format(image_infos[i]['file_name']), fontsize = 15)\n",
+ "\n",
+ "# Predicted\n",
+ "ax2.imshow(oms[i])\n",
+ "ax2.grid(False)\n",
+ "ax2.set_title(\"Predicted : {}\".format(image_infos[i]['file_name']), fontsize = 15)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## submission을 위한 test 함수 정의"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-16T19:44:27.469285Z",
+ "start_time": "2021-04-16T19:44:27.456021Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def test(model, data_loader, device):\n",
+ " size = 256\n",
+ " transform = A.Compose([A.Resize(256, 256)])\n",
+ " print('Start prediction.')\n",
+ " model.eval()\n",
+ " \n",
+ " file_name_list = []\n",
+ " preds_array = np.empty((0, size*size), dtype=np.long)\n",
+ " \n",
+ " with torch.no_grad():\n",
+ " for step, (imgs, image_infos) in enumerate(test_loader):\n",
+ "\n",
+ " # inference (512 x 512)\n",
+ " outs = model(torch.stack(imgs).to(device))\n",
+ " oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()\n",
+ " \n",
+ " # resize (256 x 256)\n",
+ " temp_mask = []\n",
+ " for img, mask in zip(np.stack(temp_images), oms):\n",
+ " transformed = transform(image=img, mask=mask)\n",
+ " mask = transformed['mask']\n",
+ " temp_mask.append(mask)\n",
+ "\n",
+ " oms = np.array(temp_mask)\n",
+ " \n",
+ " oms = oms.reshape([oms.shape[0], size*size]).astype(int)\n",
+ " preds_array = np.vstack((preds_array, oms))\n",
+ " \n",
+ " file_name_list.append([i['file_name'] for i in image_infos])\n",
+ " print(\"End prediction.\")\n",
+ " file_names = [y for x in file_name_list for y in x]\n",
+ " \n",
+ " return file_names, preds_array"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## submission.csv 생성"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-04-16T19:45:42.235310Z",
+ "start_time": "2021-04-16T19:44:30.499016Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# sample_submisson.csv 열기\n",
+ "submission = pd.read_csv('./submission/sample_submission.csv', index_col=None)\n",
+ "\n",
+ "# test set에 대한 prediction\n",
+ "file_names, preds = test(model, test_loader, device)\n",
+ "\n",
+ "# PredictionString 대입\n",
+ "for file_name, string in zip(file_names, preds):\n",
+ " submission = submission.append({\"image_id\" : file_name, \"PredictionString\" : ' '.join(str(e) for e in string.tolist())}, \n",
+ " ignore_index=True)\n",
+ "\n",
+ "# submission.csv로 저장\n",
+ "submission.to_csv(\"./submission/Baseline_FCN8s(pretrained).csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Reference\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "hide_input": false,
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.1"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": true,
+ "toc_position": {
+ "height": "calc(100% - 180px)",
+ "left": "10px",
+ "top": "150px",
+ "width": "297.278px"
+ },
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/Segmentation_Baseline_Code/requirements.txt b/Segmentation_Baseline_Code/requirements.txt
new file mode 100644
index 0000000..c23a199
--- /dev/null
+++ b/Segmentation_Baseline_Code/requirements.txt
@@ -0,0 +1,134 @@
+albumentations==0.5.2
+anyio==2.2.0
+argon2-cffi==20.1.0
+async-generator==1.10
+attrs==20.3.0
+Babel==2.9.0
+backcall==0.2.0
+bcrypt==3.2.0
+beautifulsoup4==4.9.1
+bleach==3.3.0
+boto3==1.17.56
+botocore==1.20.56
+certifi==2020.12.5
+cffi==1.14.0
+chardet==3.0.4
+conda==4.8.3
+conda-build==3.18.11
+conda-package-handling==1.7.0
+cryptography==2.9.2
+cycler==0.10.0
+Cython==0.29.23
+decorator==4.4.2
+defusedxml==0.7.1
+deprecation==2.1.0
+efficientnet-pytorch==0.6.3
+entrypoints==0.3
+filelock==3.0.12
+gevent==21.1.2
+glob2==0.7
+greenlet==1.0.0
+idna==2.9
+imageio==2.9.0
+imgaug==0.4.0
+importlib-metadata==4.0.1
+inotify-simple==1.2.1
+ipykernel==5.5.3
+ipython==7.16.1
+ipython-genutils==0.2.0
+ipywidgets==7.6.3
+jedi==0.17.1
+Jinja2==2.11.2
+jmespath==0.10.0
+json5==0.9.4
+jsonschema==3.2.0
+jupyter-client==6.1.12
+jupyter-core==4.7.1
+jupyter-packaging==0.9.2
+jupyter-server==1.6.4
+jupyterlab==3.0.14
+jupyterlab-pygments==0.1.2
+jupyterlab-server==2.4.0
+jupyterlab-widgets==1.0.0
+kiwisolver==1.3.1
+libarchive-c==2.9
+MarkupSafe==1.1.1
+matplotlib==3.0.2
+mistune==0.8.4
+mkl-fft==1.1.0
+mkl-random==1.1.1
+mkl-service==2.3.0
+munch==2.5.0
+nbclassic==0.2.7
+nbclient==0.5.3
+nbconvert==6.0.7
+nbformat==5.1.3
+nest-asyncio==1.5.1
+networkx==2.5.1
+notebook==6.3.0
+numpy==1.19.5
+olefile==0.46
+opencv-python==4.2.0.34
+opencv-python-headless==4.5.1.48
+packaging==20.9
+pandas==1.1.1
+pandocfilters==1.4.3
+paramiko==2.7.2
+parso==0.7.0
+pexpect==4.8.0
+pickleshare==0.7.5
+Pillow==7.2.0
+pkginfo==1.5.0.1
+pretrainedmodels==0.7.4
+prometheus-client==0.10.1
+prompt-toolkit==3.0.5
+protobuf==3.15.8
+psutil==5.7.0
+ptyprocess==0.6.0
+pycocotools==2.0.0
+pycosat==0.6.3
+pycparser==2.20
+Pygments==2.6.1
+PyNaCl==1.4.0
+pyOpenSSL==19.1.0
+pyparsing==2.4.7
+pyrsistent==0.17.3
+PySocks==1.7.1
+python-dateutil==2.8.1
+pytz==2020.1
+PyWavelets==1.1.1
+PyYAML==5.3.1
+pyzmq==22.0.3
+requests==2.23.0
+retrying==1.3.3
+ruamel-yaml==0.15.87
+s3transfer==0.4.2
+sagemaker-training==3.9.1
+scikit-image==0.18.1
+scipy==1.6.2
+seaborn==0.9.0
+segmentation-models-pytorch==0.1.3
+Send2Trash==1.5.0
+Shapely==1.7.1
+six==1.14.0
+sniffio==1.2.0
+soupsieve==2.0.1
+terminado==0.9.4
+testpath==0.4.4
+tifffile==2021.4.8
+timm==0.3.2
+tomlkit==0.7.0
+torch==1.4.0
+torchvision==0.5.0
+tornado==6.1
+tqdm==4.60.0
+traitlets==4.3.3
+typing-extensions==3.7.4.3
+urllib3==1.25.8
+wcwidth==0.2.5
+webencodings==0.5.1
+Werkzeug==1.0.1
+widgetsnbextension==3.5.1
+zipp==3.4.1
+zope.event==4.5.0
+zope.interface==5.4.0
\ No newline at end of file
diff --git a/Segmentation_Baseline_Code/utils.py b/Segmentation_Baseline_Code/utils.py
new file mode 100644
index 0000000..12b37b3
--- /dev/null
+++ b/Segmentation_Baseline_Code/utils.py
@@ -0,0 +1,34 @@
+# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
+import numpy as np
+
+
+def _fast_hist(label_true, label_pred, n_class):
+ mask = (label_true >= 0) & (label_true < n_class)
+ hist = np.bincount(
+ n_class * label_true[mask].astype(int) +
+ label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
+ return hist
+
+
+def label_accuracy_score(label_trues, label_preds, n_class):
+ """Returns accuracy score evaluation result.
+ - overall accuracy
+ - mean accuracy
+ - mean IU
+ - fwavacc
+ """
+ hist = np.zeros((n_class, n_class))
+ for lt, lp in zip(label_trues, label_preds):
+ hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
+ acc = np.diag(hist).sum() / hist.sum()
+ with np.errstate(divide='ignore', invalid='ignore'):
+ acc_cls = np.diag(hist) / hist.sum(axis=1)
+ acc_cls = np.nanmean(acc_cls)
+ with np.errstate(divide='ignore', invalid='ignore'):
+ iu = np.diag(hist) / (
+ hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
+ )
+ mean_iu = np.nanmean(iu)
+ freq = hist.sum(axis=1) / hist.sum()
+ fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
+ return acc, acc_cls, mean_iu, fwavacc
\ No newline at end of file