Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
Update optical flow to look at all sat channels
Browse files Browse the repository at this point in the history
Relates to #47
  • Loading branch information
jacobbieker committed Jul 29, 2021
1 parent bb5375b commit a2ed86f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 47 deletions.
82 changes: 42 additions & 40 deletions satflow/baseline/optical_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import cv2
from satflow.data.datasets import OpticalFlowDataset
from satflow.data.datasets import OpticalFlowDataset, SatFlowDataset
import webdataset as wds
import yaml
import torch.nn.functional as F
Expand All @@ -16,7 +16,7 @@ def load_config(config_file):
)
dset = wds.WebDataset("/run/media/jacob/data/satflow-flow-144-tiled-{00001..00149}.tar")

dataset = OpticalFlowDataset([dset], config=config)
dataset = SatFlowDataset([dset], config=config)

import matplotlib.pyplot as plt
import torch
Expand All @@ -33,57 +33,59 @@ def warp_flow(img, flow):

debug = False
total_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep
channel_total_losses = np.array([total_losses for _ in range(12)])
count = 0
baseline_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep
overall_loss = 0.0
overall_baseline = 0.0
channel_baseline_losses = np.array([baseline_losses for _ in range(12)])

for data in dataset:
tmp_loss = 0
tmp_base = 0
count += 1
prev_frame, curr_frame, next_frames, image, prev_image = data
prev_frame = np.moveaxis(prev_frame, [0], [2])
curr_frame = np.moveaxis(curr_frame, [0], [2])
flow = cv2.calcOpticalFlowFarneback(prev_image, image, None, 0.5, 3, 15, 3, 5, 1.2, 0)
warped_frame = warp_flow(curr_frame.astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame), torch.from_numpy(np.expand_dims(next_frames[0], axis=-1))
)
total_losses[0] += loss.item()
tmp_loss += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame.astype(np.float32)),
torch.from_numpy(np.expand_dims(next_frames[0], axis=-1)),
)
baseline_losses[0] += loss.item()
tmp_base += loss.item()

for i in range(1, 48):
warped_frame = warp_flow(warped_frame.astype(np.float32), flow)
past_frames, next_frames = data
prev_frame = past_frames[1]
curr_frame = past_frames[0]
# Do it for each of the 12 channels
for ch in range(12):
# prev_frame = np.moveaxis(prev_frame, [0], [2])
# curr_frame = np.moveaxis(curr_frame, [0], [2])
flow = cv2.calcOpticalFlowFarneback(
past_frames[1][ch], past_frames[0][ch], None, 0.5, 3, 15, 3, 5, 1.2, 0
)
warped_frame = warp_flow(curr_frame[ch].astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame),
torch.from_numpy(np.expand_dims(next_frames[i], axis=-1)),
torch.from_numpy(np.expand_dims(next_frames[0][ch], axis=-1)),
)
total_losses[i] += loss.item()
tmp_loss += loss.item()
channel_total_losses[ch][0] += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame.astype(np.float32)),
torch.from_numpy(np.expand_dims(next_frames[i], axis=-1)),
torch.from_numpy(curr_frame[ch].astype(np.float32)),
torch.from_numpy(next_frames[0][ch]),
)
baseline_losses[i] += loss.item()
tmp_base += loss.item()
tmp_base /= 48
tmp_loss /= 48
overall_loss += tmp_loss
overall_baseline += tmp_base
channel_baseline_losses[ch][0] += loss.item()

for i in range(1, 48):
warped_frame = warp_flow(warped_frame.astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame),
torch.from_numpy(np.expand_dims(next_frames[i][ch], axis=-1)),
)
channel_total_losses[ch][i] += loss.item()
tmp_loss += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame[ch].astype(np.float32)),
torch.from_numpy(next_frames[i][ch]),
)
channel_baseline_losses[ch][i] += loss.item()
print(
f"Avg Total Loss: {np.mean(total_losses) / count} Avg Baseline Loss: {np.mean(baseline_losses) / count} \n Overall Loss: {overall_loss / count} Baseline: {overall_baseline / count}"
f"Avg Total Loss: {np.mean(channel_total_losses) / count} Avg Baseline Loss: {np.mean(channel_baseline_losses) / count}"
)
if count % 100 == 0:
np.save("optical_flow_mse_loss.npy", total_losses / count)
np.save("baseline_current_image_mse_loss.npy", baseline_losses / count)
np.save("optical_flow_mse_loss.npy", total_losses / count)
np.save("baseline_current_image_mse_loss.npy", baseline_losses / count)
np.save("optical_flow_mse_loss_channels_reverse.npy", channel_total_losses / count)
np.save(
"baseline_current_image_mse_loss_channels_reverse.npy", channel_baseline_losses / count
)
np.save("optical_flow_mse_loss_reverse.npy", channel_total_losses / count)
np.save("baseline_current_image_mse_loss_reverse.npy", channel_baseline_losses / count)
21 changes: 15 additions & 6 deletions satflow/configs/datamodule/optical_flow_datamodule.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@ config:
use_latlon: False
use_time: False
time_aux: False
use_mask: True
use_image: False
use_mask: False
use_image: True
add_pixel_coords: False
time_as_channels: False
# NIR1.6, VIS0.8 and VIS0.6 RGB for near normal view
bands: [
bands:
[
"HRV",
#"IR016",
#"VIS006",
#"VIS008",
"IR016",
"IR039",
"IR087",
"IR097",
"IR108",
"IR120",
"IR134",
"VIS006",
"VIS008",
"WV062",
"WV073",
]
transforms: {}
2 changes: 1 addition & 1 deletion satflow/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def __iter__(self):
if not self.use_image:
yield image, target_mask
else:
yield image, target_image, target_mask
yield image, target_image

def get_topo_latlon(self, sample: dict) -> None:
if self.use_topo:
Expand Down

0 comments on commit a2ed86f

Please sign in to comment.