Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add test for class weights (py_dataset adapter) #20638

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions keras/src/trainers/data_adapters/data_adapter_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src import tree
from keras.src.api_export import keras_export

Expand Down Expand Up @@ -115,15 +116,20 @@ def check_data_cardinality(data):


def class_weight_to_sample_weights(y, class_weight):
sample_weight = np.ones(shape=(y.shape[0],), dtype=backend.floatx())
if len(y.shape) > 1:
if y.shape[-1] != 1:
y = np.argmax(y, axis=-1)
# Convert to numpy to ensure consistent handling of operations
# (e.g., np.round()) across frameworks like TensorFlow, JAX, and PyTorch

y_numpy = ops.convert_to_numpy(y)
sample_weight = np.ones(shape=(y_numpy.shape[0],), dtype=backend.floatx())
if len(y_numpy.shape) > 1:
if y_numpy.shape[-1] != 1:
y_numpy = np.argmax(y_numpy, axis=-1)
else:
y = np.squeeze(y, axis=-1)
y = np.round(y).astype("int32")
for i in range(y.shape[0]):
sample_weight[i] = class_weight.get(int(y[i]), 1.0)
y_numpy = np.squeeze(y_numpy, axis=-1)
y_numpy = np.round(y_numpy).astype("int32")

for i in range(y_numpy.shape[0]):
sample_weight[i] = class_weight.get(int(y_numpy[i]), 1.0)
return sample_weight


Expand Down
107 changes: 107 additions & 0 deletions keras/src/trainers/data_adapters/data_adapter_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np
import pytest
from absl.testing import parameterized

from keras.src import backend
from keras.src import testing
from keras.src.trainers.data_adapters.data_adapter_utils import (
class_weight_to_sample_weights,
)


class TestClassWeightToSampleWeights(testing.TestCase):
@parameterized.named_parameters(
[
# Simple case, where y is flat
(
"simple_class_labels",
np.array([0, 1, 0, 2]),
{0: 1.0, 1: 2.0, 2: 3.0},
np.array([1.0, 2.0, 1.0, 3.0]),
),
# Testing with one-hot encoded labels,
# so basically the argmax statement
(
"one_hot_encoded_labels",
np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]]),
{0: 1.0, 1: 2.0, 2: 3.0},
np.array([1.0, 2.0, 1.0, 3.0]),
),
# 3 is not mapped, so it's assigned the default weight (1)
(
"unmapped_class",
np.array([0, 3, 0, 2]),
{0: 1.0, 1: 2.0, 2: 3.0},
np.array([1.0, 1.0, 1.0, 3.0]),
),
(
"multi_dimensional_input",
np.array([[0], [1], [0], [2]]),
{0: 1.0, 1: 2.0, 2: 3.0},
np.array([1.0, 2.0, 1.0, 3.0]),
),
(
"all_unmapped",
np.array([0, 1, 0, 2]),
{},
np.array([1.0, 1.0, 1.0, 1.0]),
),
]
)
def test_class_weight_to_sample_weights(self, y, class_weight, expected):
self.assertAllClose(
class_weight_to_sample_weights(y, class_weight), expected
)

@pytest.mark.skipif(backend.backend() != "torch", reason="torch only")
def test_class_weight_to_sample_weights_torch_specific(self):
import torch

y = torch.from_numpy(np.array([0, 1, 0, 2]))
self.assertAllClose(
class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),
np.array([1.0, 2.0, 1.0, 3.0]),
)
y_one_hot = torch.from_numpy(
np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])
)
self.assertAllClose(
class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),
np.array([1.0, 2.0, 1.0, 3.0]),
)

@pytest.mark.skipif(backend.backend() != "jax", reason="jax only")
def test_class_weight_to_sample_weights_jax_specific(self):
import jax

y = jax.numpy.asarray(np.array([0, 1, 0, 2]))
self.assertAllClose(
class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),
np.array([1.0, 2.0, 1.0, 3.0]),
)
y_one_hot = jax.numpy.asarray(
np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])
)
self.assertAllClose(
class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),
np.array([1.0, 2.0, 1.0, 3.0]),
)

@pytest.mark.skipif(
backend.backend() != "tensorflow", reason="tensorflow only"
)
def test_class_weight_to_sample_weights_tf_specific(self):
import tensorflow as tf

y = tf.convert_to_tensor(np.array([0, 1, 0, 2]))
self.assertAllClose(
class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),
np.array([1.0, 2.0, 1.0, 3.0]),
)
y_one_hot = tf.convert_to_tensor(
np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])
)
self.assertAllClose(
class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),
np.array([1.0, 2.0, 1.0, 3.0]),
)
8 changes: 4 additions & 4 deletions keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _standardize_batch(self, batch):

