diff --git a/recognition/README.md b/recognition/README.md new file mode 100644 index 0000000000..5c646231c2 --- /dev/null +++ b/recognition/README.md @@ -0,0 +1,10 @@ +# Recognition Tasks +Various recognition tasks solved in deep learning frameworks. + +Tasks may include: +* Image Segmentation +* Object detection +* Graph node classification +* Image super resolution +* Disease classification +* Generative modelling with StyleGAN and Stable Diffusion diff --git a/recognition/alzheimers_snn_s4647936/README.md b/recognition/alzheimers_snn_s4647936/README.md new file mode 100644 index 0000000000..c2c738a73f --- /dev/null +++ b/recognition/alzheimers_snn_s4647936/README.md @@ -0,0 +1,216 @@ +# Classifying Alzheimer's Disease using Siamese Networks + +## ๐Ÿ“ฃ Introduction +The task was to "create a classifier based on the Siamese network to classify either Alzheimerโ€™s disease (normal and AD) of the ADNI brain data set or classify the ISIC 2020 Kaggle Challenge data set (normal and melanoma) having an accuracy of around 0.8 on the test set." + +As such, this project aims to distinguish between Alzheimer's disease (AD) and normal (NC) brain images from the ADNI dataset. + +## ๐Ÿง  Dataset +- **Description**: The ADNI dataset is split into 'training' and 'testing' sets. Each set has two categories: 'NC' (Normal Cognitive) and 'AD' (Alzheimer's Disease). Specifically, the training set includes 11120 'NC' and 10400 'AD' images. The test set contains 4460 'AD' and 4540 'NC' images. + + Images in the dataset are slices of a patient's brain. Each image is broken down into patient ID to represent which patient and also their corresponding brain slice. For example, an image like '388206_78' refers to the 78th slice of patient 388206's brain. + +

+ +

+ +

+ Figure 1: Brain scans of AD and NC patient slices of ADNI dataset +

