Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Topic recognition #183

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions recognition/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Recognition Tasks
Various recognition tasks solved in deep learning frameworks.

Tasks may include:
* Image Segmentation
* Object detection
* Graph node classification
* Image super resolution
* Disease classification
* Generative modelling with StyleGAN and Stable Diffusion
79 changes: 79 additions & 0 deletions recognition/vision-transformer-4696689/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# ADNI brain data classification with Vision Transformer

## Summary

Goal of the project is to classify Alzheimer's disease (normal or AD) of the ADNI
brain data using a Vision Transformer. Each sample consists of 20 slices of 240x256
greyscale image corresponding to a patient, which is to be classified as either NC
or AD. Experiments were also done with data augmentation.

## How to use

There is four files, dataset.py, modules.py, train.py, predict.py. The only files which
need to be run are train.py or predict.py. train.py is responsible for training (and
testing) the module, with the option of saving the model as well as the loss and
validation accuracy of each epoch, for use in predict.py. predict.py is able to load
this data and retest the model on any of the dataloaders (train, validation, test) or
graph the loss/accuracy curves with matplotlib.

Key point: Inside the dataset.py file, there is a directory address for the images
(local). Make sure that these are pointing in the right direction.

Key point: The save model section of the train.py file is commented. Make sure to
uncomment to use this functionality

Key point: The test section of the predict.py file is commented. Make sure to uncomment
to use this functionality.

Key point: Since the dataset is so large, training might need to be done on 4x p100 gpus
(rangpur).

## Architecture

The default Vision Transformer upgraded to include a pre-convolutional module, of
which there is two designs. The convolutional layers result in less, smaller patches
so the model is sped up. It is also supposed to introduced inductive bias into the
model. 3D patches are utilised offering massive boosts to speed. Data augmentation is
done by flipping images to result in 4x as much data which is said to be very important
for transformer models.

![Basic Transformer Model](extra/ViT.png)

The standard vision transformer works by inputting embeddings of patches of images, along
with a positional encoding, into a transformer model. Only the encoder is used, and
cross entropy loss is used for the classification. Switching the order of normalisation
allows for better propagation of gradient and training stability. If using this patch based
model it is important to use 3D patches for both speed and performance. The later design
used a CNN to instead reduce the image into channels (similar sized to patches) which are
inputted. This further improves speed without impacting performance.

## Training

Training is done for 100 epochs which was found experimentally to be long enough.
AdamW optimiser is used with a learning rate of 3e-4, this was decreased from 1e-3
(which did not train well) but also increased from 1e-4. The data is split into train,
validation and test sets. Majority of the data is in train set, and the validation and
test sets are of equal size.

Hyperparameter tuning was done manually. Learning rate schedulers eg. cyclic, warm up
were found to be ineffective. A learning rate of 1e-3 didn't permit training, but 1e-4
was too slow and didn't perform as good as the final 3e-4. The 20 slices for each image
correspond to the patient-level split.

## Result

Overall, the test accuracy was 68.0% which is ok. The test accuracy was
the same as the validation accuracy, the latter of which became stable during training.
This was about the same time the loss had rapidly decreased and became stable also.
This could indicate that the model has adapated very well to the training set and is
not generalising. This was the key motivator for data augmentation. However, it could
also indicate that the learning rate is too small and stuck in a local optima. This
is the key motivator for increasing the learnign rate from 1e-4 to 3e-4.

![Trianing accuracy and epoch](extra/train.png)
![Validation accuracy and epoch](extra/acc.png)
![Training Loss and epoch](extra/loss.png)

## References

