From af343c4d487886b78acbffa4d06f4a33c3bebf4b Mon Sep 17 00:00:00 2001 From: whale Date: Mon, 26 Apr 2021 16:33:24 +0900 Subject: [PATCH] baseline baseline code --- ...FCN8s baseline (VGG imageNet weight).ipynb | 1006 +++++++++++++++++ Segmentation_Baseline_Code/requirements.txt | 134 +++ Segmentation_Baseline_Code/utils.py | 34 + 3 files changed, 1174 insertions(+) create mode 100644 Segmentation_Baseline_Code/FCN8s baseline (VGG imageNet weight).ipynb create mode 100644 Segmentation_Baseline_Code/requirements.txt create mode 100644 Segmentation_Baseline_Code/utils.py diff --git a/Segmentation_Baseline_Code/FCN8s baseline (VGG imageNet weight).ipynb b/Segmentation_Baseline_Code/FCN8s baseline (VGG imageNet weight).ipynb new file mode 100644 index 0000000..39ffe8e --- /dev/null +++ b/Segmentation_Baseline_Code/FCN8s baseline (VGG imageNet weight).ipynb @@ -0,0 +1,1006 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "toc": true + }, + "source": [ + "

Table of Contents

\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:06:58.944902Z", + "start_time": "2021-04-22T11:06:56.623974Z" + }, + "scrolled": false + }, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "import time\n", + "import json\n", + "import warnings \n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from utils import label_accuracy_score\n", + "import cv2\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "# 전처리를 위한 라이브러리\n", + "from pycocotools.coco import COCO\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "\n", + "import albumentations as A\n", + "from albumentations.pytorch import ToTensorV2\n", + "\n", + "# 시각화를 위한 라이브러리\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns; sns.set()\n", + "\n", + "plt.rcParams['axes.grid'] = False\n", + "\n", + "print('pytorch version: {}'.format(torch.__version__))\n", + "print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))\n", + "\n", + "print(torch.cuda.get_device_name(0))\n", + "print(torch.cuda.device_count())\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # GPU 사용 가능 여부에 따라 device 정보 저장" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 하이퍼파라미터 세팅 및 seed 고정" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:06:59.171980Z", + "start_time": "2021-04-22T11:06:59.167952Z" + } + }, + "outputs": [], + "source": [ + "batch_size = 16 # Mini-batch size\n", + "num_epochs = 20\n", + "learning_rate = 0.0001" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:06:59.446510Z", + "start_time": "2021-04-22T11:06:59.443508Z" + } + }, + "outputs": [], + "source": [ + "# seed 고정\n", + "random_seed = 21\n", + "torch.manual_seed(random_seed)\n", + "torch.cuda.manual_seed(random_seed)\n", + "# torch.cuda.manual_seed_all(random_seed) # if use multi-GPU\n", + "torch.backends.cudnn.deterministic = True\n", + "torch.backends.cudnn.benchmark = False\n", + "np.random.seed(random_seed)\n", + "random.seed(random_seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 학습 데이터 EDA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:04.139668Z", + "start_time": "2021-04-22T11:07:00.575728Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "dataset_path = '../input/data'\n", + "anns_file_path = dataset_path + '/' + 'train.json'\n", + "\n", + "# Read annotations\n", + "with open(anns_file_path, 'r') as f:\n", + " dataset = json.loads(f.read())\n", + "\n", + "categories = dataset['categories']\n", + "anns = dataset['annotations']\n", + "imgs = dataset['images']\n", + "nr_cats = len(categories)\n", + "nr_annotations = len(anns)\n", + "nr_images = len(imgs)\n", + "\n", + "# Load categories and super categories\n", + "cat_names = []\n", + "super_cat_names = []\n", + "super_cat_ids = {}\n", + "super_cat_last_name = ''\n", + "nr_super_cats = 0\n", + "for cat_it in categories:\n", + " cat_names.append(cat_it['name'])\n", + " super_cat_name = cat_it['supercategory']\n", + " # Adding new supercat\n", + " if super_cat_name != super_cat_last_name:\n", + " super_cat_names.append(super_cat_name)\n", + " super_cat_ids[super_cat_name] = nr_super_cats\n", + " super_cat_last_name = super_cat_name\n", + " nr_super_cats += 1\n", + "\n", + "print('Number of super categories:', nr_super_cats)\n", + "print('Number of categories:', nr_cats)\n", + "print('Number of annotations:', nr_annotations)\n", + "print('Number of images:', nr_images)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:04.394832Z", + "start_time": "2021-04-22T11:07:04.141668Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "# Count annotations\n", + "cat_histogram = np.zeros(nr_cats,dtype=int)\n", + "for ann in anns:\n", + " cat_histogram[ann['category_id']] += 1\n", + "\n", + "# Initialize the matplotlib figure\n", + "f, ax = plt.subplots(figsize=(5,5))\n", + "\n", + "# Convert to DataFrame\n", + "df = pd.DataFrame({'Categories': cat_names, 'Number of annotations': cat_histogram})\n", + "df = df.sort_values('Number of annotations', 0, False)\n", + "\n", + "# Plot the histogram\n", + "plt.title(\"category distribution of train set \")\n", + "plot_1 = sns.barplot(x=\"Number of annotations\", y=\"Categories\", data=df, label=\"Total\", color=\"b\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:04.409808Z", + "start_time": "2021-04-22T11:07:04.395831Z" + } + }, + "outputs": [], + "source": [ + "# category labeling \n", + "sorted_temp_df = df.sort_index()\n", + "\n", + "# background = 0 에 해당되는 label 추가 후 기존들을 모두 label + 1 로 설정\n", + "sorted_df = pd.DataFrame([\"Backgroud\"], columns = [\"Categories\"])\n", + "sorted_df = sorted_df.append(sorted_temp_df, ignore_index=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:04.424832Z", + "start_time": "2021-04-22T11:07:04.411802Z" + }, + "scrolled": false + }, + "outputs": [], + "source": [ + "# class (Categories) 에 따른 index 확인 (0~11 : 총 12개)\n", + "sorted_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 데이터 전처리 함수 정의 (Dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:04.439837Z", + "start_time": "2021-04-22T11:07:04.425804Z" + } + }, + "outputs": [], + "source": [ + "category_names = list(sorted_df.Categories)\n", + "\n", + "def get_classname(classID, cats):\n", + " for i in range(len(cats)):\n", + " if cats[i]['id']==classID:\n", + " return cats[i]['name']\n", + " return \"None\"\n", + "\n", + "class CustomDataLoader(Dataset):\n", + " \"\"\"COCO format\"\"\"\n", + " def __init__(self, data_dir, mode = 'train', transform = None):\n", + " super().__init__()\n", + " self.mode = mode\n", + " self.transform = transform\n", + " self.coco = COCO(data_dir)\n", + " \n", + " def __getitem__(self, index: int):\n", + " # dataset이 index되어 list처럼 동작\n", + " image_id = self.coco.getImgIds(imgIds=index)\n", + " image_infos = self.coco.loadImgs(image_id)[0]\n", + " \n", + " # cv2 를 활용하여 image 불러오기\n", + " images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))\n", + " images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)\n", + " images /= 255.0\n", + " \n", + " if (self.mode in ('train', 'val')):\n", + " ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])\n", + " anns = self.coco.loadAnns(ann_ids)\n", + "\n", + " # Load the categories in a variable\n", + " cat_ids = self.coco.getCatIds()\n", + " cats = self.coco.loadCats(cat_ids)\n", + "\n", + " # masks : size가 (height x width)인 2D\n", + " # 각각의 pixel 값에는 \"category id + 1\" 할당\n", + " # Background = 0\n", + " masks = np.zeros((image_infos[\"height\"], image_infos[\"width\"]))\n", + " # Unknown = 1, General trash = 2, ... , Cigarette = 11\n", + " for i in range(len(anns)):\n", + " className = get_classname(anns[i]['category_id'], cats)\n", + " pixel_value = category_names.index(className)\n", + " masks = np.maximum(self.coco.annToMask(anns[i])*pixel_value, masks)\n", + " masks = masks.astype(np.float32)\n", + "\n", + " # transform -> albumentations 라이브러리 활용\n", + " if self.transform is not None:\n", + " transformed = self.transform(image=images, mask=masks)\n", + " images = transformed[\"image\"]\n", + " masks = transformed[\"mask\"]\n", + " \n", + " return images, masks, image_infos\n", + " \n", + " if self.mode == 'test':\n", + " # transform -> albumentations 라이브러리 활용\n", + " if self.transform is not None:\n", + " transformed = self.transform(image=images)\n", + " images = transformed[\"image\"]\n", + " \n", + " return images, image_infos\n", + " \n", + " \n", + " def __len__(self) -> int:\n", + " # 전체 dataset의 size를 return\n", + " return len(self.coco.getImgIds())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataset 정의 및 DataLoader 할당" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:09.179806Z", + "start_time": "2021-04-22T11:07:04.440804Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "# train.json / validation.json / test.json 디렉토리 설정\n", + "train_path = dataset_path + '/train.json'\n", + "val_path = dataset_path + '/val.json'\n", + "test_path = dataset_path + '/test.json'\n", + "\n", + "# collate_fn needs for batch\n", + "def collate_fn(batch):\n", + " return tuple(zip(*batch))\n", + "\n", + "train_transform = A.Compose([\n", + " ToTensorV2()\n", + " ])\n", + "\n", + "val_transform = A.Compose([\n", + " ToTensorV2()\n", + " ])\n", + "\n", + "test_transform = A.Compose([\n", + " ToTensorV2()\n", + " ])\n", + "\n", + "# create own Dataset 1 (skip)\n", + "# validation set을 직접 나누고 싶은 경우\n", + "# random_split 사용하여 data set을 8:2 로 분할\n", + "# train_size = int(0.8*len(dataset))\n", + "# val_size = int(len(dataset)-train_size)\n", + "# dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=transform)\n", + "# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])\n", + "\n", + "# create own Dataset 2\n", + "# train dataset\n", + "train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)\n", + "\n", + "# validation dataset\n", + "val_dataset = CustomDataLoader(data_dir=val_path, mode='val', transform=val_transform)\n", + "\n", + "# test dataset\n", + "test_dataset = CustomDataLoader(data_dir=test_path, mode='test', transform=test_transform)\n", + "\n", + "\n", + "# DataLoader\n", + "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " num_workers=4,\n", + " collate_fn=collate_fn)\n", + "\n", + "val_loader = torch.utils.data.DataLoader(dataset=val_dataset, \n", + " batch_size=batch_size,\n", + " shuffle=False,\n", + " num_workers=4,\n", + " collate_fn=collate_fn)\n", + "\n", + "test_loader = torch.utils.data.DataLoader(dataset=test_dataset,\n", + " batch_size=batch_size,\n", + " num_workers=4,\n", + " collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 데이터 샘플 시각화 (Show example image and mask)\n", + "\n", + "- `train_loader` \n", + "- `val_loader` \n", + "- `test_loader` " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:09.779805Z", + "start_time": "2021-04-22T11:07:09.181803Z" + }, + "scrolled": false + }, + "outputs": [], + "source": [ + "# train_loader의 output 결과(image 및 mask) 확인\n", + "for imgs, masks, image_infos in train_loader:\n", + " image_infos = image_infos[0]\n", + " temp_images = imgs\n", + " temp_masks = masks\n", + " \n", + " break\n", + "\n", + "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12))\n", + "\n", + "print('image shape:', list(temp_images[0].shape))\n", + "print('mask shape: ', list(temp_masks[0].shape))\n", + "print('Unique values, category of transformed mask : \\n', [{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks[0]))])\n", + "\n", + "ax1.imshow(temp_images[0].permute([1,2,0]))\n", + "ax1.grid(False)\n", + "ax1.set_title(\"input image : {}\".format(image_infos['file_name']), fontsize = 15)\n", + "\n", + "ax2.imshow(temp_masks[0])\n", + "ax2.grid(False)\n", + "ax2.set_title(\"masks : {}\".format(image_infos['file_name']), fontsize = 15)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:10.469862Z", + "start_time": "2021-04-22T11:07:09.780831Z" + } + }, + "outputs": [], + "source": [ + "# val_loader의 output 결과(image 및 mask) 확인\n", + "for imgs, masks, image_infos in val_loader:\n", + " image_infos = image_infos[0]\n", + " temp_images = imgs\n", + " temp_masks = masks\n", + " \n", + " break\n", + "\n", + "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12))\n", + "\n", + "print('image shape:', list(temp_images[0].shape))\n", + "print('mask shape: ', list(temp_masks[0].shape))\n", + "\n", + "print('Unique values, category of transformed mask : \\n', [{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks[0]))])\n", + "\n", + "ax1.imshow(temp_images[0].permute([1,2,0]))\n", + "ax1.grid(False)\n", + "ax1.set_title(\"input image : {}\".format(image_infos['file_name']), fontsize = 15)\n", + "\n", + "ax2.imshow(temp_masks[0])\n", + "ax2.grid(False)\n", + "ax2.set_title(\"masks : {}\".format(image_infos['file_name']), fontsize = 15)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:07:10.772294Z", + "start_time": "2021-04-22T11:07:10.470862Z" + } + }, + "outputs": [], + "source": [ + "# test_loader의 output 결과(image) 확인\n", + "for imgs, image_infos in test_loader:\n", + " image_infos = image_infos[0]\n", + " temp_images = imgs\n", + " \n", + " break\n", + "\n", + "fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))\n", + "\n", + "print('image shape:', list(temp_images[0].shape))\n", + "\n", + "ax1.imshow(temp_images[0].permute([1,2,0]))\n", + "ax1.grid(False)\n", + "ax1.set_title(\"input image : {}\".format(image_infos['file_name']), fontsize = 15)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## baseline model\n", + "\n", + "### FCN8s (VGG imageNet weight)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:15:29.119807Z", + "start_time": "2021-04-22T11:15:29.109808Z" + } + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torchvision import models\n", + "from torchvision.models import vgg16\n", + "\n", + "class FCN8s(nn.Module):\n", + " def __init__(self, num_classes):\n", + " super(FCN8s,self).__init__()\n", + " self.pretrained_model = vgg16(pretrained = True)\n", + " features, classifiers = list(self.pretrained_model.features.children()), list(self.pretrained_model.classifier.children())\n", + "\n", + " self.features_map1 = nn.Sequential(*features[0:17])\n", + " self.features_map2 = nn.Sequential(*features[17:24])\n", + " self.features_map3 = nn.Sequential(*features[24:31])\n", + " \n", + " # Score pool3\n", + " self.score_pool3_fr = nn.Conv2d(256, num_classes, 1)\n", + " \n", + " # Score pool4 \n", + " self.score_pool4_fr = nn.Conv2d(512, num_classes, 1) \n", + " \n", + " # fc6 ~ fc7\n", + " self.conv = nn.Sequential(nn.Conv2d(512, 4096, kernel_size = 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(),\n", + " nn.Conv2d(4096, 4096, kernel_size = 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout()\n", + " )\n", + " \n", + " # Score\n", + " self.score_fr = nn.Conv2d(4096, num_classes, kernel_size = 1)\n", + " \n", + " # UpScore2 using deconv\n", + " self.upscore2 = nn.ConvTranspose2d(num_classes,\n", + " num_classes,\n", + " kernel_size=4,\n", + " stride=2,\n", + " padding=1)\n", + " \n", + " # UpScore2_pool4 using deconv\n", + " self.upscore2_pool4 = nn.ConvTranspose2d(num_classes, \n", + " num_classes, \n", + " kernel_size=4,\n", + " stride=2,\n", + " padding=1)\n", + " \n", + " # UpScore8 using deconv\n", + " self.upscore8 = nn.ConvTranspose2d(num_classes, \n", + " num_classes,\n", + " kernel_size=16,\n", + " stride=8,\n", + " padding=4)\n", + " \n", + " def forward(self, x):\n", + " pool3 = h = self.features_map1(x)\n", + " pool4 = h = self.features_map2(h)\n", + " h = self.features_map3(h)\n", + " \n", + " h = self.conv(h)\n", + " h = self.score_fr(h)\n", + " \n", + " score_pool3c = self.score_pool3_fr(pool3) \n", + " score_pool4c = self.score_pool4_fr(pool4)\n", + " \n", + " # Up Score I\n", + " upscore2 = self.upscore2(h)\n", + " \n", + " # Sum I\n", + " h = upscore2 + score_pool4c\n", + " \n", + " # Up Score II\n", + " upscore2_pool4c = self.upscore2_pool4(h)\n", + " \n", + " # Sum II\n", + " h = upscore2_pool4c + score_pool3c\n", + " \n", + " # Up Score III\n", + " upscore8 = self.upscore8(h)\n", + " \n", + " return upscore8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:15:34.624277Z", + "start_time": "2021-04-22T11:15:30.068347Z" + } + }, + "outputs": [], + "source": [ + "# 구현된 model에 임의의 input을 넣어 output이 잘 나오는지 test\n", + "\n", + "model = FCN8s(num_classes=12)\n", + "x = torch.randn([1, 3, 512, 512])\n", + "print(\"input shape : \", x.shape)\n", + "out = model(x).to(device)\n", + "print(\"output shape : \", out.size())\n", + "\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## train, validation, test 함수 정의" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:15:38.201874Z", + "start_time": "2021-04-22T11:15:38.187884Z" + } + }, + "outputs": [], + "source": [ + "def train(num_epochs, model, data_loader, val_loader, criterion, optimizer, saved_dir, val_every, device):\n", + " print('Start training..')\n", + " best_loss = 9999999\n", + " for epoch in range(num_epochs):\n", + " model.train()\n", + " for step, (images, masks, _) in enumerate(data_loader):\n", + " images = torch.stack(images) # (batch, channel, height, width)\n", + " masks = torch.stack(masks).long() # (batch, channel, height, width)\n", + " \n", + " # gpu 연산을 위해 device 할당\n", + " images, masks = images.to(device), masks.to(device)\n", + " \n", + " # inference\n", + " outputs = model(images)\n", + " \n", + " # loss 계산 (cross entropy loss)\n", + " loss = criterion(outputs, masks)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " # step 주기에 따른 loss 출력\n", + " if (step + 1) % 25 == 0:\n", + " print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(\n", + " epoch+1, num_epochs, step+1, len(train_loader), loss.item()))\n", + " \n", + " # validation 주기에 따른 loss 출력 및 best model 저장\n", + " if (epoch + 1) % val_every == 0:\n", + " avrg_loss = validation(epoch + 1, model, val_loader, criterion, device)\n", + " if avrg_loss < best_loss:\n", + " print('Best performance at epoch: {}'.format(epoch + 1))\n", + " print('Save model in', saved_dir)\n", + " best_loss = avrg_loss\n", + " save_model(model, saved_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:15:38.901226Z", + "start_time": "2021-04-22T11:15:38.888195Z" + } + }, + "outputs": [], + "source": [ + "def validation(epoch, model, data_loader, criterion, device):\n", + " print('Start validation #{}'.format(epoch))\n", + " model.eval()\n", + " with torch.no_grad():\n", + " total_loss = 0\n", + " cnt = 0\n", + " mIoU_list = []\n", + " for step, (images, masks, _) in enumerate(data_loader):\n", + " \n", + " images = torch.stack(images) # (batch, channel, height, width)\n", + " masks = torch.stack(masks).long() # (batch, channel, height, width)\n", + "\n", + " images, masks = images.to(device), masks.to(device) \n", + "\n", + " outputs = model(images)\n", + " loss = criterion(outputs, masks)\n", + " total_loss += loss\n", + " cnt += 1\n", + " \n", + " outputs = torch.argmax(outputs.squeeze(), dim=1).detach().cpu().numpy()\n", + "\n", + " mIoU = label_accuracy_score(masks.detach().cpu().numpy(), outputs, n_class=12)[2]\n", + " mIoU_list.append(mIoU)\n", + " \n", + " avrg_loss = total_loss / cnt\n", + " print('Validation #{} Average Loss: {:.4f}, mIoU: {:.4f}'.format(epoch, avrg_loss, np.mean(mIoU_list)))\n", + "\n", + " return avrg_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 모델 저장 함수 정의" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:15:41.634492Z", + "start_time": "2021-04-22T11:15:41.627493Z" + } + }, + "outputs": [], + "source": [ + "# 모델 저장 함수 정의\n", + "val_every = 1 \n", + "\n", + "saved_dir = './saved'\n", + "if not os.path.isdir(saved_dir): \n", + " os.mkdir(saved_dir)\n", + " \n", + "def save_model(model, saved_dir, file_name='fcn8s_best_model(pretrained).pt'):\n", + " check_point = {'net': model.state_dict()}\n", + " output_path = os.path.join(saved_dir, file_name)\n", + " torch.save(model.state_dict(), output_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 모델 생성 및 Loss function, Optimizer 정의" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-22T11:15:43.106368Z", + "start_time": "2021-04-22T11:15:43.096368Z" + } + }, + "outputs": [], + "source": [ + "# Loss function 정의\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "# Optimizer 정의\n", + "optimizer = torch.optim.Adam(params = model.parameters(), lr = learning_rate, weight_decay=1e-6)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "start_time": "2021-04-22T11:15:43.700Z" + }, + "scrolled": false + }, + "outputs": [], + "source": [ + "train(num_epochs, model, train_loader, val_loader, criterion, optimizer, saved_dir, val_every, device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 저장된 model 불러오기 (학습된 이후) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-16T19:44:21.050200Z", + "start_time": "2021-04-16T19:44:20.802200Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "# best model 저장된 경로\n", + "model_path = './saved/fcn8s_best_model(pretrained).pt'\n", + "\n", + "# best model 불러오기\n", + "checkpoint = torch.load(model_path, map_location=device)\n", + "model.load_state_dict(checkpoint)\n", + "\n", + "# 추론을 실행하기 전에는 반드시 설정 (batch normalization, dropout 를 평가 모드로 설정)\n", + "# model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-16T19:44:24.939227Z", + "start_time": "2021-04-16T19:44:24.518228Z" + } + }, + "outputs": [], + "source": [ + "# 첫번째 batch의 추론 결과 확인\n", + "for imgs, image_infos in test_loader:\n", + " image_infos = image_infos\n", + " temp_images = imgs\n", + " \n", + " model.eval()\n", + " # inference\n", + " outs = model(torch.stack(temp_images).to(device))\n", + " oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()\n", + " \n", + " break\n", + "\n", + "i = 3\n", + "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16, 16))\n", + "\n", + "print('Shape of Original Image :', list(temp_images[i].shape))\n", + "print('Shape of Predicted : ', list(oms[i].shape))\n", + "print('Unique values, category of transformed mask : \\n', [{int(i),category_names[int(i)]} for i in list(np.unique(oms[i]))])\n", + "\n", + "# Original image\n", + "ax1.imshow(temp_images[i].permute([1,2,0]))\n", + "ax1.grid(False)\n", + "ax1.set_title(\"Original image : {}\".format(image_infos[i]['file_name']), fontsize = 15)\n", + "\n", + "# Predicted\n", + "ax2.imshow(oms[i])\n", + "ax2.grid(False)\n", + "ax2.set_title(\"Predicted : {}\".format(image_infos[i]['file_name']), fontsize = 15)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## submission을 위한 test 함수 정의" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-16T19:44:27.469285Z", + "start_time": "2021-04-16T19:44:27.456021Z" + } + }, + "outputs": [], + "source": [ + "def test(model, data_loader, device):\n", + " size = 256\n", + " transform = A.Compose([A.Resize(256, 256)])\n", + " print('Start prediction.')\n", + " model.eval()\n", + " \n", + " file_name_list = []\n", + " preds_array = np.empty((0, size*size), dtype=np.long)\n", + " \n", + " with torch.no_grad():\n", + " for step, (imgs, image_infos) in enumerate(test_loader):\n", + "\n", + " # inference (512 x 512)\n", + " outs = model(torch.stack(imgs).to(device))\n", + " oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()\n", + " \n", + " # resize (256 x 256)\n", + " temp_mask = []\n", + " for img, mask in zip(np.stack(temp_images), oms):\n", + " transformed = transform(image=img, mask=mask)\n", + " mask = transformed['mask']\n", + " temp_mask.append(mask)\n", + "\n", + " oms = np.array(temp_mask)\n", + " \n", + " oms = oms.reshape([oms.shape[0], size*size]).astype(int)\n", + " preds_array = np.vstack((preds_array, oms))\n", + " \n", + " file_name_list.append([i['file_name'] for i in image_infos])\n", + " print(\"End prediction.\")\n", + " file_names = [y for x in file_name_list for y in x]\n", + " \n", + " return file_names, preds_array" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## submission.csv 생성" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-16T19:45:42.235310Z", + "start_time": "2021-04-16T19:44:30.499016Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "# sample_submisson.csv 열기\n", + "submission = pd.read_csv('./submission/sample_submission.csv', index_col=None)\n", + "\n", + "# test set에 대한 prediction\n", + "file_names, preds = test(model, test_loader, device)\n", + "\n", + "# PredictionString 대입\n", + "for file_name, string in zip(file_names, preds):\n", + " submission = submission.append({\"image_id\" : file_name, \"PredictionString\" : ' '.join(str(e) for e in string.tolist())}, \n", + " ignore_index=True)\n", + "\n", + "# submission.csv로 저장\n", + "submission.to_csv(\"./submission/Baseline_FCN8s(pretrained).csv\", index=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reference\n", + "\n" + ] + } + ], + "metadata": { + "hide_input": false, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "297.278px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Segmentation_Baseline_Code/requirements.txt b/Segmentation_Baseline_Code/requirements.txt new file mode 100644 index 0000000..c23a199 --- /dev/null +++ b/Segmentation_Baseline_Code/requirements.txt @@ -0,0 +1,134 @@ +albumentations==0.5.2 +anyio==2.2.0 +argon2-cffi==20.1.0 +async-generator==1.10 +attrs==20.3.0 +Babel==2.9.0 +backcall==0.2.0 +bcrypt==3.2.0 +beautifulsoup4==4.9.1 +bleach==3.3.0 +boto3==1.17.56 +botocore==1.20.56 +certifi==2020.12.5 +cffi==1.14.0 +chardet==3.0.4 +conda==4.8.3 +conda-build==3.18.11 +conda-package-handling==1.7.0 +cryptography==2.9.2 +cycler==0.10.0 +Cython==0.29.23 +decorator==4.4.2 +defusedxml==0.7.1 +deprecation==2.1.0 +efficientnet-pytorch==0.6.3 +entrypoints==0.3 +filelock==3.0.12 +gevent==21.1.2 +glob2==0.7 +greenlet==1.0.0 +idna==2.9 +imageio==2.9.0 +imgaug==0.4.0 +importlib-metadata==4.0.1 +inotify-simple==1.2.1 +ipykernel==5.5.3 +ipython==7.16.1 +ipython-genutils==0.2.0 +ipywidgets==7.6.3 +jedi==0.17.1 +Jinja2==2.11.2 +jmespath==0.10.0 +json5==0.9.4 +jsonschema==3.2.0 +jupyter-client==6.1.12 +jupyter-core==4.7.1 +jupyter-packaging==0.9.2 +jupyter-server==1.6.4 +jupyterlab==3.0.14 +jupyterlab-pygments==0.1.2 +jupyterlab-server==2.4.0 +jupyterlab-widgets==1.0.0 +kiwisolver==1.3.1 +libarchive-c==2.9 +MarkupSafe==1.1.1 +matplotlib==3.0.2 +mistune==0.8.4 +mkl-fft==1.1.0 +mkl-random==1.1.1 +mkl-service==2.3.0 +munch==2.5.0 +nbclassic==0.2.7 +nbclient==0.5.3 +nbconvert==6.0.7 +nbformat==5.1.3 +nest-asyncio==1.5.1 +networkx==2.5.1 +notebook==6.3.0 +numpy==1.19.5 +olefile==0.46 +opencv-python==4.2.0.34 +opencv-python-headless==4.5.1.48 +packaging==20.9 +pandas==1.1.1 +pandocfilters==1.4.3 +paramiko==2.7.2 +parso==0.7.0 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==7.2.0 +pkginfo==1.5.0.1 +pretrainedmodels==0.7.4 +prometheus-client==0.10.1 +prompt-toolkit==3.0.5 +protobuf==3.15.8 +psutil==5.7.0 +ptyprocess==0.6.0 +pycocotools==2.0.0 +pycosat==0.6.3 +pycparser==2.20 +Pygments==2.6.1 +PyNaCl==1.4.0 +pyOpenSSL==19.1.0 +pyparsing==2.4.7 +pyrsistent==0.17.3 +PySocks==1.7.1 +python-dateutil==2.8.1 +pytz==2020.1 +PyWavelets==1.1.1 +PyYAML==5.3.1 +pyzmq==22.0.3 +requests==2.23.0 +retrying==1.3.3 +ruamel-yaml==0.15.87 +s3transfer==0.4.2 +sagemaker-training==3.9.1 +scikit-image==0.18.1 +scipy==1.6.2 +seaborn==0.9.0 +segmentation-models-pytorch==0.1.3 +Send2Trash==1.5.0 +Shapely==1.7.1 +six==1.14.0 +sniffio==1.2.0 +soupsieve==2.0.1 +terminado==0.9.4 +testpath==0.4.4 +tifffile==2021.4.8 +timm==0.3.2 +tomlkit==0.7.0 +torch==1.4.0 +torchvision==0.5.0 +tornado==6.1 +tqdm==4.60.0 +traitlets==4.3.3 +typing-extensions==3.7.4.3 +urllib3==1.25.8 +wcwidth==0.2.5 +webencodings==0.5.1 +Werkzeug==1.0.1 +widgetsnbextension==3.5.1 +zipp==3.4.1 +zope.event==4.5.0 +zope.interface==5.4.0 \ No newline at end of file diff --git a/Segmentation_Baseline_Code/utils.py b/Segmentation_Baseline_Code/utils.py new file mode 100644 index 0000000..12b37b3 --- /dev/null +++ b/Segmentation_Baseline_Code/utils.py @@ -0,0 +1,34 @@ +# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py +import numpy as np + + +def _fast_hist(label_true, label_pred, n_class): + mask = (label_true >= 0) & (label_true < n_class) + hist = np.bincount( + n_class * label_true[mask].astype(int) + + label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) + return hist + + +def label_accuracy_score(label_trues, label_preds, n_class): + """Returns accuracy score evaluation result. + - overall accuracy + - mean accuracy + - mean IU + - fwavacc + """ + hist = np.zeros((n_class, n_class)) + for lt, lp in zip(label_trues, label_preds): + hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) + acc = np.diag(hist).sum() / hist.sum() + with np.errstate(divide='ignore', invalid='ignore'): + acc_cls = np.diag(hist) / hist.sum(axis=1) + acc_cls = np.nanmean(acc_cls) + with np.errstate(divide='ignore', invalid='ignore'): + iu = np.diag(hist) / ( + hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + ) + mean_iu = np.nanmean(iu) + freq = hist.sum(axis=1) / hist.sum() + fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() + return acc, acc_cls, mean_iu, fwavacc \ No newline at end of file