generated from nogibjj/python-template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
ui.py
58 lines (50 loc) · 1.89 KB
/
ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from torch import nn
import copy
import torchvision
from torchvision import transforms as T
import gradio as gr
def launch(model_path):
'''
Launch the user interface that provides the predictions and corresponding
class probabilities based on trained models from input model file path.
Input:
model_path: String
Output:
None (will launch the user interface)
'''
# load model from model file path
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
best_model_wts = copy.deepcopy(checkpoint)
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(best_model_wts)
# prediction function
def model_pred(img):
class_names = ['artificial','human']
transform=T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor=transform(img).unsqueeze(0)
model.eval()
with torch.no_grad():
pred_probs = torch.softmax(model(image_tensor), dim=1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
return pred_labels_and_probs
inputs = gr.Image(type="pil")
interface = gr.Interface(
fn=model_pred,
inputs=inputs,
outputs=gr.Label(num_top_classes=2, label="Classification"),
title="Original vs AI-generated art Classification",
description="Provide an image and get the predicted class label.")
interface.launch()
if __name__ == "__main__":
model_path = "models/model1.pth"
launch(model_path)