Dosovitskiy, A. (2021) An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, Papers with code. Available at: https://paperswithcode.com/paper/an-image-is-worth-16x16-words-transformers-1 (Accessed: 18 November 2023).
126 changes: 126 additions & 0 deletions recognition/vision-transformer-4696689/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Imports Here
"""
"""numpy and torch"""
import numpy as np
import torch

"""PIL"""
from PIL import Image

"""torchvision and utils"""
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

"""os"""
import os

"""
Loading data from local file
"""
"""Assumes images have pixel values in range [0,255]"""
def getImages(trainDIRs, testDIRS):
"""Get image to tensor"""
transform = transforms.Compose([
transforms.PILToTensor()
])
hflip = transforms.Compose([
transforms.RandomHorizontalFlip(p=1.0),
transforms.PILToTensor()
])
vflip = transforms.Compose([
transforms.RandomVerticalFlip(p=1.0),
transforms.PILToTensor()
])
dflip = transforms.Compose([
transforms.RandomHorizontalFlip(p=1.0),
transforms.RandomVerticalFlip(p=1.0),
transforms.PILToTensor()
])
tlist = [transform, hflip, vflip, dflip]
"""Loading data into arrays"""
xtrain, xtrain, xtest, ytest = [], [], [], []
"""training data"""
size = [0, 0]
for i, DIR in enumerate(trainDIRs):
for t in tlist:
px = []
j = 0
for filename in sorted(os.listdir(DIR)):
f = os.path.join(DIR, filename)
img = Image.open(f)
tensor = t(img).float()
tensor.require_grad = True
px.append(tensor/255)
j = (j+1) % 20
if j == 0:
xtrain.append(torch.stack(px))
px = []
size[i] += 1
xtrain = torch.stack(xtrain)
ytrain = torch.from_numpy(np.concatenate((np.ones(size[0]), np.zeros(size[1])), axis=0))

"""testing data"""
size = [0, 0]
for i, DIR in enumerate(testDIRs):
for t in tlist:
px = []
j = 0
for filename in sorted(os.listdir(DIR)):
f = os.path.join(DIR, filename)
img = Image.open(f)
tensor = t(img).float()
tensor.require_grad = True
px.append(tensor/255)
j = (j+1) % 20
if j == 0:
xtest.append(torch.stack(px))
px = []
size[i] += 1
xtest = torch.stack(xtest)
idx = torch.randperm(xtest.size(0))
xtest = xtest[idx, :]
splitsize = int(xtest.shape[0]/2)
xval, xtest = xtest.split(splitsize, dim=0)
ytest = torch.from_numpy(np.concatenate((np.ones(size[0]), np.zeros(size[1])), axis=0))
ytest = ytest[idx]
yval, ytest = ytest.split(splitsize, dim=0)
return xtrain, ytrain, xtest, ytest, xval, yval
"""
Dataloader
"""
class DatasetWrapper(Dataset):
def __init__(self, X, y=None):
self.X, self.y = X, y

def __len__(self):
return len(self.X)

def __getitem__(self, idx):
if self.y is None:
return self.X[idx]
else:
return self.X[idx], self.y[idx]

trainDIRs = ['AD_NC/train/AD/', 'AD_NC/train/NC']
testDIRs = ['AD_NC/test/AD/', 'AD_NC/test/NC']
xtrain, ytrain, xtest, ytest, xval, yval = getImages(trainDIRs, testDIRs)
ytrain, ytest = ytrain.type(torch.LongTensor), ytest.type(torch.LongTensor)
xtrain = xtrain.permute(0, 2, 1, 3, 4)
xtest = xtest.permute(0, 2, 1, 3, 4)
xval = xval.permute(0, 2, 1, 3, 4)

def trainloader(batchsize=16):
return DataLoader(DatasetWrapper(xtrain, ytrain), batch_size=batchsize, shuffle=True, pin_memory=True)

def valloader():
return DataLoader(DatasetWrapper(xval, yval), batch_size=1, shuffle=True, pin_memory=True)

def testloader():
return DataLoader(DatasetWrapper(xtest, ytest), batch_size=1, shuffle=True, pin_memory=True)

def trainshape():
return xtrain.shape

def testshape():
return xtest.shape
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 37 additions & 0 deletions recognition/vision-transformer-4696689/extra/conv-block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Conv v2
"""
class ConvLayer2(nn.Module):
def __init__(self):
super().__init__()
#pool
self.pool = nn.MaxPool2d(kernel_size=3, stride=2)
self.relu = nn.ReLU()
#first layer
self.conv11_x = nn.Conv2d(20, 48, kernel_size=(11,11), stride=(4,4), padding=(0,0))
self.conv11_y = nn.Conv2d(240, 48, kernel_size=(11,3), stride=(4,1), padding=(0,0))
self.conv11_z = nn.Conv2d(256, 48, kernel_size=(3,11), stride=(1,4), padding=(0,0))
#second layer
self.conv5_x = nn.Conv2d(48, 192, kernel_size=(5,5), stride=(2,2), padding=(0,0))
self.conv5_y = nn.Conv2d(48, 192, kernel_size=(5,3), stride=(2,1), padding=(0,0))
self.conv5_z = nn.Conv2d(48, 192, kernel_size=(3,5), stride=(1,2), padding=(0,0))
#projection
self.l_x = nn.Linear(30, 32)
self.l_y = nn.Linear(12, 32)
self.l_z = nn.Linear(10, 32)

