diff --git a/data/data_pipe.py b/data/data_pipe.py index bd67f02c..6b5d4898 100644 --- a/data/data_pipe.py +++ b/data/data_pipe.py @@ -88,7 +88,8 @@ def load_mx_rec(rec_path): for idx in tqdm(range(1,max_idx)): img_info = imgrec.read_idx(idx) header, img = mx.recordio.unpack_img(img_info) - label = int(header.label) + # label = int(header.label) + label = int(header.label[0]) img = Image.fromarray(img) label_path = save_path/str(label) if not label_path.exists(): @@ -119,4 +120,4 @@ def load_mx_rec(rec_path): # if self.h_flip: # img = de_preprocess(img) # img = self.transform(img) -# return img, label \ No newline at end of file +# return img, label