def _infinite_generator(self):
for i in itertools.count():
yield self.py_dataset[i]
yield self._standardize_batch(self.py_dataset[i])

def _finite_generator(self):
indices = range(self.py_dataset.num_batches)
Expand All @@ -245,18 +245,18 @@ def _finite_generator(self):
random.shuffle(indices)

for i in indices:
yield self.py_dataset[i]
yield self._standardize_batch(self.py_dataset[i])

def _infinite_enqueuer_generator(self):
self.enqueuer.start()
for batch in self.enqueuer.get():
yield batch
yield self._standardize_batch(batch)

def _finite_enqueuer_generator(self):
self.enqueuer.start()
num_batches = self.py_dataset.num_batches
for i, batch in enumerate(self.enqueuer.get()):
yield batch
yield self._standardize_batch(batch)
if i >= num_batches - 1:
self.enqueuer.stop()
return
Expand Down
33 changes: 32 additions & 1 deletion keras/src/trainers/data_adapters/py_dataset_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,41 @@ def test_basic_flow(
else:
self.assertAllClose(sample_order, expected_order)

# TODO: test class_weight
# TODO: test sample weights
# TODO: test inference mode (single output)

def test_class_weight(self):
x = np.random.randint(1, 100, (4, 5))
y = np.array([0, 1, 2, 1])
class_w = {0: 2, 1: 1, 2: 3}
py_dataset = ExamplePyDataset(x, y, batch_size=2)
adapter = py_dataset_adapter.PyDatasetAdapter(
py_dataset, shuffle=False, class_weight=class_w
)
if backend.backend() == "numpy":
gen = adapter.get_numpy_iterator()
elif backend.backend() == "tensorflow":
gen = adapter.get_tf_dataset()
elif backend.backend() == "jax":
gen = adapter.get_jax_iterator()
elif backend.backend() == "torch":
gen = adapter.get_torch_dataloader()

for index, batch in enumerate(gen):
# Batch is a tuple of (x, y, class_weight)
self.assertLen(batch, 3)
# Let's verify the data and class weights match for each element
# of the batch (2 elements in each batch)
for sub_elem in range(2):
self.assertTrue(
np.array_equal(batch[0][sub_elem], x[index * 2 + sub_elem])
)
self.assertEqual(batch[1][sub_elem], y[index * 2 + sub_elem])
class_key = np.int32(batch[1][sub_elem])
self.assertEqual(batch[2][sub_elem], class_w[class_key])

self.assertEqual(index, 1) # 2 batches

def test_speedup(self):
x = np.random.random((40, 4))
y = np.random.random((40, 2))
Expand Down
26 changes: 26 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,17 +522,37 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch):
"testcase_name": "py_dataset",
"dataset_type": "py_dataset",
},
{
"testcase_name": "py_dataset_cw",
"dataset_type": "py_dataset",
"fit_kwargs": {"class_weight": {0: 1, 1: 2}},
},
{
"testcase_name": "py_dataset_infinite",
"dataset_type": "py_dataset",
"dataset_kwargs": {"infinite": True},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "py_dataset_infinite_cw",
"dataset_type": "py_dataset",
"dataset_kwargs": {"infinite": True},
"fit_kwargs": {
"steps_per_epoch": 20,
"class_weight": {0: 1, 1: 2},
},
},
{
"testcase_name": "py_dataset_multithreading",
"dataset_type": "py_dataset",
"dataset_kwargs": {"workers": 2},
},
{
"testcase_name": "py_dataset_multithreading_cw",
"dataset_type": "py_dataset",
"dataset_kwargs": {"workers": 2},
"fit_kwargs": {"class_weight": {0: 1, 1: 2}},
},
{
"testcase_name": "py_dataset_multithreading_infinite",
"dataset_type": "py_dataset",
Expand All @@ -544,6 +564,12 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch):
"dataset_type": "py_dataset",
"dataset_kwargs": {"workers": 2, "use_multiprocessing": True},
},
{
"testcase_name": "py_dataset_multiprocessing_cw",
"dataset_type": "py_dataset",
"dataset_kwargs": {"workers": 2, "use_multiprocessing": True},
"fit_kwargs": {"class_weight": {0: 1, 1: 2}},
},
{
"testcase_name": "py_dataset_multiprocessing_infinite",
"dataset_type": "py_dataset",
Expand Down
Loading