def forward(self, imgs):
#input N, C, L, W, H
#first layer
x_x = self.relu(self.pool(self.conv11_x(imgs.flatten(1,2))))
x_y = self.relu(self.pool(self.conv11_y(imgs.permute(0,1,3,4,2).flatten(1,2))))
x_z = self.relu(self.pool(self.conv11_z(imgs.permute(0,1,4,2,3).flatten(1,2))))
#second layer
x_x = self.relu(self.pool(self.conv5_x(x_x)))
x_y = self.relu(self.pool(self.conv5_y(x_y)))
x_z = self.relu(self.pool(self.conv5_z(x_z)))
#projection
x_x = self.l_x(x_x.flatten(2,3))
x_y = self.l_y(x_y.flatten(2,3))
x_z = self.l_z(x_z.flatten(2,3))
return torch.cat([x_x, x_y, x_z], dim=2)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 20 additions & 0 deletions recognition/vision-transformer-4696689/extra/parameters.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
AdamW lr=1e-4, 175 epochs, 192, 120, heads=4, embed=360, fflscale=2, nblocks=4
LOSS = [0.72875, 0.70531, 0.66767, 0.61233, 0.53435, 0.49842, 0.43119, 0.45669, 0.38625, 0.35263, 0.36537, 0.32514, 0.26318, 0.2506, 0.24311, 0.18782, 0.17435, 0.13011, 0.14882, 0.17382, 0.10999, 0.13796, 0.07506, 0.06944, 0.06198, 0.03524, 0.07395, 0.09999, 0.04692, 0.03988, 0.0566, 0.02929, 0.01366, 0.01277, 0.01246, 0.01824, 0.04371, 0.0791, 0.04064, 0.04082, 0.01846, 0.00784, 0.00725, 0.00714, 0.0071, 0.00703, 0.00697, 0.00684, 0.00686, 0.00677, 0.00665, 0.00629, 0.00595, 0.01606, 0.11788, 0.21843, 0.02893, 0.01473, 0.04044, 0.02642, 0.02621, 0.00663, 0.00604, 0.00071, 0.00035, 0.00026, 0.00022, 0.0002, 0.00018, 0.00016, 0.00015, 0.00014, 0.00013, 0.00012, 0.00011, 0.0001, 0.0001, 9e-05, 8e-05, 8e-05, 7e-05, 7e-05, 7e-05, 6e-05, 6e-05, 6e-05, 5e-05, 5e-05, 5e-05, 5e-05, 4e-05, 4e-05, 4e-05, 4e-05, 4e-05, 4e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
ACC = [50.67, 51.11, 58.67, 63.11, 57.78, 62.67, 63.56, 66.22, 66.22, 67.11, 66.67, 65.78, 67.56, 65.33, 68.0, 68.44, 67.11, 64.89, 64.89, 67.56, 68.0, 69.33, 67.11, 67.56, 68.0, 67.56, 66.22, 71.11, 69.33, 67.11, 66.67, 69.78, 69.33, 69.78, 69.78, 68.0, 66.67, 68.89, 69.78, 69.78, 68.44, 67.56, 67.11, 67.56, 67.56, 67.56, 68.0, 68.0, 68.0, 68.0, 68.0, 67.56, 67.56, 68.0, 66.22, 70.67, 67.56, 66.67, 68.89, 65.33, 66.67, 70.22, 68.0, 69.78, 68.89, 68.0, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44, 68.44]

to plot:
import matplotlib.pyplot as plt
steps = range(175)
plt.plot(steps, LOSS)
plt.ylabel('LOSS')
plt.xlabel('epoch')
plt.show()
plt.plot(steps, ACC)
plt.ylabel('ACCURACY')
plt.xlabel('epoch')
plt.show()

cuda
training time: 27699.315416812897
test acc: tensor(0.6800)
TIME = [147.563, 146.343, 144.501, 147.546, 144.388, 143.652, 146.672, 144.336, 145.402, 146.032, 144.47, 144.527, 145.94, 145.326, 144.034, 145.458, 146.047, 143.858, 146.212, 144.663, 144.781, 146.169, 143.851, 146.982, 143.694, 145.329, 145.16, 146.066, 144.08, 145.364, 145.876, 143.906, 145.965, 144.99, 144.381, 147.893, 146.199, 144.357, 145.847, 144.55, 144.047, 145.702, 144.852, 143.926, 145.867, 144.55, 144.213, 146.131, 144.313, 144.568, 145.913, 144.292, 147.893, 147.291, 148.067, 148.66, 149.459, 148.164, 148.963, 149.543, 144.27, 145.208, 145.364, 143.899, 146.17, 143.49, 146.005, 144.319, 144.524, 145.954, 143.908, 145.923, 149.609, 148.143, 149.126, 147.25, 143.868, 145.934, 144.889, 144.385, 146.232, 144.071, 145.286, 145.871, 143.787, 145.719, 148.777, 147.816, 149.28, 148.8, 148.009, 149.313, 149.438, 147.923, 148.943, 149.355, 148.399, 148.242, 149.209, 149.388, 148.377, 148.594, 149.603, 148.353, 148.588, 149.617, 148.425, 148.436, 149.528, 148.536, 148.31, 149.578, 148.509, 148.387, 149.569, 148.542, 148.188, 149.53, 148.641, 148.101, 149.468, 148.894, 148.149, 148.935, 149.422, 148.588, 148.187, 149.229, 149.147, 149.19, 148.44, 148.16, 149.419, 148.88, 148.568, 148.514, 148.583, 148.594, 148.789, 148.996, 149.07, 149.142, 148.768, 148.309, 148.454, 148.685, 149.076, 149.272, 148.759, 148.253, 148.44, 149.121, 149.245, 148.525, 148.261, 148.695, 149.247, 149.253, 148.579, 148.307, 149.357, 147.468, 148.775, 147.945, 149.511, 148.644, 148.232, 149.552, 148.53, 148.147, 149.467, 148.824, 148.064, 149.387, 149.3]
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
106 changes: 106 additions & 0 deletions recognition/vision-transformer-4696689/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Imports Here
"""
import numpy as np
import torch
import torch.nn as nn

class Attention(nn.Module):
def __init__(self, heads, embed):
super().__init__()
self.heads = heads
self.attn = nn.MultiheadAttention(embed, heads, batch_first=True)
self.Q = nn.Linear(embed, embed, bias=False)
self.K = nn.Linear(embed, embed, bias=False)
self.V = nn.Linear(embed, embed, bias=False)

def forward(self, x):
Q = self.Q(x)
K = self.K(x)
V = self.V(x)
attnout, attnweights = self.attn(Q, K, V)
return attnout

class TransBlock(nn.Module):
def __init__(self, heads, embed, fflsize):
super().__init__()
self.fnorm = nn.LayerNorm(embed)
self.snorm = nn.LayerNorm(embed)
self.attn = Attention(heads, embed)
self.ffl = nn.Sequential(
nn.Linear(embed, fflsize),
nn.GELU(),
nn.Linear(fflsize, embed)
)

def forward(self, x):
"""
Switching to pre-MHA LayerNorm is supposed to give better performance,
this is used in other models such as LLMs like GPT. Gradients are meant
to be stabilised. This is different to the original ViT paper.
"""
x = x + self.attn(self.fnorm(x))
x = x + self.ffl(self.snorm(x))
return x
"""
Convolution pre
"""
class ConvLayer(nn.Module):
def __init__(self):
super().__init__()
self.pool = nn.MaxPool3d(kernel_size=3, stride=2)
self.relu = nn.ReLU()
self.conv11 = nn.Conv3d(1, 48, kernel_size=(3,11,11), stride=(1,4,4), padding=(1,0,0))
self.conv5 = nn.Conv3d(48, 192, kernel_size=(3,5,5), stride=(1,2,2), padding=(1,0,0))

def forward(self, imgs):
x = self.conv11(imgs)
x = self.relu(self.pool(x))
x = self.conv5(x)
x = self.relu(self.pool(x))
return x
"""
Vision Transformer Class to create a vision transformer model
"""
class VisionTransformer(nn.Module):
def __init__(self, classes=2, inputsize=(1,1,1), heads=2, embed=64, fflscale=2, nblocks=1):
super().__init__()
(self.N, self.Np, self.P) = inputsize
"""components"""
self.proj = nn.Linear(self.P, embed)
self.clstoken = nn.Parameter(torch.randn(1, 1, embed))
self.posembed = self.embedding(self.Np+1, embed)
self.transformer = nn.Sequential(
*((TransBlock(heads, embed, int(fflscale*embed)),)*nblocks)
)
self.classifier = nn.Sequential(
nn.LayerNorm(embed),
nn.Linear(embed, classes)
)
"""convolutional components"""
self.conv = ConvLayer()

def embedding(self, npatches, embed, freq=10000): #10000 is described in ViT paper
posembed = torch.zeros(npatches, embed)
for i in range(npatches):
for j in range(embed):
if j % 2 == 0:
posembed[i][j] = np.sin(i/(freq**(j/embed)))
else:
posembed[i][j] = np.cos(i/(freq**((j-1)/embed)))
return posembed

def forward(self, imgs): #assume size checking done by createPatches
"""Convolutional layer"""
imgs = self.conv(imgs)
imgs = imgs.flatten(2,4)
"""Linear Projection and Positional Embedding"""
tokens = self.proj(imgs) #perform linear projection
clstoken = self.clstoken.repeat(imgs.shape[0], 1, 1)
tokens = torch.cat([clstoken, tokens], dim=1) #concat the class token
x = tokens + self.posembed.repeat(imgs.shape[0], 1, 1) #add positional encoding
"""Transformer"""
x = self.transformer(x)
"""Classification"""
y = x[:,0]
return self.classifier(y)
Loading