Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Topic recognition - s4647936 #165

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
c3aff8b
Added recognition branch and README for info.
shakes76 Sep 17, 2023
7e5bc0e
Initial setup for Siamese network Alzheimer’s disease classification.
tranvicky Oct 4, 2023
8d8b7f2
Renamed README.MD to README.md
tranvicky Oct 4, 2023
9aabb12
Update to contain rough outline for README.md
tranvicky Oct 4, 2023
7e03b36
Initial draft of Siamese network architecture in modules.py
tranvicky Oct 10, 2023
7d27279
Added section in code to help determine dataset image dimensions
tranvicky Oct 10, 2023
7cacc40
Merge branch 'topic-recognition' of https://github.com/tranvicky/Patt…
tranvicky Oct 10, 2023
dea52f6
•TripletLoss setup: Create an initial structure for the triplet loss …
tranvicky Oct 10, 2023
cbb5bea
Added TripletDataset to dataset.py to generate anchor, positive (same…
tranvicky Oct 10, 2023
e096a1e
Refactored TripletDataset for enhanced adaptability, ensuring balance…
tranvicky Oct 14, 2023
0628f8c
Added image transformations and initialized train and test datasets i…
tranvicky Oct 14, 2023
3387cc5
Added functionality to display and save sample triplet images while c…
tranvicky Oct 14, 2023
0d39d54
Refactored TripletDataset to ensure patient-wise split for train-test…
tranvicky Oct 14, 2023
33b1d47
Added testing to determine size of training sets after splitting
tranvicky Oct 14, 2023
e5b4344
Merge branch 'topic-recognition' of https://github.com/tranvicky/Patt…
tranvicky Oct 14, 2023
b04d801
Added additional data augmentation steps (random rotation, horizontal…
tranvicky Oct 14, 2023
cc2038f
Integrated the training loop for the Siamese Network into train.py. A…
tranvicky Oct 23, 2023
a1058d9
Fixed method save_image in train.y so that it properly saves images.
tranvicky Oct 23, 2023
22adb4c
Added validation/testing loop to evaluate model on test set post-trai…
tranvicky Oct 23, 2023
2896300
Implement predict.py to load the trained Siamese Network model and pe…
tranvicky Oct 23, 2023
561d8a3
Integrated PCA-based visualization for model embeddings in train.py a…
tranvicky Oct 23, 2023
3c30104
Changed embedding plots from PCA to T-SNE
tranvicky Oct 23, 2023
eda84c8
Added functionality to extract and store embeddings for the entire da…
tranvicky Oct 23, 2023
0cc7e8d
Modified dataset.py to include labels, allowing for easier extraction…
tranvicky Oct 23, 2023
6a77f09
Added SimpleClassifier class in the neural network module to facilita…
tranvicky Oct 23, 2023
3e7fc30
Implemented the training loop for the Simple Classifier using extract…
tranvicky Oct 23, 2023
f0ba8da
Added evaluation phase for the Simple Classifier using the test embed…
tranvicky Oct 23, 2023
5f9c135
Changed labels to detect if AD or NC in it's path.
tranvicky Oct 23, 2023
2afb4e3
Added method plot_confusion_matrix to plot confusion matrix so can vi…
tranvicky Oct 23, 2023
3010854
Changed train.py so code can print and out plot of train vs validatio…
tranvicky Oct 23, 2023
879bc08
Added tracking and visualization for the classifier's training and va…
tranvicky Oct 23, 2023
5f66d81
Changed code to fix issue of tsne graph displaying only one colour
tranvicky Oct 23, 2023
6f2924f
Add code for SNN to be stopped early if validation loss doesn't impro…
tranvicky Oct 23, 2023
c4a4810
Added changed to train.py to resolve issue of incorrect indexing with…
tranvicky Oct 24, 2023
2b5d722
Changed train.py code to save plot of siamese training vs losses.
tranvicky Oct 24, 2023
a9e5b91
Solved issues with pconfusion matrix plot not showing and issue with …
tranvicky Oct 25, 2023
bc90ea3
Introduced more convolutional layers and added dropout after each fc …
tranvicky Oct 25, 2023
a11b716
Adjusted hyperparameters for learning rate and batch size.
tranvicky Oct 25, 2023
ca94001
Updated README.md to fill out all of the uncompleted sections.
tranvicky Oct 25, 2023
1af6bb3
Changed code to resolve dimension bugs after changed number of fc lay…
tranvicky Oct 25, 2023
dbdaff3
Updated README.d to include visualisations/plots produced during trai…
tranvicky Oct 25, 2023
28c03b8
Updated README.md to reflect updated dependencies.
tranvicky Oct 25, 2023
bad03a6
Updated predict.py so that if user wants to visualise the model, they…
tranvicky Oct 25, 2023
bf8a1a4
Updated README.md to detail a more in-depth setup guide and running f…
tranvicky Oct 25, 2023
53f5e2f
Updated train.py to resolve issue about numpy to tensor error.
tranvicky Oct 25, 2023
1fb431b
Updated references for README.md
tranvicky Oct 25, 2023
5814d1b
Removed unecessary print statements from train.py, modules.py and dat…
tranvicky Oct 25, 2023
8d8ad25
Provide folder for images/plots/visualisations needed for README.md.
tranvicky Oct 25, 2023
77b19d5
Update README.md to resolve issue of images only showing links instea…
tranvicky Oct 25, 2023
fb57496
Updated README.md and fixed image pathways.
tranvicky Oct 25, 2023
6aee14a
Changed variable outputs to resolve tensor error in train.py.
tranvicky Oct 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions recognition/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Recognition Tasks
Various recognition tasks solved in deep learning frameworks.

Tasks may include:
* Image Segmentation
* Object detection
* Graph node classification
* Image super resolution
* Disease classification
* Generative modelling with StyleGAN and Stable Diffusion
216 changes: 216 additions & 0 deletions recognition/alzheimers_snn_s4647936/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Classifying Alzheimer's Disease using Siamese Networks

## 📣 Introduction
The task was to "create a classifier based on the Siamese network to classify either Alzheimer’s disease (normal and AD) of the ADNI brain data set or classify the ISIC 2020 Kaggle Challenge data set (normal and melanoma) having an accuracy of around 0.8 on the test set."

As such, this project aims to distinguish between Alzheimer's disease (AD) and normal (NC) brain images from the ADNI dataset.

## 🧠 Dataset
- **Description**: The ADNI dataset is split into 'training' and 'testing' sets. Each set has two categories: 'NC' (Normal Cognitive) and 'AD' (Alzheimer's Disease). Specifically, the training set includes 11120 'NC' and 10400 'AD' images. The test set contains 4460 'AD' and 4540 'NC' images.

Images in the dataset are slices of a patient's brain. Each image is broken down into patient ID to represent which patient and also their corresponding brain slice. For example, an image like '388206_78' refers to the 78th slice of patient 388206's brain.

<p align="center">
<img src="./readme-images/AD_NC_Scans.png">
</p>

<p align="center">
<em> Figure 1: Brain scans of AD and NC patient slices of ADNI dataset</em>
</p>

- **Dataset Link**: [ADNI Dataset](http://adni.loni.usc.edu/) (However, the provided preprocessed data from UQ Blackboard was used)
- **Preprocessing Steps**: Images were resized to 256x240 pixels and normalized. Data augmentations like rotations and flips were applied during training (further in-depth preprocessing detailed will be discussed in the **Data Preprocessing** section).

## 🪄 Data Preprocessing

The dataset preprocessing is aimed at generating triplets of images for training Siamese networks using triplet loss. The triplets consist of an anchor, a positive, and a negative image.

### Dataset Structure

The data is organised in a directory structure with separate subdirectories for Alzheimer's Disease (AD) and Normal Control (NC) images.

```
root_dir/
├── train/
│ ├── AD/
│ └── NC/
└── test/
├── AD/
└── NC/
```

### Triplet Generation

For each anchor image from either the AD or NC class:
- A positive image is selected from the same patient but a different slice.
- A negative image is then chosen from the opposite class.

This ensures that the positive image is similar (as it's from the same patient), whereas the negative image is distinct.

### Data Augmentation

During training, data augmentation is applied to the images to introduce variability and improve generalisation. The following augmentations are applied:
- Random rotation of up to 10 degrees.
- Random horizontal flip.
- Random vertical flip.

### Patient-wise Data Split

The dataset is split patient-wise rather than image-wise. This ensures that all slices of a particular patient either belong to the training set or the testing set, preventing data leakage.

### Dataset Implementation

The `TripletDataset` class, a subclass of `torch.utils.data.Dataset`, facilitates the creation of triplets and data loading. The main components of this class include:

- `__init__`: Initialises the dataset, loads image paths, applies transformations, and splits data patient-wise.
- `__len__`: Returns the total number of images (AD + NC).
- `__getitem__`: Given an index, it returns a triplet of images along with the label of the anchor image.

The `patient_wise_split` function is a utility that splits the dataset based on unique patient IDs. It ensures that all images from a single patient are either in the training or testing set.

## 🐱 Siamese Network Architecture

### Overview
- **Siamese Networks**: These networks consist of twin networks which accept distinct inputs but are joined by an energy function at the top. This energy function computes a metric between the highest-level feature representation on each side.
- **Objective**: The Siamese network's goal is to differentiate between pairs of inputs. In this context, it's used to differentiate between different Alzheimer's Disease (AD) and Normal Control (NC) images.

### Model Architecture

#### Feature Extractor
- **Purpose**: This sub-network is responsible for extracting features from an input image. These features form the basis upon which differences or similarities between images are recognised.
- **Implementation**:
- Employs a Convolutional Neural Network (CNN) structure with added depth.
- The first convolutional layer expects grayscale images and outputs 32 channels with a 5x5 kernel.
- The subsequent convolutional layer takes these 32 channels as input, producing 64 channels with another 5x5 kernel.
- Two fully connected layers follow the convolutional layers. The first reduces the dimension to 256, while the final layer further reduces it to a 2-dimensional output for visualisation and analysis.
- ReLU activation functions and max-pooling operations are applied post-convolutional layers, with dropout layers introduced for regularisation.


#### Siamese Network
- **Composition**: The Siamese network uses two copies of the above-described Feature Extractor. Both images in a pair go through these identical subnetworks. The outputs from these twin networks are then used to compute the triplet loss.

### Loss Function

#### Triplet Loss
- **Purpose**: The Triplet Loss emphasises the relative distance between the anchor-positive pair and the anchor-negative pair. The objective is to ensure that the anchor and positive (both from the same class) are closer to each other in the embedding space than the anchor and negative (from different classes).
- **Implementation**:
- The Euclidean distance between the anchor and positive, as well as the anchor and negative, is computed.
- The difference between these distances, plus a margin, is the loss. The margin ensures a buffer between positive and negative pairs.

<p align="center">
<img src="./readme-images/triplet_loss.png">
</p>

<p align="center">
<em> Figure 2: Triplet Loss Diagram [1]</em>
</p>

### Additional Classifier
After the Siamese network has been trained, a simple classifier is built on top of the embeddings generated by the network.
- **Architecture**:
- Two-layer feedforward neural network.
- The first layer reduces the 2-dimensional embedding to 64 dimensions using a ReLU activation function.
- The second layer maps the 64 dimensions to 2 outputs, representing the AD and NC classes.


## 🏋️ Training and Evaluation

### Training Procedure
- The Siamese network undergoes training for 30 epochs with a batch size of 16 for improved generalisation.
- An early stopping mechanism halts training if the validation loss doesn't improve over a specified number of epochs.
- The Adam optimizer is employed with a learning rate of 0.0005.

### Evaluation and Visualisation

![Siamese Network Training vs Validation Loss](./readme-images/siamese_train_vs_val_loss.png)
#### *Figure 3: Siamese Network Training vs Validation Loss*

Analysis: The plot above showcases the training and validation loss for the Siamese Network across epochs. The model seems to be converging, although there's some fluctuation in the validation loss, indicating potential overfitting.


![t-SNE Visualization of Siamese Network Embeddings](./readme-images/embeddings_tsne.png)
#### *Figure 4: t-SNE Visualisation of Siamese Network Embeddings*
Analysis: The t-SNE visualisation illustrates the 2D representation of the embeddings generated by the Siamese network. There appears to be some clustering, but there's also an overlap between AD and NC embeddings.


![Classifier Training vs Validation Loss](./readme-images/classifier_train_vs_val_loss.png)

#### *Figure 5: Classifier Training vs Validation Loss*

Analysis: This plot displays the training and validation loss for the classifier built on top of the Siamese network embeddings. The model seems to converge after initial epochs, but a similar fluctuation in validation loss is observed, suggesting overfitting.

![Confusion Matrix for Classifier Performance](./readme-images/classifier_confusion_matrix.png)

#### *Figure 6: Confusion Matrix for Classifier Performance*

Analysis: The confusion matrix reveals the classifier's performance on the test set. It's evident that the classifier struggles to differentiate between AD and NC samples consistently.

![t-SNE Visualisation of Classifier Embeddings](./readme-images/classifier_embeddings_tsne.png)

#### Figure 7: t-SNE Visualisation of Classifier Embeddings

Analysis: This depicted a very interesting plot. The t-SNE visualisation of the classifier embeddings suggests a more distinct separation between AD and NC categories compared to the original Siamese network embeddings. However, there's still a noticeable overlap.


## 🛠️ Setup and Installation
Ensure you have Python 3.7+ installed. Then, set up a virtual environment (optional) and install the dependencies:

| Dependency | Version |
| ------------ | ----------- |
| torch | 1.10.1+cu113 |
| torchvision | 0.11.2+cu113 |
| Pillow | 8.3.2 |
| matplotlib | 3.4.3 |
| seaborn | 0.11.2 |
| numpy | 1.21.3 |
| scikit-learn | 1.0.1 |

## 👨‍💻 Usage

### Preparation:
Before proceeding, ensure you update the paths in the code files to point to the location where you have downloaded the dataset on your device.

### 1. Train the Siamese Network:

To train the Siamese network on the provided dataset and obtain embeddings for visualisation:

`python train.py`

This will train the Siamese network and save the trained model as `siamese_model.pt`. Visualisations such as the t-SNE plot for the Siamese network embeddings will also be generated.

### 2. Predict with the Trained Models:

To predict embeddings for a pair of sample images using the trained Siamese network, modify the `predict.py` script with the paths to your sample images and run:

`python predict.py`


This will generate embeddings for the provided images and predict their classes using the trained Siamese network.


## ✨ Results Summary
While the Siamese Network and the subsequent classifier showed promising results during training, the classifier did not achieve the target accuracy of 0.8 on the test set. This is evident from the confusion matrix and the t-SNE visualisations. The models seem to struggle to find a clear boundary between the AD and NC categories.

## 🔮 Future Work & Improvements
1. Experiment with different architectures for the Feature Extractor to achieve more discriminative embeddings.
2. Explore other distance metrics or loss functions that might provide a better separation between the categories.
3. Incorporate more sophisticated data augmentation techniques to improve model generalisation.
4. Investigate methods to reduce overfitting, possibly through more regularisation or employing a more sophisticated early stopping mechanism.
5. Given the target accuracy of 0.8 was not achieved, more extensive hyperparameter tuning and model validation strategies should be explored.

## 📚 References

[1] [Triplet Loss — Advanced Intro](https://towardsdatascience.com/triplet-loss-advanced-intro-49a07b7d8905)

[2] Bromley, Jane, et al. "[Signature verification using a" Siamese" time delay neural network](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf)." Advances in neural information processing systems. 1994.

[3] G. Koch, R. Zemel, and R. Salakhutdinov, "Siamese neural networks for one-shot image recognition," in ICML deep learning workshop, 2015, vol. 2, no. 1: Lille.

[4]: Mandal, S. (2023) *Power of Siamese Networks and Triplet Loss: Tackling Unbalanced Datasets*. Medium.com.
"https://medium.com/@mandalsouvik/power-of-siamese-networks-and-triplet-loss-tackling-unbalanced-datasets-ebb2bb6efdb1"


[5] R. Takahashi, T. Matsubara, and K. Uehara, "Data augmentation using random image cropping and patching for deep CNNs," IEEE Transactions on Circuits and Systems for Video Technology, vol. 30, no. 9, pp. 2917-2931, 2019.

129 changes: 129 additions & 0 deletions recognition/alzheimers_snn_s4647936/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import random
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset

class TripletDataset(Dataset):
"""
Generate triplets for training Siamese networks using triplet loss.

For each anchor image from either the AD or NC class, a positive image is selected from
the same patient but a different slice. A negative image is then chosen from the opposite class.

Args:
root_dir (str): Root directory containing AD and NC image subdirectories.
mode (str): Either 'train' or 'test'.
transform (callable, optional): Transformations applied to the images.

Returns:
tuple: A triplet of images - (anchor, positive, negative).
"""

def __init__(self, root_dir, mode='train', transform=None, split_ratio=0.8):
self.root_dir = root_dir # root_dir = "/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/AD_NC"
self.mode = mode
self.transform = transform

# Directories for AD and NC images
self.ad_dir = os.path.join(root_dir, mode, 'AD')
self.nc_dir = os.path.join(root_dir, mode, 'NC')

# Load all image paths
self.ad_paths = [os.path.join(self.ad_dir, img) for img in os.listdir(self.ad_dir)]
self.nc_paths = [os.path.join(self.nc_dir, img) for img in os.listdir(self.nc_dir)]

train_ad_paths, test_ad_paths, train_nc_paths, test_nc_paths = patient_wise_split(self.ad_paths, self.nc_paths, split_ratio)

if mode == 'train':
self.ad_paths = train_ad_paths
self.nc_paths = train_nc_paths

# Integrate data augmentation if in training mode
self.transform = transforms.Compose([
transforms.RandomRotation(10),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transform
])

elif mode == 'test':
self.ad_paths = test_ad_paths
self.nc_paths = test_nc_paths
print("Sample AD paths:", self.ad_paths[:5])
print("Sample NC paths:", self.nc_paths[:5])


def __len__(self):
return len(self.ad_paths) + len(self.nc_paths) # combined length

def __getitem__(self, idx):
# Decide whether to take AD or NC as anchor based on index
if idx < len(self.ad_paths):
anchor_path = self.ad_paths[idx]
positive_paths = self.ad_paths
negative_paths = self.nc_paths
else:
anchor_path = self.nc_paths[idx - len(self.ad_paths)] # offset by length of ad_paths
positive_paths = self.nc_paths
negative_paths = self.ad_paths

# Extract patient ID from the filename
patient_id = os.path.basename(anchor_path).split('_')[0]

# Choose a positive image from the same patient
positive_path = random.choice([path for path in positive_paths if os.path.basename(path) != os.path.basename(anchor_path) and patient_id in os.path.basename(path)])

# Choose a negative image from a different patient
negative_path = random.choice([path for path in negative_paths if patient_id not in os.path.basename(path)])
anchor_image = Image.open(anchor_path)
positive_image = Image.open(positive_path)
negative_image = Image.open(negative_path)

if self.transform:
anchor_image = self.transform(anchor_image)
positive_image = self.transform(positive_image)
negative_image = self.transform(negative_image)

# Decide label based on anchor image path
label = 0 if "/AD/" in anchor_path else 1

return anchor_image, positive_image, negative_image, label


def patient_wise_split(ad_paths, nc_paths, split_ratio=0.8):
"""
Split the AD and NC data patient-wise.

Args:
- ad_paths: List of paths to AD images.
- nc_paths: List of paths to NC images.
- split_ratio: Proportion of data to use for training.

Returns:
- train_ad_paths: List of AD training paths.
- test_ad_paths: List of AD testing paths.
- train_nc_paths: List of NC training paths.
- test_nc_paths: List of NC testing paths.
"""

# Extract patient IDs
ad_patient_ids = list(set(os.path.basename(path).split('_')[0] for path in ad_paths))
nc_patient_ids = list(set(os.path.basename(path).split('_')[0] for path in nc_paths))

# Split patient IDs for training and testing
train_ad_ids = random.sample(ad_patient_ids, int(split_ratio * len(ad_patient_ids)))
train_nc_ids = random.sample(nc_patient_ids, int(split_ratio * len(nc_patient_ids)))

test_ad_ids = list(set(ad_patient_ids) - set(train_ad_ids))
test_nc_ids = list(set(nc_patient_ids) - set(train_nc_ids))

# Get paths based on split IDs
train_ad_paths = [path for path in ad_paths if os.path.basename(path).split('_')[0] in train_ad_ids]
test_ad_paths = [path for path in ad_paths if os.path.basename(path).split('_')[0] in test_ad_ids]

train_nc_paths = [path for path in nc_paths if os.path.basename(path).split('_')[0] in train_nc_ids]
test_nc_paths = [path for path in nc_paths if os.path.basename(path).split('_')[0] in test_nc_ids]

return train_ad_paths, test_ad_paths, train_nc_paths, test_nc_paths

Loading