forked from amirgholami/PyHessian
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
executable file
·112 lines (100 loc) · 4.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#*
# @file Different utility functions
# Copyright (c) Zhewei Yao, Amir Gholami
# All rights reserved.
# This file is part of PyHessian library.
#
# PyHessian is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# PyHessian is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with PyHessian. If not, see <http://www.gnu.org/licenses/>.
#*
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
def getData(name='cifar10', train_bs=128, test_bs=1000):
"""
Get the dataloader
"""
if name == 'cifar10':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
trainset = datasets.CIFAR10(root='../data',
train=True,
download=True,
transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset,
batch_size=train_bs,
shuffle=True)
testset = datasets.CIFAR10(root='../data',
train=False,
download=False,
transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset,
batch_size=test_bs,
shuffle=False)
if name == 'cifar10_without_dataaugmentation':
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
trainset = datasets.CIFAR10(root='../data',
train=True,
download=True,
transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset,
batch_size=train_bs,
shuffle=True)
testset = datasets.CIFAR10(root='../data',
train=False,
download=False,
transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset,
batch_size=test_bs,
shuffle=False)
return train_loader, test_loader
def test(model, test_loader, cuda=True):
"""
Get the test performance
"""
model.eval()
correct = 0
total_num = 0
for data, target in test_loader:
if cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
pred = output.data.max(
1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
total_num += len(data)
print('testing_correct: ', correct / total_num, '\n')
return correct / total_num