diff --git a/eval/verification.py b/eval/verification.py index 2b846c6..367062e 100644 --- a/eval/verification.py +++ b/eval/verification.py @@ -197,7 +197,8 @@ def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): return tpr, fpr, accuracy, val, val_std, far @torch.no_grad() -def load_bin(path, image_size): +def load_bin(path, image_size):\ + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") try: with open(path, 'rb') as f: bins, issame_list = pickle.load(f) # py2 @@ -217,7 +218,7 @@ def load_bin(path, image_size): for flip in [0, 1]: if flip == 1: img = mx.ndarray.flip(data=img, axis=2) - data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) + data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()).float().to(device) if idx % 1000 == 0: print('loading bin', idx) print(data_list[0].shape) @@ -225,6 +226,7 @@ def load_bin(path, image_size): @torch.no_grad() def test(data_set, backbone, batch_size, nfolds=10): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print('testing verification..') data_list = data_set[0] issame_list = data_set[1] @@ -240,7 +242,7 @@ def test(data_set, backbone, batch_size, nfolds=10): _data = data[bb - batch_size: bb] time0 = datetime.datetime.now() img = ((_data / 255) - 0.5) / 0.5 - net_out: torch.Tensor = backbone(img) + net_out: torch.cuda.Tensor = backbone(img.float().to(device)) _embeddings = net_out.detach().cpu().numpy() time_now = datetime.datetime.now() diff = time_now - time0