From c33a73ae63b7dbf09bb330cd57c37b288cae9098 Mon Sep 17 00:00:00 2001 From: max-mauermann <40059289+max-mauermann@users.noreply.github.com> Date: Fri, 5 Apr 2024 15:16:43 +0200 Subject: [PATCH] mixup only uses positive samples and does not mixup the same samples multiple times (#280) --- utils.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/utils.py b/utils.py index 3d67ee9e..86ffc386 100644 --- a/utils.py +++ b/utils.py @@ -248,18 +248,32 @@ def mixup(x, y, augmentation_ratio=0.25, alpha=0.2): # Set numpy random seed np.random.seed(cfg.RANDOM_SEED) + # Get indices of all positive samples + positive_indices = np.unique(np.where(y[:, :] == 1)[0]) + # Calculate the number of samples to augment based on the ratio - num_samples_to_augment = int(len(x) * augmentation_ratio) + num_samples_to_augment = int(len(positive_indices) * augmentation_ratio) + + # Indices of samples, that are already mixed up + mixed_up_indices = [] for _ in range(num_samples_to_augment): - # Randomly choose one instance from the dataset - index = np.random.choice(len(x)) + + # Randomly choose one instance from the positive samples + index = np.random.choice(positive_indices) + + # Choose another one, when the chosen one was already mixed up + while index in mixed_up_indices: + index = np.random.choice(positive_indices) + x1, y1 = x[index], y[index] # Randomly choose a different instance from the dataset - second_index = np.random.choice(len(x)) - while second_index == index: - second_index = np.random.choice(len(x)) + second_index = np.random.choice(positive_indices) + + # Choose again, when the same or an already mixed up sample was selected + while second_index == index or second_index in mixed_up_indices: + second_index = np.random.choice(positive_indices) x2, y2 = x[second_index], y[second_index] # Generate a random mixing coefficient (lambda) @@ -273,6 +287,9 @@ def mixup(x, y, augmentation_ratio=0.25, alpha=0.2): x[index] = mixed_x y[index] = mixed_y + # Mark the sample as already mixed up + mixed_up_indices.append(index) + del mixed_x del mixed_y