Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixbug: imagenet remove torch #275

Merged
merged 1 commit into from
Nov 27, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions vega/datasets/common/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

"""This is a class for Imagenet dataset."""

from torchvision.datasets import ImageFolder
from vega.datasets.transforms.Compose import Compose
from vega.common import ClassFactory, ClassType
from vega.common import FileOps
from vega.datasets.conf.imagenet import ImagenetConfig
Expand All @@ -43,7 +41,27 @@ def __init__(self, **kwargs):
self.args.data_path = FileOps.download_dataset(self.args.data_path)
split = 'train' if self.mode == 'train' else 'val'
local_data_path = FileOps.join_path(self.args.data_path, split)
self.image_folder = ImageFolder(root=local_data_path, transform=Compose(self.transforms.__transform__))
self.frame_type = False
if self.args.backend in ["m", "mindspore"]:
self.frame_type = True
if self.frame_type:
from mindspore.dataset import ImageFolderDataset
from mindspore.dataset import vision
import mindspore.dataset.transforms as transforms
self.image_folders = ImageFolderDataset(dataset_dir=local_data_path, num_parallel_workers=8)
transform = transforms.Compose([vision.Decode(to_pil=True),
vision.RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=vision.Inter.BILINEAR),
vision.RandomHorizontalFlip(prob=0.5),
vision.RandomColorAdjust(brightness=[0.6, 1.4], contrast=[0.6, 1.4],saturation=[0.6, 1.4], hue=[-0.2, 0.2]),
vision.ToTensor(),
vision.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], is_hwc=False),
])

self.image_folder = self.image_folders.map(operations=transform, input_columns=["image"])
else:
from torchvision.datasets import ImageFolder
from vega.datasets.transforms.Compose import Compose
self.image_folder = ImageFolder(root=local_data_path, transform=Compose(self.transforms.__transform__))

@property
def input_channels(self):
Expand All @@ -65,8 +83,12 @@ def input_size(self):

def __len__(self):
"""Get the length of the dataset."""
if self.frame_type:
return self.image_folder.get_dataset_size()
return self.image_folder.__len__()

def __getitem__(self, index):
"""Get an item of the dataset according to the index."""
if self.frame_type:
return next(self.image_folder.create_tuple_iterator())
return self.image_folder.__getitem__(index)