-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
70 lines (56 loc) · 1.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import logging
import yaml
import numpy as np
import matplotlib.pyplot as plt
import torchvision as tv
import torch
from models import RockPaperScissorsClassifier
def init_logger():
"""
Initialize logger settings
:return: None
"""
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)-5.5s] %(message)s",
handlers=[
logging.FileHandler("app.log", mode="w"),
logging.StreamHandler()
])
def load_config(path):
"""
Load the configuration from task_2_table.yaml.
"""
return yaml.load(open(path, 'r'), Loader=yaml.SafeLoader)
def load_model(path):
"""
Load model from file
:param path: str
:return: RockPaperScissorsClassifier
"""
logging.info("Loading model from {}".format(path))
model = RockPaperScissorsClassifier()
model.load_state_dict(torch.load(path))
model.eval()
return model
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
def preview_images(dataloader, class_names):
# Get a batch of training data
inputs, classes = next(iter(dataloader))
# Make a grid from batch
out = tv.utils.make_grid(inputs)
imshow(out, title="Examples")
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")