Skip to content

Commit

Permalink
feat: add balanced batch sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 11, 2024
1 parent 4b954bb commit 8d2b048
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions neuralnetlib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,48 @@ def shuffle(x: np.ndarray, y: np.ndarray = None, random_state: int = None) -> tu
return shuffled_x


def balanced_batch_sampling(n_classes: int, real_samples: np.ndarray, labels: np.ndarray, batch_size: int, rng: np.random.Generator):
"""Generates a balanced batch of samples by selecting a fixed number of samples from each class.
Args:
n_classes (int): The number of classes
real_samples (np.ndarray): The real samples
labels (np.ndarray): The labels of the samples in one-hot encoding
batch_size (int): The total number of samples to select
rng (np.random.Generator): The random number generator
Raises:
ValueError: If the batch size is less than the number of classes
Returns:
tuple: A tuple of (real_samples, labels) where each array has the selected samples
"""
samples_per_class = batch_size // n_classes
if samples_per_class == 0:
raise ValueError(f"batch_size ({batch_size}) doit être au moins égal au nombre de classes ({n_classes})")

selected_indices = []

class_indices = [np.nonzero(labels[:, class_idx] == 1)[0] for class_idx in range(n_classes)]

empty_classes = [i for i, indices in enumerate(class_indices) if len(indices) == 0]
if empty_classes:
raise ValueError(f"Les classes {empty_classes} n'ont aucun échantillon dans le dataset")

for class_idx in range(n_classes):
selected_class_indices = rng.choice(
class_indices[class_idx],
size=samples_per_class,
replace=True
)
selected_indices.extend(selected_class_indices)

selected_indices = np.array(selected_indices)
rng.shuffle(selected_indices)

return real_samples[selected_indices], labels[selected_indices]


def progress_bar(current: int, total: int, width: int = 30, message: str = "") -> None:
"""
Prints a progress bar to the console.
Expand Down

0 comments on commit 8d2b048

Please sign in to comment.