-
Notifications
You must be signed in to change notification settings - Fork 3
/
mnist_datamodule.py
115 lines (98 loc) · 4.21 KB
/
mnist_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
from typing import Any, Tuple
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
class MNISTDataModule(LightningDataModule):
"""Example of LightningDataModule for MNIST dataset.
A DataModule implements 5 key methods:
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
- setup (things to do on every accelerator in distributed mode)
- train_dataloader (the training dataloader)
- val_dataloader (the validation dataloader(s))
- test_dataloader (the test dataloader(s))
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
"""
def __init__(
self,
data_dir: str = "data/",
train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
persistent_workers: bool = False,
) -> None:
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
self.save_hyperparameters()
# data transformations
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
self.data_train: Dataset | None = None
self.data_val: Dataset | None = None
self.data_test: Dataset | None = None
@property
def num_classes(self) -> int:
return 10
def prepare_data(self) -> None:
"""Download data if needed.
This method is called only from a single GPU.
Do not use it to assign state (self.x = y).
"""
mnist_dir = os.path.join(self.hparams.data_dir, "MNIST")
_download = not os.path.exists(mnist_dir)
MNIST(self.hparams.data_dir, train=True, download=_download)
MNIST(self.hparams.data_dir, train=False, download=_download)
def setup(self, stage: Any = None) -> None:
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning when doing `trainer.fit()` and `trainer.test()`,
so be careful not to execute the random split twice! The `stage` can be used to
differentiate whether it's called before trainer.fit()` or `trainer.test()`.
"""
if stage in ("fit", "test", None):
trainset = MNIST(
root=self.hparams.data_dir, train=True, transform=self.transform
)
testset = MNIST(
root=self.hparams.data_dir, train=False, transform=self.transform
)
dataset = ConcatDataset(datasets=[trainset, testset])
self.data_train, self.data_val, self.data_test = random_split(
dataset=dataset,
lengths=self.hparams.train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=self.hparams.num_workers,
persistent_workers=self.hparams.persistent_workers,
pin_memory=self.hparams.pin_memory,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
persistent_workers=self.hparams.persistent_workers,
pin_memory=self.hparams.pin_memory,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
persistent_workers=self.hparams.persistent_workers,
pin_memory=self.hparams.pin_memory,
)