Skip to content

Commit

Permalink
move flow normalization to fwd pass
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Nov 12, 2024
1 parent 737489f commit f6c7bba
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
x = self._backbone(x)
heatmaps = tuple(self._post(x))
if self.config.compute_flow:
flow = self._flow(x[0])
flow = F.normalize(self._flow(x[0]), dim=1)
return dict(heatmaps=heatmaps, flow=flow)
else:
return dict(heatmaps=heatmaps)
Expand Down Expand Up @@ -853,15 +853,15 @@ def predict(
if subpix_radius >= 0:
if not self.config.is_3d:
flow = (
F.normalize(out["flow"], dim=1)[0]
out["flow"][0]
.permute(1, 2, 0)
.detach()
.cpu()
.numpy()
) # HW(3*C')
else:
flow = (
F.normalize(out["flow"], dim=1)[0]
out["flow"][0]
.permute(1, 2, 3, 0)
.detach()
.cpu()
Expand Down Expand Up @@ -1049,7 +1049,7 @@ def predict(
(1, 2, 0) if not self.config.is_3d else (1, 2, 3, 0)
)
flow_tile = (
F.normalize(out["flow"], dim=1)[0]
out["flow"][0]
.permute(*permute_dims)
.detach()
.cpu()
Expand Down Expand Up @@ -1306,15 +1306,15 @@ def optimize_threshold(
for flow in out["flow"]:
if not self.config.is_3d:
curr_flow_preds += [
F.normalize(flow, dim=1)
flow
.permute(1, 2, 0)
.detach()
.cpu()
.numpy()
]
else:
curr_flow_preds += [
F.normalize(flow, dim=1)
flow
.permute(1, 2, 3, 0)
.detach()
.cpu()
Expand Down

0 comments on commit f6c7bba

Please sign in to comment.