-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
26 lines (21 loc) · 860 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
def prepare_mnist_data(data_path, batch_size):
# Transforms to apply to the data
trans = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(
root=data_path, train=True, transform=trans, download=True)
test_dataset = torchvision.datasets.MNIST(
root=data_path, train=False, transform=trans)
# Create Dataloaders
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader