forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
imagenet.py
53 lines (45 loc) · 1.6 KB
/
imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# for imagenet download prepare.sh and run it
import glob, random
import json
import numpy as np
from PIL import Image
import functools, pathlib
BASEDIR = pathlib.Path(__file__).parent / "imagenet"
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
cir = {v[0]: int(k) for k,v in ci.items()}
@functools.lru_cache(None)
def get_train_files():
train_files = open(BASEDIR / "train_files").read().strip().split("\n")
return [(BASEDIR / "train" / x) for x in train_files]
@functools.lru_cache(None)
def get_val_files():
val_files = glob.glob(str(BASEDIR / "val/*/*"))
return val_files
#rrc = transforms.RandomResizedCrop(224)
import torchvision.transforms.functional as F
def image_load(fn):
img = Image.open(fn).convert('RGB')
img = F.resize(img, 256, Image.BILINEAR)
img = F.center_crop(img, 224)
ret = np.array(img)
return ret
def iterate(bs=32, val=True, shuffle=True):
files = get_val_files() if val else get_train_files()
order = list(range(0, len(files)))
if shuffle: random.shuffle(order)
from multiprocessing import Pool
p = Pool(16)
for i in range(0, len(files), bs):
X = p.map(image_load, [files[i] for i in order[i:i+bs]])
Y = [cir[files[i].split("/")[-2]] for i in order[i:i+bs]]
yield (np.array(X), np.array(Y))
def fetch_batch(bs, val=False):
files = get_val_files() if val else get_train_files()
samp = np.random.randint(0, len(files), size=(bs))
files = [files[i] for i in samp]
X = [image_load(x) for x in files]
Y = [cir[x.split("/")[0]] for x in files]
return np.array(X), np.array(Y)
if __name__ == "__main__":
X,Y = fetch_batch(64)
print(X.shape, Y)