-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataset_dataloader.py
102 lines (84 loc) · 2.8 KB
/
dataset_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import numpy as np
import pandas as pd
import cv2
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensor, ToTensorV2
from albumentations import (HorizontalFlip,
VerticalFlip,
Normalize,
Compose)
class LungsDataset(Dataset):
def __init__(self,
imgs_dir: str,
masks_dir:str,
df: pd.DataFrame,
phase: str):
"""Initialization."""
self.root_imgs_dir = imgs_dir
self.root_masks_dir = masks_dir
self.df = df
self.augmentations = get_augmentations(phase)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
img_name = self.df.loc[idx, "ImageId"]
mask_name = self.df.loc[idx, "MaskId"]
img_path = os.path.join(self.root_imgs_dir, img_name)
mask_path = os.path.join(self.root_masks_dir, mask_name)
img = cv2.imread(img_path)
mask = cv2.imread(mask_path)
mask[mask < 240] = 0 # remove artifacts
mask[mask > 0] = 1
augmented = self.augmentations(image=img,
mask=mask.astype(np.float32))
img = augmented['image']
mask = augmented['mask'].permute(2, 0, 1)
return img, mask
def get_augmentations(phase,
mean: tuple = (0.485, 0.456, 0.406),
std: tuple = (0.229, 0.224, 0.225),):
list_transforms = []
if phase == "train":
list_transforms.extend(
[
VerticalFlip(p=0.5),
]
)
list_transforms.extend(
[
Normalize(mean=mean, std=std, p=1),
#ToTensor(num_classes=3, sigmoid=False),
ToTensorV2(),
]
)
list_trfms = Compose(list_transforms)
return list_trfms
def get_dataloader(
imgs_dir: str,
masks_dir: str,
path_to_csv: str,
phase: str,
batch_size: int = 8,
num_workers: int = 6,
test_size: float = 0.2,
):
'''Returns: dataloader for the model training'''
df = pd.read_csv(path_to_csv)
train_df, val_df = train_test_split(df,
test_size=test_size,
random_state=69)
train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)
df = train_df if phase == "train" else val_df
image_dataset = LungsDataset(imgs_dir, masks_dir, df, phase)
dataloader = DataLoader(
image_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=True,
)
return dataloader