Skip to content

Commit

Permalink
Move processing bounding boxes to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Jul 4, 2024
1 parent b9a9db7 commit c973b8b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 38 deletions.
14 changes: 5 additions & 9 deletions lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do

require Logger

import ExVision.Utils

alias ExVision.Types.BBoxWithMask

@type output_t() :: [BBoxWithMask.t()]
Expand Down Expand Up @@ -46,16 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y]))

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()
scores = unbatch(scores)
labels = unbatch(labels)

masks =
masks
Expand Down
30 changes: 10 additions & 20 deletions lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do

require Logger

import ExVision.Utils

alias ExVision.Types.BBoxWithKeypoints

@typep output_t() :: [BBoxWithKeypoints.t()]
Expand Down Expand Up @@ -67,26 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()

keypoints_list =
keypoints_list
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, 1]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

keypoints_scores_list = keypoints_scores_list |> Nx.squeeze(axes: [0]) |> Nx.to_list()
bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y]))

scores = unbatch(scores)
labels = unbatch(labels)

keypoints_list = process_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1]))

keypoints_scores_list = unbatch(keypoints_scores_list)

[bboxes, scores, labels, keypoints_list, keypoints_scores_list]
|> Enum.zip()
Expand Down
14 changes: 5 additions & 9 deletions lib/ex_vision/object_detection/generic_detector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defmodule ExVision.ObjectDetection.GenericDetector do

require Logger

import ExVision.Utils

alias ExVision.Types.{BBox, ImageMetadata}

@typep output_t() :: [BBox.t()]
Expand All @@ -29,16 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y]))

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()
scores = unbatch(scores)
labels = unbatch(labels)

[bboxes, scores, labels]
|> Enum.zip()
Expand Down
13 changes: 13 additions & 0 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,17 @@ defmodule ExVision.Utils do
def batched_run(process_name, input) do
process_name |> batched_run([input]) |> hd()
end

defp process_bbox(bbox, scales, axes \\ [0]) do
bbox
|> Nx.squeeze(axes: axes)
|> Nx.multiply(scales)
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
end

defp unbatch(batched_value, axes \\ [0]) do
batched_value |> Nx.squeeze(axes: axes) |> Nx.to_list()
end
end

0 comments on commit c973b8b

Please sign in to comment.