Skip to content

Commit

Permalink
Merge pull request #223 from kahst/binary-classification
Browse files Browse the repository at this point in the history
Binary classification
  • Loading branch information
max-mauermann authored Jan 8, 2024
2 parents ffb5ac6 + 4eca0d3 commit 6ecb086
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@
# Mutliple executions will be averaged, so the evaluation is more consistent
AUTOTUNE_EXECUTIONS_PER_TRIAL: int = 1

# If a binary classification model is trained, this value will be detected automatically in the training script
BINARY_CLASSIFICATION: bool = False

#####################
# Misc runtime vars #
#####################
Expand Down
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ def on_epoch_end(self, epoch, logs=None):
x_train, y_train = utils.upsampling(x_train, y_train, upsampling_ratio, upsampling_mode)
print(f"Upsampled training data to {x_train.shape[0]} samples.", flush=True)

# Apply mixup to training data
if train_with_mixup:
# Apply mixup to training data
if train_with_mixup and not cfg.BINARY_CLASSIFICATION:
x_train, y_train = utils.mixup(x_train, y_train)

# Apply label smoothing
if train_with_label_smoothing:
if train_with_label_smoothing and not cfg.BINARY_CLASSIFICATION:
y_train = utils.label_smoothing(y_train)

# Early stopping
Expand All @@ -252,7 +252,7 @@ def on_epoch_end(self, epoch, logs=None):
classifier.compile(
optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
loss=custom_loss,
metrics=[keras.metrics.AUC(curve="PR", multi_label=False, name="AUPRC")],
metrics=[keras.metrics.AUC(curve="PR", multi_label=False, name="AUPRC")], # TODO: Use AUROCC
)

# Train model
Expand Down
10 changes: 9 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ def _loadTrainingData(cache_mode="none", cache_file=""):
labels = list(sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH)))

# Get valid labels
valid_labels = [l for l in labels if not l.lower() in cfg.NON_EVENT_CLASSES and not l.startswith("-")]
valid_labels = [l for l in labels if not l.lower() in cfg.NON_EVENT_CLASSES and not l.startswith("-")]

cfg.BINARY_CLASSIFICATION = len(valid_labels) == 1

if cfg.BINARY_CLASSIFICATION:
if len([l for l in labels if l.startswith("-")]) > 0:
raise Exception("negative labels cant be used with binary classification")
if len([l for l in labels if l in cfg.NON_EVENT_CLASSES]) == 0:
raise Exception("non-event samples are required for binary classification")

# Load training data
x_train = []
Expand Down
15 changes: 14 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def random_split(x, y, val_ratio=0.2):
train_indices = positive_indices[:num_samples_train]
val_indices = positive_indices[num_samples_train:num_samples_train + num_samples_val]


# Append samples to training and validation data
x_train.append(x[train_indices])
y_train.append(y[train_indices])
Expand All @@ -109,6 +108,20 @@ def random_split(x, y, val_ratio=0.2):
x_train.append(x[negative_indices])
y_train.append(y[negative_indices])

# Add samples for non-event classes to training and validation data
non_event_indices = np.where(y[:,:] == 0)[0]
num_samples = len(non_event_indices)
num_samples_train = max(1, int(num_samples * (1 - val_ratio)))
num_samples_val = max(0, num_samples - num_samples_train)
np.random.shuffle(non_event_indices)
train_indices = non_event_indices[:num_samples_train]
val_indices = non_event_indices[num_samples_train:num_samples_train + num_samples_val]
x_train.append(x[train_indices])
y_train.append(y[train_indices])
x_val.append(x[val_indices])
y_val.append(y[val_indices])


# Concatenate data
x_train = np.concatenate(x_train)
y_train = np.concatenate(y_train)
Expand Down

0 comments on commit 6ecb086

Please sign in to comment.