Skip to content

Commit

Permalink
Reorganize bounding box utilities and namespaces (#439)
Browse files Browse the repository at this point in the history
* Reorganize bounding box utilities and namespaces

* regorganize bbox api entrypoint
  • Loading branch information
LukeWood authored May 17, 2022
1 parent aa6a747 commit d4d7fb5
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def produce_random_data(include_confidence=False, num_images=128, num_classes=20
)

images = [
keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape(
x, [25, images[0].shape[1]]
)
keras_cv.bounding_box.pad_batch_to_shape(x, [25, images[0].shape[1]])
for x in images
]
return tf.stack(images, axis=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def produce_random_data(include_confidence=False, num_images=128, num_classes=20
)

images = [
keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape(
x, [25, images[0].shape[1]]
)
keras_cv.bounding_box.pad_batch_to_shape(x, [25, images[0].shape[1]])
for x in images
]
return tf.stack(images, axis=0)
Expand Down
4 changes: 1 addition & 3 deletions benchmarks/metrics/coco/recall_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def produce_random_data(include_confidence=False, num_images=128, num_classes=20
)

images = [
keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape(
x, [25, images[0].shape[1]]
)
keras_cv.bounding_box.pad_batch_to_shape(x, [25, images[0].shape[1]])
for x in images
]
return tf.stack(images, axis=0)
Expand Down
25 changes: 25 additions & 0 deletions keras_cv/bounding_box/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.bounding_box.convert_to_corners import convert_to_corners
from keras_cv.bounding_box.pad_batch_to_shape import pad_batch_to_shape

# These are the indexes used in Tensors to represent each corresponding side.
LEFT, TOP, RIGHT, BOTTOM = 0, 1, 2, 3

# Regardless of format these constants are consistent.
# Class is held in the 5th index
CLASS = 4
# Confidence exists only on y_pred, and is in the 6th index.
CONFIDENCE = 5
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import tensorflow as tf

from keras_cv.utils import bounding_box
from keras_cv import bounding_box


class BBOXTestCase(tf.test.TestCase):
Expand Down Expand Up @@ -97,25 +97,21 @@ def test_yolo_to_corner(self):
def test_bounding_box_padding(self):
bounding_boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
target_shape = [3, 4]
result = bounding_box.pad_bounding_box_batch_to_shape(
bounding_boxes, target_shape
)
result = bounding_box.pad_batch_to_shape(bounding_boxes, target_shape)
self.assertAllClose(result, [[1, 2, 3, 4], [5, 6, 7, 8], [-1, -1, -1, -1]])

target_shape = [2, 5]
result = bounding_box.pad_bounding_box_batch_to_shape(
bounding_boxes, target_shape
)
result = bounding_box.pad_batch_to_shape(bounding_boxes, target_shape)
self.assertAllClose(result, [[1, 2, 3, 4, -1], [5, 6, 7, 8, -1]])

# Make sure to raise error if the rank is different between bounding_box and
# target shape
with self.assertRaisesRegex(ValueError, "Target shape should have same rank"):
bounding_box.pad_bounding_box_batch_to_shape(bounding_boxes, [1, 2, 3])
bounding_box.pad_batch_to_shape(bounding_boxes, [1, 2, 3])

# Make sure raise error if the target shape is smaller
target_shape = [3, 2]
with self.assertRaisesRegex(
ValueError, "Target shape should be larger than bounding box shape"
):
bounding_box.pad_bounding_box_batch_to_shape(bounding_boxes, target_shape)
bounding_box.pad_batch_to_shape(bounding_boxes, target_shape)
75 changes: 75 additions & 0 deletions keras_cv/bounding_box/convert_to_corners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared utility functions for working with bounding boxes.
Usually bounding boxes is a 2D Tensor with shape [batch, 4]. The second dimension
will contain 4 numbers based on 2 different formats. In KerasCV, we will use the
`corners` format, which is [LEFT, TOP, RIGHT, BOTTOM].
In this file, provide utility functions for manipulating bounding boxes and converting
their formats.
"""

import tensorflow as tf


def convert_to_corners(bounding_boxes, format):
"""Converts bounding_boxes to corners format.
Converts bounding boxes from the provided format to corners format, which is:
`[left, top, right, bottom]`.
args:
format: one of "coco" or "yolo". The formats are as follows-
coco=[x_min, y_min, width, height]
yolo=[x_center, y_center, width, height]
"""
if format == "coco":
return _coco_to_corners(bounding_boxes)
elif format == "yolo":
return _yolo_to_corners(bounding_boxes)
else:
raise ValueError(
"Unsupported format passed to convert_to_corners(). "
f"Want one 'coco' or 'yolo', got format=={format}"
)


def _yolo_to_corners(bounding_boxes):
x, y, width, height, rest = tf.split(bounding_boxes, [1, 1, 1, 1, -1], axis=-1)
return tf.concat(
[
x - width / 2.0,
y - height / 2.0,
x + width / 2.0,
y + height / 2.0,
rest, # In case there is any more index after the HEIGHT.
],
axis=-1,
)


def _coco_to_corners(bounding_boxes):
x, y, width, height, rest = tf.split(bounding_boxes, [1, 1, 1, 1, -1], axis=-1)
return tf.concat(
[
x,
y,
x + width,
y + height,
rest, # In case there is any more index after the HEIGHT.
],
axis=-1,
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,80 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared utility functions for working with bounding boxes.
Usually bounding boxes is a 2D Tensor with shape [batch, 4]. The second dimension
will contain 4 numbers based on 2 different formats. In KerasCV, we will use the
`corners` format, which is [LEFT, TOP, RIGHT, BOTTOM].
In this file, provide utility functions for manipulating bounding boxes and converting
their formats.
"""

import tensorflow as tf

# These are the indexes used in Tensors to represent each corresponding side.
LEFT, TOP, RIGHT, BOTTOM = 0, 1, 2, 3

# Regardless of format these constants are consistent.
# Class is held in the 5th index
CLASS = 4
# Confidence exists only on y_pred, and is in the 6th index.
CONFIDENCE = 5


def convert_to_corners(bounding_boxes, format):
"""Converts bounding_boxes to corners format.
Converts bounding boxes from the provided format to corners format, which is:
`[left, top, right, bottom]`.
args:
format: one of "coco" or "yolo". The formats are as follows-
coco=[x_min, y_min, width, height]
yolo=[x_center, y_center, width, height]
"""
if format == "coco":
return _coco_to_corners(bounding_boxes)
elif format == "yolo":
return _yolo_to_corners(bounding_boxes)
else:
raise ValueError(
"Unsupported format passed to convert_to_corners(). "
f"Want one 'coco' or 'yolo', got format=={format}"
)


def _yolo_to_corners(bounding_boxes):
x, y, width, height, rest = tf.split(bounding_boxes, [1, 1, 1, 1, -1], axis=-1)
return tf.concat(
[
x - width / 2.0,
y - height / 2.0,
x + width / 2.0,
y + height / 2.0,
rest, # In case there is any more index after the HEIGHT.
],
axis=-1,
)


def _coco_to_corners(bounding_boxes):
x, y, width, height, rest = tf.split(bounding_boxes, [1, 1, 1, 1, -1], axis=-1)
return tf.concat(
[
x,
y,
x + width,
y + height,
rest, # In case there is any more index after the HEIGHT.
],
axis=-1,
)


def pad_bounding_box_batch_to_shape(bounding_boxes, target_shape, padding_values=-1):
def pad_batch_to_shape(bounding_boxes, target_shape, padding_values=-1):
"""Pads a list of bounding boxes with -1s.
Boxes represented by all -1s are ignored by COCO metrics.
Expand All @@ -93,17 +23,17 @@ def pad_bounding_box_batch_to_shape(bounding_boxes, target_shape, padding_values
bounding_box = [[1, 2, 3, 4], [5, 6, 7, 8]] # 2 bounding_boxes with with xywh or
corners format.
target_shape = [3, 4] # Add 1 more dummy bounding_box
result = pad_bounding_box_batch_to_shape(bounding_box, target_shape)
result = pad_batch_to_shape(bounding_box, target_shape)
# result == [[1, 2, 3, 4], [5, 6, 7, 8], [-1, -1, -1, -1]]
target_shape = [2, 5] # Add 1 more index after the current 4 coordinates.
result = pad_bounding_box_batch_to_shape(bounding_box, target_shape)
result = pad_batch_to_shape(bounding_box, target_shape)
# result == [[1, 2, 3, 4, -1], [5, 6, 7, 8, -1]]
Args:
bounding_boxes: tf.Tensor of bounding boxes in any format.
target_shape: Target shape to pad bounding box to. This should have the same
rank as the bbounding_boxs. Note that if the target_shape contains any
rank as the bounding_boxes. Note that if the target_shape contains any
dimension that is smaller than the bounding box shape, then no value will be
padded.
padding_values: value to pad, defaults to -1 to mask out in coco metrics.
Expand Down
4 changes: 2 additions & 2 deletions keras_cv/metrics/coco/mean_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.metrics.coco import utils
from keras_cv.utils import bounding_box
from keras_cv.utils import iou as iou_lib


Expand Down Expand Up @@ -63,7 +63,7 @@ class COCOMeanAveragePrecision(tf.keras.metrics.Metric):
account for this, you may either pass a `tf.RaggedTensor`, or pad Tensors
with `-1`s to indicate unused boxes. A utility function to perform this
padding is available at
`keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape()`.
`keras_cv.bounding_box.pad_batch_to_shape()`.
```python
coco_map = keras_cv.metrics.COCOMeanAveragePrecision(
Expand Down
4 changes: 2 additions & 2 deletions keras_cv/metrics/coco/mean_average_precision_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import tensorflow as tf
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.metrics import COCOMeanAveragePrecision
from keras_cv.utils import bounding_box as bounding_box_utils


class COCOMeanAveragePrecisionTest(tf.test.TestCase):
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_bounding_box_counting(self):
y_true = tf.constant([[[0, 0, 100, 100, 1]]], dtype=tf.float64)
y_pred = tf.constant([[[0, 50, 100, 150, 1, 1.0]]], dtype=tf.float32)

y_true = bounding_box_utils.pad_bounding_box_batch_to_shape(y_true, (1, 20, 5))
y_true = bounding_box.pad_batch_to_shape(y_true, (1, 20, 5))

metric = COCOMeanAveragePrecision(
iou_thresholds=[0.15],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import numpy as np
import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.metrics.coco import COCOMeanAveragePrecision
from keras_cv.utils import bounding_box

SAMPLE_FILE = os.path.dirname(os.path.abspath(__file__)) + "/sample_boxes.npz"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as np
import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.metrics import COCORecall
from keras_cv.utils import bounding_box

SAMPLE_FILE = os.path.dirname(os.path.abspath(__file__)) + "/sample_boxes.npz"

Expand Down
4 changes: 2 additions & 2 deletions keras_cv/metrics/coco/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import tensorflow.keras as keras
import tensorflow.keras.initializers as initializers

from keras_cv import bounding_box
from keras_cv.metrics.coco import utils
from keras_cv.utils import bounding_box
from keras_cv.utils import iou as iou_lib


Expand Down Expand Up @@ -54,7 +54,7 @@ class COCORecall(keras.metrics.Metric):
account for this, you may either pass a `tf.RaggedTensor`, or pad Tensors
with `-1`s to indicate unused boxes. A utility function to perform this
padding is available at
`keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape`.
`keras_cv.bounding_box.pad_batch_to_shape`.
```python
coco_recall = keras_cv.metrics.COCORecall(
Expand Down
2 changes: 1 addition & 1 deletion keras_cv/metrics/coco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Contains shared utilities for Keras COCO metrics."""
import tensorflow as tf

from keras_cv.utils import bounding_box
from keras_cv import bounding_box


def filter_boxes_by_area_range(boxes, min_area, max_area):
Expand Down
2 changes: 1 addition & 1 deletion keras_cv/metrics/coco/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.metrics.coco import utils
from keras_cv.utils import bounding_box
from keras_cv.utils import iou as iou_lib


Expand Down
2 changes: 0 additions & 2 deletions keras_cv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.utils import bounding_box
from keras_cv.utils.bounding_box import pad_bounding_box_batch_to_shape
from keras_cv.utils.fill_utils import fill_rectangle
from keras_cv.utils.iou import compute_ious_for_image
from keras_cv.utils.preprocessing import blend
Expand Down
2 changes: 1 addition & 1 deletion keras_cv/utils/fill_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import tensorflow as tf

from keras_cv.utils import bounding_box
from keras_cv import bounding_box


def _axis_mask(starts, ends, mask_len):
Expand Down

0 comments on commit d4d7fb5

Please sign in to comment.