+ +- **Dataset Link**: [ADNI Dataset](http://adni.loni.usc.edu/) (However, the provided preprocessed data from UQ Blackboard was used) +- **Preprocessing Steps**: Images were resized to 256x240 pixels and normalized. Data augmentations like rotations and flips were applied during training (further in-depth preprocessing detailed will be discussed in the **Data Preprocessing** section). + +## ๐Ÿช„ Data Preprocessing + +The dataset preprocessing is aimed at generating triplets of images for training Siamese networks using triplet loss. The triplets consist of an anchor, a positive, and a negative image. + +### Dataset Structure + +The data is organised in a directory structure with separate subdirectories for Alzheimer's Disease (AD) and Normal Control (NC) images. + +``` +root_dir/ +โ”‚ +โ”œโ”€โ”€ train/ +โ”‚ โ”œโ”€โ”€ AD/ +โ”‚ โ””โ”€โ”€ NC/ +โ”‚ +โ””โ”€โ”€ test/ +โ”œโ”€โ”€ AD/ +โ””โ”€โ”€ NC/ +``` + +### Triplet Generation + +For each anchor image from either the AD or NC class: +- A positive image is selected from the same patient but a different slice. +- A negative image is then chosen from the opposite class. + +This ensures that the positive image is similar (as it's from the same patient), whereas the negative image is distinct. + +### Data Augmentation + +During training, data augmentation is applied to the images to introduce variability and improve generalisation. The following augmentations are applied: +- Random rotation of up to 10 degrees. +- Random horizontal flip. +- Random vertical flip. + +### Patient-wise Data Split + +The dataset is split patient-wise rather than image-wise. This ensures that all slices of a particular patient either belong to the training set or the testing set, preventing data leakage. + +### Dataset Implementation + +The `TripletDataset` class, a subclass of `torch.utils.data.Dataset`, facilitates the creation of triplets and data loading. The main components of this class include: + +- `__init__`: Initialises the dataset, loads image paths, applies transformations, and splits data patient-wise. +- `__len__`: Returns the total number of images (AD + NC). +- `__getitem__`: Given an index, it returns a triplet of images along with the label of the anchor image. + +The `patient_wise_split` function is a utility that splits the dataset based on unique patient IDs. It ensures that all images from a single patient are either in the training or testing set. + +## ๐Ÿฑ Siamese Network Architecture + +### Overview +- **Siamese Networks**: These networks consist of twin networks which accept distinct inputs but are joined by an energy function at the top. This energy function computes a metric between the highest-level feature representation on each side. +- **Objective**: The Siamese network's goal is to differentiate between pairs of inputs. In this context, it's used to differentiate between different Alzheimer's Disease (AD) and Normal Control (NC) images. + +### Model Architecture + +#### Feature Extractor +- **Purpose**: This sub-network is responsible for extracting features from an input image. These features form the basis upon which differences or similarities between images are recognised. +- **Implementation**: + - Employs a Convolutional Neural Network (CNN) structure with added depth. + - The first convolutional layer expects grayscale images and outputs 32 channels with a 5x5 kernel. + - The subsequent convolutional layer takes these 32 channels as input, producing 64 channels with another 5x5 kernel. + - Two fully connected layers follow the convolutional layers. The first reduces the dimension to 256, while the final layer further reduces it to a 2-dimensional output for visualisation and analysis. + - ReLU activation functions and max-pooling operations are applied post-convolutional layers, with dropout layers introduced for regularisation. + + +#### Siamese Network +- **Composition**: The Siamese network uses two copies of the above-described Feature Extractor. Both images in a pair go through these identical subnetworks. The outputs from these twin networks are then used to compute the triplet loss. + +### Loss Function + +#### Triplet Loss +- **Purpose**: The Triplet Loss emphasises the relative distance between the anchor-positive pair and the anchor-negative pair. The objective is to ensure that the anchor and positive (both from the same class) are closer to each other in the embedding space than the anchor and negative (from different classes). +- **Implementation**: + - The Euclidean distance between the anchor and positive, as well as the anchor and negative, is computed. + - The difference between these distances, plus a margin, is the loss. The margin ensures a buffer between positive and negative pairs. + +

+ +

+ +

+ Figure 2: Triplet Loss Diagram [1] +

+ +### Additional Classifier +After the Siamese network has been trained, a simple classifier is built on top of the embeddings generated by the network. +- **Architecture**: + - Two-layer feedforward neural network. + - The first layer reduces the 2-dimensional embedding to 64 dimensions using a ReLU activation function. + - The second layer maps the 64 dimensions to 2 outputs, representing the AD and NC classes. + + +## ๐Ÿ‹๏ธ Training and Evaluation + +### Training Procedure +- The Siamese network undergoes training for 30 epochs with a batch size of 16 for improved generalisation. +- An early stopping mechanism halts training if the validation loss doesn't improve over a specified number of epochs. +- The Adam optimizer is employed with a learning rate of 0.0005. + +### Evaluation and Visualisation + +![Siamese Network Training vs Validation Loss](./readme-images/siamese_train_vs_val_loss.png) +#### *Figure 3: Siamese Network Training vs Validation Loss* + +Analysis: The plot above showcases the training and validation loss for the Siamese Network across epochs. The model seems to be converging, although there's some fluctuation in the validation loss, indicating potential overfitting. + + +![t-SNE Visualization of Siamese Network Embeddings](./readme-images/embeddings_tsne.png) +#### *Figure 4: t-SNE Visualisation of Siamese Network Embeddings* +Analysis: The t-SNE visualisation illustrates the 2D representation of the embeddings generated by the Siamese network. There appears to be some clustering, but there's also an overlap between AD and NC embeddings. + + +![Classifier Training vs Validation Loss](./readme-images/classifier_train_vs_val_loss.png) + +#### *Figure 5: Classifier Training vs Validation Loss* + +Analysis: This plot displays the training and validation loss for the classifier built on top of the Siamese network embeddings. The model seems to converge after initial epochs, but a similar fluctuation in validation loss is observed, suggesting overfitting. + +![Confusion Matrix for Classifier Performance](./readme-images/classifier_confusion_matrix.png) + +#### *Figure 6: Confusion Matrix for Classifier Performance* + +Analysis: The confusion matrix reveals the classifier's performance on the test set. It's evident that the classifier struggles to differentiate between AD and NC samples consistently. + +![t-SNE Visualisation of Classifier Embeddings](./readme-images/classifier_embeddings_tsne.png) + +#### Figure 7: t-SNE Visualisation of Classifier Embeddings + +Analysis: This depicted a very interesting plot. The t-SNE visualisation of the classifier embeddings suggests a more distinct separation between AD and NC categories compared to the original Siamese network embeddings. However, there's still a noticeable overlap. + + +## ๐Ÿ› ๏ธ Setup and Installation +Ensure you have Python 3.7+ installed. Then, set up a virtual environment (optional) and install the dependencies: + +| Dependency | Version | +| ------------ | ----------- | +| torch | 1.10.1+cu113 | +| torchvision | 0.11.2+cu113 | +| Pillow | 8.3.2 | +| matplotlib | 3.4.3 | +| seaborn | 0.11.2 | +| numpy | 1.21.3 | +| scikit-learn | 1.0.1 | + +## ๐Ÿ‘จโ€๐Ÿ’ป Usage + +### Preparation: +Before proceeding, ensure you update the paths in the code files to point to the location where you have downloaded the dataset on your device. + +### 1. Train the Siamese Network: + +To train the Siamese network on the provided dataset and obtain embeddings for visualisation: + +`python train.py` + +This will train the Siamese network and save the trained model as `siamese_model.pt`. Visualisations such as the t-SNE plot for the Siamese network embeddings will also be generated. + +### 2. Predict with the Trained Models: + +To predict embeddings for a pair of sample images using the trained Siamese network, modify the `predict.py` script with the paths to your sample images and run: + +`python predict.py` + + +This will generate embeddings for the provided images and predict their classes using the trained Siamese network. + + +## โœจ Results Summary +While the Siamese Network and the subsequent classifier showed promising results during training, the classifier did not achieve the target accuracy of 0.8 on the test set. This is evident from the confusion matrix and the t-SNE visualisations. The models seem to struggle to find a clear boundary between the AD and NC categories. + +## ๐Ÿ”ฎ Future Work & Improvements +1. Experiment with different architectures for the Feature Extractor to achieve more discriminative embeddings. +2. Explore other distance metrics or loss functions that might provide a better separation between the categories. +3. Incorporate more sophisticated data augmentation techniques to improve model generalisation. +4. Investigate methods to reduce overfitting, possibly through more regularisation or employing a more sophisticated early stopping mechanism. +5. Given the target accuracy of 0.8 was not achieved, more extensive hyperparameter tuning and model validation strategies should be explored. + +## ๐Ÿ“š References + +[1] [Triplet Loss โ€” Advanced Intro](https://towardsdatascience.com/triplet-loss-advanced-intro-49a07b7d8905) + +[2] Bromley, Jane, et al. "[Signature verification using a" Siamese" time delay neural network](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf)." Advances in neural information processing systems. 1994. + +[3] G. Koch, R. Zemel, and R. Salakhutdinov, "Siamese neural networks for one-shot image recognition," in ICML deep learning workshop, 2015, vol. 2, no. 1: Lille. + +[4]: Mandal, S. (2023) *Power of Siamese Networks and Triplet Loss: Tackling Unbalanced Datasets*. Medium.com. +"https://medium.com/@mandalsouvik/power-of-siamese-networks-and-triplet-loss-tackling-unbalanced-datasets-ebb2bb6efdb1" + + +[5] R. Takahashi, T. Matsubara, and K. Uehara, "Data augmentation using random image cropping and patching for deep CNNs," IEEE Transactions on Circuits and Systems for Video Technology, vol. 30, no. 9, pp. 2917-2931, 2019. + diff --git a/recognition/alzheimers_snn_s4647936/dataset.py b/recognition/alzheimers_snn_s4647936/dataset.py new file mode 100644 index 0000000000..47c1d4f52d --- /dev/null +++ b/recognition/alzheimers_snn_s4647936/dataset.py @@ -0,0 +1,129 @@ +import os +import random +import torchvision.transforms as transforms +from PIL import Image +from torch.utils.data import Dataset + +class TripletDataset(Dataset): + """ + Generate triplets for training Siamese networks using triplet loss. + + For each anchor image from either the AD or NC class, a positive image is selected from + the same patient but a different slice. A negative image is then chosen from the opposite class. + + Args: + root_dir (str): Root directory containing AD and NC image subdirectories. + mode (str): Either 'train' or 'test'. + transform (callable, optional): Transformations applied to the images. + + Returns: + tuple: A triplet of images - (anchor, positive, negative). + """ + + def __init__(self, root_dir, mode='train', transform=None, split_ratio=0.8): + self.root_dir = root_dir # root_dir = "/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/AD_NC" + self.mode = mode + self.transform = transform + + # Directories for AD and NC images + self.ad_dir = os.path.join(root_dir, mode, 'AD') + self.nc_dir = os.path.join(root_dir, mode, 'NC') + + # Load all image paths + self.ad_paths = [os.path.join(self.ad_dir, img) for img in os.listdir(self.ad_dir)] + self.nc_paths = [os.path.join(self.nc_dir, img) for img in os.listdir(self.nc_dir)] + + train_ad_paths, test_ad_paths, train_nc_paths, test_nc_paths = patient_wise_split(self.ad_paths, self.nc_paths, split_ratio) + + if mode == 'train': + self.ad_paths = train_ad_paths + self.nc_paths = train_nc_paths + + # Integrate data augmentation if in training mode + self.transform = transforms.Compose([ + transforms.RandomRotation(10), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transform + ]) + + elif mode == 'test': + self.ad_paths = test_ad_paths + self.nc_paths = test_nc_paths + print("Sample AD paths:", self.ad_paths[:5]) + print("Sample NC paths:", self.nc_paths[:5]) + + + def __len__(self): + return len(self.ad_paths) + len(self.nc_paths) # combined length + + def __getitem__(self, idx): + # Decide whether to take AD or NC as anchor based on index + if idx < len(self.ad_paths): + anchor_path = self.ad_paths[idx] + positive_paths = self.ad_paths + negative_paths = self.nc_paths + else: + anchor_path = self.nc_paths[idx - len(self.ad_paths)] # offset by length of ad_paths + positive_paths = self.nc_paths + negative_paths = self.ad_paths + + # Extract patient ID from the filename + patient_id = os.path.basename(anchor_path).split('_')[0] + + # Choose a positive image from the same patient + positive_path = random.choice([path for path in positive_paths if os.path.basename(path) != os.path.basename(anchor_path) and patient_id in os.path.basename(path)]) + + # Choose a negative image from a different patient + negative_path = random.choice([path for path in negative_paths if patient_id not in os.path.basename(path)]) + anchor_image = Image.open(anchor_path) + positive_image = Image.open(positive_path) + negative_image = Image.open(negative_path) + + if self.transform: + anchor_image = self.transform(anchor_image) + positive_image = self.transform(positive_image) + negative_image = self.transform(negative_image) + + # Decide label based on anchor image path + label = 0 if "/AD/" in anchor_path else 1 + + return anchor_image, positive_image, negative_image, label + + +def patient_wise_split(ad_paths, nc_paths, split_ratio=0.8): + """ + Split the AD and NC data patient-wise. + + Args: + - ad_paths: List of paths to AD images. + - nc_paths: List of paths to NC images. + - split_ratio: Proportion of data to use for training. + + Returns: + - train_ad_paths: List of AD training paths. + - test_ad_paths: List of AD testing paths. + - train_nc_paths: List of NC training paths. + - test_nc_paths: List of NC testing paths. + """ + + # Extract patient IDs + ad_patient_ids = list(set(os.path.basename(path).split('_')[0] for path in ad_paths)) + nc_patient_ids = list(set(os.path.basename(path).split('_')[0] for path in nc_paths)) + + # Split patient IDs for training and testing + train_ad_ids = random.sample(ad_patient_ids, int(split_ratio * len(ad_patient_ids))) + train_nc_ids = random.sample(nc_patient_ids, int(split_ratio * len(nc_patient_ids))) + + test_ad_ids = list(set(ad_patient_ids) - set(train_ad_ids)) + test_nc_ids = list(set(nc_patient_ids) - set(train_nc_ids)) + + # Get paths based on split IDs + train_ad_paths = [path for path in ad_paths if os.path.basename(path).split('_')[0] in train_ad_ids] + test_ad_paths = [path for path in ad_paths if os.path.basename(path).split('_')[0] in test_ad_ids] + + train_nc_paths = [path for path in nc_paths if os.path.basename(path).split('_')[0] in train_nc_ids] + test_nc_paths = [path for path in nc_paths if os.path.basename(path).split('_')[0] in test_nc_ids] + + return train_ad_paths, test_ad_paths, train_nc_paths, test_nc_paths + diff --git a/recognition/alzheimers_snn_s4647936/modules.py b/recognition/alzheimers_snn_s4647936/modules.py new file mode 100644 index 0000000000..741b0494db --- /dev/null +++ b/recognition/alzheimers_snn_s4647936/modules.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +""" +Feature Extractor +- Sub-network responsible for extracting features from an input image. +- Implemented with simple convolutional neural network (CNN) structure. +""" +class FeatureExtractor(nn.Module): + def __init__(self): + super(FeatureExtractor, self).__init__() + + # Define convolutional layers + self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5) # assuming grayscale images + self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5) + self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5) + + # Define fully connected layers (adjust based on input image size) + self.fc1 = nn.Linear(28*26*128, 512) + self.dropout1 = nn.Dropout(0.5) + self.fc2 = nn.Linear(512, 256) + self.dropout2 = nn.Dropout(0.5) + self.fc3 = nn.Linear(256, 2) + + def forward(self, x): + # Apply the convolutional layers with ReLU and max pooling + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) # kernel size 2 + x = F.max_pool2d(F.relu(self.conv3(x)), 2) + + # Flatten the tensor + x = x.view(x.size(0), 28*26*128) + + # Apply the fully connected layers with ReLU + x = F.relu(self.fc1(x)) + x = self.dropout1(x) + x = F.relu(self.fc2(x)) + x = self.dropout2(x) + x = self.fc3(x) + + return x + +""" +Siamese Network +- Uses two copies of 'FeatureExtractor' to process two images +""" +class SiameseNetwork(nn.Module): + def __init__(self): + super(SiameseNetwork, self).__init__() + + # Use the same feature extractor for both inputs + self.feature_extractor = FeatureExtractor() + + def forward_one(self, x): + # Forward pass for one input + return self.feature_extractor(x) + + def forward(self, input1, input2): + # Forward pass for both inputs + output1 = self.forward_one(input1) + output2 = self.forward_one(input2) + return output1, output2 + +""" +Determine dataset image dimensions +""" +# Path to images +image_path = "/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/AD_NC/train/AD/336537_97.jpeg" # Adjust this path as per your directory structure + +# Open the image and determine its size +image = Image.open(image_path) +width, height = image.size + +# print(f"Image dimensions: {width} x {height}") # Result is 256 x 240 + +""" +Triplet Loss implementation +- Beneficial to choose "hard" triplets +- The anchor and positive would be two different "slices" from the same patient +- The negative would be a "slice" from a different patient +""" +class TripletLoss(nn.Module): + def __init__(self, margin=1.0): + super(TripletLoss, self).__init__() + self.margin = margin + + def forward(self, anchor, positive, negative): + distance_positive = (anchor - positive).pow(2).sum(1) # Euclidean distance + distance_negative = (anchor - negative).pow(2).sum(1) # Euclidean distance + losses = nn.functional.relu(distance_positive - distance_negative + self.margin) + return losses.mean() + +class SimpleClassifier(nn.Module): + def __init__(self): + super(SimpleClassifier, self).__init__() + self.fc1 = nn.Linear(2, 64) # Embedding size is 2 + self.fc2 = nn.Linear(64, 2) # Output is 2 (AD or NC) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x diff --git a/recognition/alzheimers_snn_s4647936/predict.py b/recognition/alzheimers_snn_s4647936/predict.py new file mode 100644 index 0000000000..0f5e672c74 --- /dev/null +++ b/recognition/alzheimers_snn_s4647936/predict.py @@ -0,0 +1,90 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from modules import SiameseNetwork, SimpleClassifier +from torchvision import transforms +from dataset import TripletDataset +import seaborn as sns +from sklearn.metrics import confusion_matrix +import datetime + +# Generate a unique filename with a timestamp +current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + +def get_unique_filename(base_filename): + return f"{base_filename}_{current_time}.png" + +def plot_confusion_matrix(y_true, y_pred, classes, base_filename): + output_filename = get_unique_filename(base_filename) + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(5, 5)) + sns.heatmap(cm, annot=True, fmt='g', cmap='Blues', cbar=False, ax=ax) + ax.set_xlabel('Predicted labels') + ax.set_ylabel('True labels') + ax.set_title('Confusion Matrix') + ax.xaxis.set_ticklabels(classes) + ax.yaxis.set_ticklabels(classes) + plt.savefig(output_filename) + +# Transformations for images +transform = transforms.Compose([ + transforms.Resize((256, 240)), + transforms.ToTensor(), +]) + +# Dataset instance for testing +test_dataset = TripletDataset(root_dir="/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/AD_NC", mode='test', transform=transform) +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False) + +# GPU availability +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Load the trained Siamese Network +siamese_model = SiameseNetwork().to(device) +siamese_model.load_state_dict(torch.load("/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/siamese_model.pth", map_location=device)) +siamese_model.eval() + +# Load the trained Simple Classifier +classifier = SimpleClassifier().to(device) +classifier.load_state_dict(torch.load("/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/classifier_model.pth", map_location=device)) +classifier.eval() + +# Extract embeddings and labels for testing +test_embeddings = [] +test_labels = [] + +with torch.no_grad(): + for anchor, _, _, label in test_loader: + anchor = anchor.to(device) + embedding, _ = siamese_model(anchor, anchor) + test_embeddings.append(embedding.cpu().numpy()) + test_labels.extend(label.tolist()) + +test_embeddings = np.concatenate(test_embeddings) + +# Visualise the embeddings using t-SNE +tsne = TSNE(n_components=2, random_state=42) +embeddings_2d = tsne.fit_transform(test_embeddings) + +plt.figure(figsize=(10, 7)) +plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=test_labels, cmap='jet', alpha=0.5, edgecolors='w', s=40) +plt.colorbar() +plt.title('2D t-SNE of Test Embeddings') +plt.savefig(get_unique_filename('test_embeddings_tsne')) + +# Evaluate classifier on embeddings +test_embeddings_tensor = torch.tensor(test_embeddings).float().to(device) +test_labels_tensor = torch.tensor(test_labels).to(device) + +outputs = classifier(test_embeddings_tensor) +_, predicted = torch.max(outputs, 1) + +# Plot confusion matrix for the classifier +class_names = ["AD", "NC"] +plot_confusion_matrix(test_labels, predicted.cpu().numpy(), class_names, "test_classifier_confusion_matrix") + +correct = (predicted == test_labels_tensor).sum().item() +total = test_labels_tensor.size(0) + +print(f"Accuracy of the classifier on test embeddings: {100 * correct / total}%") diff --git a/recognition/alzheimers_snn_s4647936/readme-images/AD_NC_Scans.png b/recognition/alzheimers_snn_s4647936/readme-images/AD_NC_Scans.png new file mode 100644 index 0000000000..4408ce4a46 Binary files /dev/null and b/recognition/alzheimers_snn_s4647936/readme-images/AD_NC_Scans.png differ diff --git a/recognition/alzheimers_snn_s4647936/readme-images/classifier_confusion_matrix.png b/recognition/alzheimers_snn_s4647936/readme-images/classifier_confusion_matrix.png new file mode 100644 index 0000000000..16b6c37dd1 Binary files /dev/null and b/recognition/alzheimers_snn_s4647936/readme-images/classifier_confusion_matrix.png differ diff --git a/recognition/alzheimers_snn_s4647936/readme-images/classifier_embeddings_tsne.png b/recognition/alzheimers_snn_s4647936/readme-images/classifier_embeddings_tsne.png new file mode 100644 index 0000000000..81dec2b1a2 Binary files /dev/null and b/recognition/alzheimers_snn_s4647936/readme-images/classifier_embeddings_tsne.png differ diff --git a/recognition/alzheimers_snn_s4647936/readme-images/classifier_train_vs_val_loss.png b/recognition/alzheimers_snn_s4647936/readme-images/classifier_train_vs_val_loss.png new file mode 100644 index 0000000000..6a04c91ff3 Binary files /dev/null and b/recognition/alzheimers_snn_s4647936/readme-images/classifier_train_vs_val_loss.png differ diff --git a/recognition/alzheimers_snn_s4647936/readme-images/embeddings_tsne.png b/recognition/alzheimers_snn_s4647936/readme-images/embeddings_tsne.png new file mode 100644 index 0000000000..36764bbe7b Binary files /dev/null and b/recognition/alzheimers_snn_s4647936/readme-images/embeddings_tsne.png differ diff --git a/recognition/alzheimers_snn_s4647936/readme-images/siamese_train_vs_val_loss.png b/recognition/alzheimers_snn_s4647936/readme-images/siamese_train_vs_val_loss.png new file mode 100644 index 0000000000..33bd5944ee Binary files /dev/null and b/recognition/alzheimers_snn_s4647936/readme-images/siamese_train_vs_val_loss.png differ diff --git a/recognition/alzheimers_snn_s4647936/readme-images/triplet_loss.png b/recognition/alzheimers_snn_s4647936/readme-images/triplet_loss.png new file mode 100644 index 0000000000..c7ec72ded3 Binary files /dev/null and b/recognition/alzheimers_snn_s4647936/readme-images/triplet_loss.png differ diff --git a/recognition/alzheimers_snn_s4647936/train.py b/recognition/alzheimers_snn_s4647936/train.py new file mode 100644 index 0000000000..cfdb660593 --- /dev/null +++ b/recognition/alzheimers_snn_s4647936/train.py @@ -0,0 +1,327 @@ +import os +import torch +import matplotlib.pyplot as plt +from dataset import TripletDataset +from modules import SiameseNetwork, TripletLoss, SimpleClassifier +from torchvision import transforms +import torch.optim as optim +from sklearn.manifold import TSNE +import numpy as np +import torch.nn as nn +import seaborn as sns +from sklearn.metrics import confusion_matrix +import datetime + +# Generate a unique filename with a timestamp +current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + +def get_unique_filename(base_filename): + return f"{base_filename}_{current_time}.png" + +def plot_confusion_matrix(y_true, y_pred, classes, base_filename): + output_filename = get_unique_filename(base_filename) + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(5, 5)) + sns.heatmap(cm, annot=True, fmt='g', cmap='Blues', cbar=False, ax=ax) + ax.set_xlabel('Predicted labels') + ax.set_ylabel('True labels') + ax.set_title('Confusion Matrix') + ax.xaxis.set_ticklabels(classes) + ax.yaxis.set_ticklabels(classes) + plt.savefig(output_filename) + +# Transformations for images +transform = transforms.Compose([ + transforms.Resize((256, 240)), + transforms.ToTensor(), +]) + +# Dataset instances +train_dataset = TripletDataset(root_dir="/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/AD_NC", mode='train', transform=transform) +test_dataset = TripletDataset(root_dir="/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/AD_NC", mode='test', transform=transform) + +# Parameters +learning_rate = 0.0005 +num_epochs = 30 +batch_size = 16 + +# GPU availability +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if not torch.cuda.is_available(): + print("No CUDA Found. Using CPU") + +# Initialise the Siamese Network and Triplet Loss +model = SiameseNetwork().to(device) +criterion = TripletLoss(margin=1.0).to(device) +optimizer = optim.Adam(model.parameters(), lr=learning_rate) + +# DataLoader setup +train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + +# Lists to store training and validation losses +training_losses = [] +validation_losses = [] + +# Early stopping parameters +patience = 5 +best_val_loss = float('inf') +epochs_without_improvement = 0 + +# Training loop for Siamese Network +for epoch in range(num_epochs): + model.train() + running_loss = 0.0 + + for batch_idx, (anchor, positive, negative, labels) in enumerate(train_loader): + anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device) + + # Zero the parameter gradients + optimizer.zero_grad() + + # Forward pass + anchor_out, positive_out = model(anchor, positive) + _, negative_out = model(anchor, negative) + + loss = criterion(anchor_out, positive_out, negative_out) + loss.backward() + optimizer.step() + + running_loss += loss.item() + + # Calculate average training loss for the epoch + epoch_loss = running_loss / len(train_loader) + training_losses.append(epoch_loss) + + # Validation step + model.eval() # set the model to evaluation mode + val_running_loss = 0.0 + with torch.no_grad(): # deactivate autograd engine to reduce memory usage and speed up computations + for val_anchor, val_positive, val_negative, _ in test_loader: + val_anchor, val_positive, val_negative = val_anchor.to(device), val_positive.to(device), val_negative.to(device) + val_anchor_out, val_positive_out = model(val_anchor, val_positive) + _, val_negative_out = model(val_anchor, val_negative) + val_loss = criterion(val_anchor_out, val_positive_out, val_negative_out) + val_running_loss += val_loss.item() + + # Calculate average validation loss for the epoch + val_epoch_loss = val_running_loss / len(test_loader) + validation_losses.append(val_epoch_loss) + + # Early stopping logic + if val_epoch_loss < best_val_loss: + best_val_loss = val_epoch_loss + epochs_without_improvement = 0 + else: + epochs_without_improvement += 1 + if epochs_without_improvement == patience: + print("Early stopping due to no validation loss improvement.") + break + + print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_epoch_loss:.4f}") + +print("Finished Training Siamese Network") + +# After training, save the Siamese Network model weights +torch.save(model.state_dict(), "/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/siamese_model.pth") +print("Saved Siamese Network model weights") + +""" +Save and visualise results +""" + +# Save the loss curve +plt.figure() +plt.plot(training_losses, label='Training Loss') +plt.plot(validation_losses, label='Validation Loss') +plt.xlabel('Epochs') +plt.ylabel('Loss') +plt.title('Siamese Network Training vs Validation Loss') +plt.legend() +current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') +losses_file = f'siamese_train_vs_val_loss_{current_time}.png' +plt.savefig(losses_file) + +# Function to save images +def save_image(img, base_filename): + # Generate a unique filename with a timestamp + current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + filename = f"{base_filename}_{current_time}.png" + + # Select the first image from the batch + img = img[0] + + # Move tensor to CPU and convert to numpy + img = img.cpu().numpy() + + # Transpose from [channels, height, width] to [height, width, channels] + img = img.transpose((1, 2, 0)) + + # Convert to float and normalize if necessary + if img.max() > 1: + img = img.astype(float) / 255 + + plt.figure() + plt.imshow(img) + plt.axis('off') # Hide axes + plt.savefig(filename) + +# Save sample images after training +# save_image(anchor, 'anchor_sample.png') +# save_image(positive, 'positive_sample.png') +# save_image(negative, 'negative_sample.png') + +# --------- Visualize Embeddings using t-SNE --------- +# Extract embeddings and labels +all_embeddings = [] +all_labels = [] + +# Assuming you have two classes: AD and NC. Let's assign them numeric labels. +# AD: 0, NC: 1 +with torch.no_grad(): + for anchor, _, _, label in train_loader: + anchor = anchor.to(device) + embedding, _ = model(anchor, anchor) + all_embeddings.append(embedding.cpu().numpy()) + all_labels.extend(label.tolist()) + +print(f"Number of AD labels: {all_labels.count(0)}") +print(f"Number of NC labels: {all_labels.count(1)}") + +all_embeddings = np.concatenate(all_embeddings) + +# Reduce dimensionality using t-SNE +tsne = TSNE(n_components=2, random_state=42) +embeddings_2d = tsne.fit_transform(all_embeddings) + +# Plot +current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') +plt.figure(figsize=(10, 7)) +plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=all_labels, cmap='jet', alpha=0.5, edgecolors='w', s=40) +plt.colorbar() +plt.title('2D t-SNE of Embeddings') +plt.savefig(f'embeddings_tsne_{current_time}.png') + +# --------- Extract Embeddings for the Entire Dataset --------- +train_embeddings = [] +train_labels = [] +test_embeddings = [] +test_labels = [] + +with torch.no_grad(): + for anchor, _, _, label in train_loader: + anchor = anchor.to(device) + embedding, _ = model(anchor, anchor) + train_embeddings.append(embedding.cpu().numpy()) + train_labels.extend(label.tolist()) + + for anchor, _, _, label in test_loader: + anchor = anchor.to(device) + embedding, _ = model(anchor, anchor) + test_embeddings.append(embedding.cpu().numpy()) + test_labels.extend(label.tolist()) + +train_embeddings = np.concatenate(train_embeddings) +test_embeddings = np.concatenate(test_embeddings) + +# --------- Train Simple Classifier --------- +classifier = SimpleClassifier().to(device) +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(classifier.parameters(), lr=0.001) + +classifier_training_losses = [] +classifier_validation_losses = [] + +# Convert embeddings list to tensor +train_embeddings_tensor = torch.tensor(train_embeddings).float().to(device) +test_embeddings_tensor = torch.tensor(test_embeddings).float().to(device) +train_labels_tensor = torch.tensor(train_labels).to(device) +test_labels_tensor = torch.tensor(test_labels).to(device) + + +for epoch in range(num_epochs): + # Train with embeddings + running_loss = 0.0 + + for embeddings, labels in zip(train_embeddings, train_labels): + embeddings_tensor = torch.tensor(embeddings).float().to(device) + optimizer.zero_grad() + outputs = classifier(embeddings_tensor) + loss = criterion(outputs, torch.tensor(labels).to(device)) + loss.backward() + optimizer.step() + running_loss += loss.item() + + # Average training loss for the epoch + epoch_loss = running_loss / len(train_labels) + classifier_training_losses.append(epoch_loss) + + # Validation loss + val_running_loss = 0.0 + with torch.no_grad(): + for embeddings, labels in zip(test_embeddings, test_labels): + embeddings_tensor = torch.tensor(embeddings).float().to(device) + outputs = classifier(embeddings_tensor) + val_loss = criterion(outputs, torch.tensor(labels).to(device)) + val_running_loss += val_loss.item() + + # Average validation loss for the epoch + val_epoch_loss = val_running_loss / len(test_labels) + classifier_validation_losses.append(val_epoch_loss) + + print(f"Classifier Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_epoch_loss:.4f}") + + +# After training, save the classifier model weights +torch.save(classifier.state_dict(), "/home/Student/s4647936/PatternAnalysis-2023/recognition/alzheimers_snn_s4647936/classifier_model.pth") +print("Saved classifier model weights") + +# Plotting the classifier training vs validation losses +plt.figure() +plt.plot(classifier_training_losses, label='Classifier Training Loss') +plt.plot(classifier_validation_losses, label='Classifier Validation Loss') +plt.xlabel('Epochs') +plt.ylabel('Loss') +plt.title('Classifier Training vs Validation Loss') +plt.legend() +plt.savefig(f'classifier_train_vs_val_loss_{current_time}.png') + +# --------- Extract embeddings for visualization after classifier --------- +all_classifier_embeddings = [] +all_labels = [] + +with torch.no_grad(): + for embeddings, labels in zip(test_embeddings, test_labels): + outputs = classifier(embeddings_tensor) + all_classifier_embeddings.append(outputs.cpu().numpy()) + all_labels.append(labels) + +all_classifier_embeddings = np.array(all_classifier_embeddings) +all_classifier_embeddings = all_classifier_embeddings.reshape(-1, all_classifier_embeddings.shape[-1]) + +# Reduce dimensionality using t-SNE +tsne = TSNE(n_components=2, random_state=42) +embeddings_2d = tsne.fit_transform(all_classifier_embeddings) + +# Plot +plt.figure(figsize=(10, 7)) +plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=all_labels, cmap='jet', alpha=0.5, edgecolors='w', s=40) +plt.colorbar() +plt.title('2D t-SNE of Classifier Embeddings') +plt.savefig(get_unique_filename('classifier_embeddings_tsne')) + +# --------- Evaluate Classifier --------- +test_embeddings_tensor = torch.tensor(test_embeddings).to(device) +test_labels_tensor = torch.tensor(test_labels).to(device) + +outputs = classifier(test_embeddings_tensor) +_, predicted = torch.max(outputs, 1) + +# Plot confusion matrix for the classifier +class_names = ["AD", "NC"] # Assuming AD is labeled as 0 and NC as 1 +plot_confusion_matrix(test_labels, predicted.cpu().numpy(), class_names, "classifier_confusion_matrix") + +correct = (predicted == test_labels_tensor).sum().item() +total = test_labels_tensor.size(0) + +print(f"Accuracy of the classifier on test embeddings: {100 * correct / total}%")