This repository contains the PyTorch implementation of the ACB-MSE loss function, which stands for Automatic Class Balanced Mean Squared Error, originally developed for the DEEPCLEAN3D Denoiser to combat class imbalance and stabilise loss gradient fluctuation due to dramatically varying class frequencies.
Available on PyPi
pip install acb_mse
- Python 3.x
- PyTorch (tested with version 2.0.1)
zero_weighting
(float, optional): Weighting coefficient for MSE loss of zero pixels. Default is 1.nonzero_weighting
(float, optional): Weighting coefficient for MSE loss of non-zero pixels. Default is 1.
- Input (torch.Tensor):
$( * )$ , where$( * )$ means any number of dimensions. - Target (torch.Tensor):
$( * )$ , same shape as the input.
- Output (float): Calculated loss value.
import torch
from acb_mse import ACBLoss
# Select weighting for each class if not wanting to use the defualt 1:1 weighting
zero_weighting = 1.0
nonzero_weighting = 1.2
# Create an instance of the ACBMSE loss function with specified weighting coefficients
loss_function = ACBLoss(zero_weighting, nonzero_weighting)
# Dummy target image and reconstructed image tensors (assuming B=10, C=3, H=256, W=256)
target_image = torch.rand(10, 3, 256, 256)
reconstructed_image = torch.rand(10, 3, 256, 256)
# Calculate the ACBMSE loss
loss = loss_function(reconstructed_image, target_image)
print("ACB-MSE Loss:", loss)
- Two masks are created from the target (label) image:
zero_mask
: A boolean mask where elements areTrue
for zero-valued pixels in the target image.nonzero_mask
: A boolean mask where elements areTrue
for non-zero-valued pixels in the target image.
- The pixel values from both the target image and the reconstructed image corresponding to zero and non-zero masks are extracted.
- The mean squared error loss as calculated between the target and the input for each mask.
- The two loss values are then multiplied by the corresponding weighting coefficients (
zero_weighting
andnonzero_weighting
) to allow user to adjust the balance from default 1:1. - The weighted balanced MSE loss is returned as the final value.
The function relies on the knowledge of the indices for all hits and non-hits in the true label image, which are then compared to the values in the corresponding index's in the recovered image. Therefore, ACB-MSE is unsuitable for unsupervised learning tasks. The ACB-MSE loss function is given by:
where
The ACB-MSE loss function was designed for data taken from particle detectors which often have a majority of 'pixels' which are unlit and a very sparse pattern of lit pixels. In this scenario the ACB-MSE loss provides two main benefits, addressing the class imbalance beteen lit and unlit pixels whilst also stabilising the loss gradient during training. Additonal parameters, 'A' and 'B', are provided to allow the user to set a custom balance between classes.
Fluctuations in the number of hit pixels across images during training can disrupt loss stability. ACB-MSE remedies this by dynamically adjusting loss function weights to reflect class frequencies in the target.
The above plot demonstrates how each of the loss functions (ACB-MSE, MSE and MAE) behave based on the number of hits in the true signal. Two dummy images were created, the first image contains a simulated signal and the recovered image is created with 50% of that signal correctly identified, simulating a 50% signal recovery by the network. To generate the plot the first image was filled in two pixel increments with the second image following at a constant 50% recovery, and at each iteration the loss is calculated for the pair of images. We can see how the MSE and MAE functions loss varies as the size of the signal is increased with the recovery percentage fixed at 50%, whereas the ACB-MSE loss stays constant regardless of the frequency of the signal class.
Class imbalance is an issue that can arise where the interesting features are contained in the minority class. In the case of the DEEPCLEAN3D data, the input images contained 11,264 total pixels with only around 200 of them being hits. For the network, guessing that all the pixels are non-hits (zero valued) yields a very respectable reconstruction loss and is a simple transfer function for the network to learn, this local minima proved hard for the network to escape from. Class balancing based on class frequency is a simple solution to this problem that shifts the loss landscape, making it less favorable for the network to guess all pixels as non-hits. This enabled the DEEPCLEAN3D network to escape the local minima and begin to learn a usefull transfer function for the input fetures.
This project is licensed under the MIT License - see the LICENSE.md file for details.
Contributions to this codebase are welcome! If you encounter any issues or have suggestions for improvements please open an issue or a pull request on the GitHub repository.
For any inquiries, feel free to reach out to me at [email protected].