forked from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_example.py
49 lines (37 loc) · 1.35 KB
/
inference_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from PIL import Image
from torchvision import transforms
from src.models.mnist_model import MNISTLitModel
def predict():
"""Example of inference with trained model.
It loads trained image classification model from checkpoint.
Then it loads example image and predicts its label.
"""
# ckpt can be also a URL!
CKPT_PATH = "last.ckpt"
# load model from checkpoint
# model __init__ parameters will be loaded from ckpt automatically
# you can also pass some parameter explicitly to override it
trained_model = MNISTLitModel.load_from_checkpoint(checkpoint_path=CKPT_PATH)
# print model hyperparameters
print(trained_model.hparams)
# switch to evaluation mode
trained_model.eval()
trained_model.freeze()
# load data
img = Image.open("data/example_img.png").convert("L") # convert to black and white
# img = Image.open("data/example_img.png").convert("RGB") # convert to RGB
# preprocess
mnist_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((28, 28)),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
img = mnist_transforms(img)
img = img.reshape((1, *img.size())) # reshape to form batch of size 1
# inference
output = trained_model(img)
print(output)
if __name__ == "__main__":
predict()