-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset_AID.py
71 lines (56 loc) · 2.41 KB
/
dataset_AID.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# -*- coding:utf-8 -*-
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
import torch
class CarDateSet(data.Dataset):
def __init__(self, root, lists, transforms=None, train=True, test=False):
self.test = test
# imgs = [os.path.join(root, img) for img in os.listdir(root)]
with open (lists, 'r') as f:
lines = f.readlines()
imgs = []
labels = []
for line in lines:
imgs.append(os.path.join(root, line.split('\t')[1]))
# print(line)
labels.append(int(line.split('\t')[2])) # irrigated land_1.tif
self.imgs = imgs
self.labels = labels
if transforms is None:
self.transforms = T.Compose([
# torchvision.transforms.Resize(256),
# T.ToTensor()
# T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
# T.ToPILImage(),
T.Resize((448, 448)), # 缩放图片(Image)到(h,w)
# T.RandomHorizontalFlip(p=0.3),
# T.RandomVerticalFlip(p=0.3),
# T.RandomCrop(size=224),
# T.RandomRotation(180),
# T.RandomHorizontalFlip(), #水平翻转,注意不是所有图片都适合,比如车牌
# T.CenterCrop(224), # 从图片中间切出224*224的图片
# T.RandomCrop(224), #随机裁剪
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
# T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化至[-1, 1],规定均值和标准差
])
else:
self.transforms = transforms
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
label = self.labels[index]
data = Image.open(img_path).convert('RGB')
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
if __name__ == '__main__':
dataset = CarDateSet('G:\\LULC\\PytorchLULC\\Shenzhen56\\', './data/train.txt')
img, label = dataset[0] # 相当于调用dataset.__getitem__(0)
for img, label in dataset:
print(img.size(), img.float().mean(), label)