forked from Alibaba-MIIL/ML_Decoder
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_clip_features_test.py
92 lines (73 loc) · 3.25 KB
/
extract_clip_features_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torchvision.datasets as dset
import torchvision.transforms as transforms
import h5py
import torch
import eval_clip
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import os
import pandas as pd
class CLIP_Data(Dataset):
def __init__(self, dir_, img_listpath,transform,preprocess):
super(CLIP_Data, self).__init__()
self.dir_ = dir_
self.img_listpath = img_listpath
self.imglist = pd.read_csv(img_listpath, header=None)
self.transform = transform
self.preprocess =preprocess
def __len__(self):
return len(self.imglist)
def __getitem__(self, index):
filename = os.path.join(self.dir_, self.imglist.iloc[index,0].replace('\\','/'))
img = Image.open(filename)
if self.transform != None:
img = self.transform(img)
img = self.preprocess(img)
return img
##########CLIPCODE############
batch_size = 64
workers = 2
dataroot = '/home/muhammad.ali/Mul_Lab/Generative_MLZSL/datasets/extract_features/data/Flickr'
train_img = '/home/muhammad.ali/Mul_Lab/Generative_MLZSL/datasets/extract_features/ImageList/TrainImagelist.txt'
test_img = '/home/muhammad.ali/Mul_Lab/Generative_MLZSL/datasets/extract_features/ImageList/TestImagelist.txt'
transform=transforms.Compose([
transforms.Resize((224,224)), #bilinear interpolation
# transforms.CenterCrop(224),
# transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# dataset = dset.ImageFolder(root=dataroot,
# transform=transform
# )
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = eval_clip.load("ViT-B/32", device=device)
dataset = CLIP_Data(dataroot, test_img, None, preprocess)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
img_features = []
for i, data in enumerate(dataloader, 0):
data = data.to(device)
with torch.no_grad():
image_features = model.encode_image(data)
img_features.append(image_features.cpu().numpy())
read_filename = "/home/muhammad.ali/Mul_Lab/Generative_MLZSL/datasets/NUS-WIDE/nus_wide_vgg_features/nus_gzsl_test_vgg19.h5"
rf = h5py.File(read_filename, 'r')
labels = rf['labels']
write_filename = "/home/muhammad.ali/Mul_Lab/Generative_MLZSL/datasets/NUS-WIDE/nus_wide_clip_features/nus_gzsl_test_clip.h5"
wf = h5py.File(write_filename,'w')
img_features= np.concatenate(img_features)
features = wf.create_dataset("features", data=img_features)
labels = wf.create_dataset("labels", data=labels)
wf.close()
rf.close()
read_filename = "/home/muhammad.ali/Mul_Lab/Generative_MLZSL/datasets/NUS-WIDE/nus_wide_vgg_features/nus_zsl_test_vgg19.h5"
rf = h5py.File(read_filename, 'r')
labels = rf['labels']
write_filename = "/home/muhammad.ali/Mul_Lab/Generative_MLZSL/datasets/NUS-WIDE/nus_wide_clip_features/nus_zsl_test_clip.h5"
wf = h5py.File(write_filename,'w')
#img_features= np.concatenate(img_features)
features = wf.create_dataset("features", data=img_features)
labels = wf.create_dataset("labels", data=labels)
wf.close()
rf.close()