forked from isl-org/Open3D-PointNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_segmentation.py
106 lines (87 loc) · 3.96 KB
/
train_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import print_function
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from datasets import PartDataset
from pointnet import PointNetDenseCls
import torch.nn.functional as F
if torch.cuda.is_available():
import torch.backends.cudnn as cudnn
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='seg', help='output folder')
parser.add_argument('--model', type=str, default = '', help='model path')
opt = parser.parse_args()
print (opt)
opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = False, class_choice = ['Chair'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = False, class_choice = ['Chair'], train = False)
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
print(len(dataset), len(test_dataset))
num_classes = dataset.num_seg_classes
print('classes', num_classes)
try:
os.makedirs(opt.outf)
except OSError:
pass
blue = lambda x:'\033[94m' + x + '\033[0m'
classifier = PointNetDenseCls(k = num_classes)
if opt.model != '':
classifier.load_state_dict(torch.load(opt.model))
optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
if torch.cuda.is_available():
classifier.cuda()
num_batch = len(dataset)/opt.batchSize
for epoch in range(opt.nepoch):
for i, data in enumerate(dataloader, 0):
points, target = data
points, target = Variable(points), Variable(target)
points = points.transpose(2,1)
if torch.cuda.is_available():
points, target = points.cuda(), target.cuda()
optimizer.zero_grad()
classifier = classifier.train()
pred, _ = classifier(points)
pred = pred.view(-1, num_classes)
target = target.view(-1,1)[:,0] - 1
#print(pred.size(), target.size())
loss = F.nll_loss(pred, target)
loss.backward()
optimizer.step()
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.item(), correct.item()/float(opt.batchSize * 2500)))
if i % 10 == 0:
j, data = next(enumerate(testdataloader, 0))
points, target = data
points, target = Variable(points), Variable(target)
points = points.transpose(2,1)
if torch.cuda.is_available():
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
pred, _ = classifier(points)
pred = pred.view(-1, num_classes)
target = target.view(-1,1)[:,0] - 1
loss = F.nll_loss(pred, target)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500)))
torch.save(classifier.state_dict(), '%s/seg_model_%d.pth' % (opt.outf, epoch))