Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
k1smet14 committed Apr 29, 2021
1 parent af343c4 commit b71becb
Showing 1 changed file with 105 additions and 1 deletion.
106 changes: 105 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
@@ -1 +1,105 @@
print('Hello 분리수거!')
import os
import json
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

import albumentations as A
from albumentations.pytorch import ToTensorV2

from my_utils import *
from dataloader import *
#from loss import *

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"


def collate_fn(batch):
return tuple(zip(*batch))


def test(model, data_loader, device):
size = 256
transform = A.Compose([A.Resize(256, 256)])
print('Start prediction.')
model.eval()

file_name_list = []
preds_array = np.empty((0, size*size), dtype=np.long)

with torch.no_grad():
for step, (imgs, image_infos) in tqdm(enumerate(data_loader)):

# inference (512 x 512)
outs = model(torch.stack(imgs).to(device))
oms = torch.argmax(outs, dim=1).detach().cpu().numpy()
# resize (256 x 256)
# temp_mask = []
# for img, mask in zip(np.stack(imgs), oms):
# transformed = transform(image=img, mask=mask)
# mask = transformed['mask']
# temp_mask.append(mask)

# oms = np.array(temp_mask)

oms = oms.reshape([oms.shape[0], size*size]).astype(int)
preds_array = np.vstack((preds_array, oms))

file_name_list.append([i['file_name'] for i in image_infos])
print(f"step:{step+1:3d}/{len(data_loader)}")
print("End prediction.")
file_names = [y for x in file_name_list for y in x]

return file_names, preds_array


def main():
dataset_path = '../input/data'
test_path = dataset_path + '/test.json'

test_transform = A.Compose([
A.Resize(256, 256),
ToTensorV2()
])

test_dataset = CustomDataLoader(data_dir=test_path, mode='test', transform=test_transform)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=16,
num_workers=4,
collate_fn=collate_fn)

model = smp.DeepLabV3Plus(
encoder_name='resnext50_32x4d',
encoder_weights='ssl',
classes=12
)
load_model(model, device, saved_dir="models", file_name="deeplabv3plus_resnext50_32x4d.pt")
model.to(device)
#load_model(model, device, saved_dir, file_name)
# sample_submisson.csv 열기

submission = pd.read_csv('./submission/sample_submission.csv', index_col=None)

# test set에 대한 prediction
file_names, preds = test(model, test_loader, device)

# PredictionString 대입
for file_name, string in zip(file_names, preds):
submission = submission.append({"image_id" : file_name, "PredictionString" : ' '.join(str(e) for e in string.tolist())},
ignore_index=True)

# submission.csv로 저장
submission.to_csv("./submission/deeplabv3plus_resnext50_32x4d.csv", index=False)

if __name__ == '__main__':
main()

0 comments on commit b71becb

Please sign in to comment.