Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Add more Optical Flow baselines #71

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/

.idea/
satflow/logs/
*.png
# C extensions
*.so

Expand Down
45 changes: 28 additions & 17 deletions satflow/baseline/optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,40 @@ def warp_flow(img, flow):
channel_baseline_losses = np.array([baseline_losses for _ in range(12)])

for data in dataset:
tmp_loss = 0
tmp_base = 0
count += 1
past_frames, next_frames = data
prev_frame = past_frames[1]
curr_frame = past_frames[0]
# 6 past frames
# Get average of the past ones pairs
past_frames = past_frames.astype(np.float32)
next_frames = next_frames.astype(np.float32)
prev_frame = past_frames[0]
curr_frame = past_frames[1]
# Do it for each of the 12 channels
for ch in range(12):
# prev_frame = np.moveaxis(prev_frame, [0], [2])
# curr_frame = np.moveaxis(curr_frame, [0], [2])
flow = cv2.calcOpticalFlowFarneback(
past_frames[1][ch], past_frames[0][ch], None, 0.5, 3, 15, 3, 5, 1.2, 0
past_frames[-2][ch].astype(np.float32),
past_frames[-1][ch].astype(np.float32),
None,
0.5,
3,
15,
3,
5,
1.2,
0,
)
warped_frame = warp_flow(curr_frame[ch].astype(np.float32), flow)
warped_frame = warp_flow(past_frames[-1][ch].astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame),
torch.from_numpy(np.expand_dims(next_frames[0][ch], axis=-1)),
torch.from_numpy(np.expand_dims(next_frames[0][ch].astype(np.float32), axis=-1)),
)
channel_total_losses[ch][0] += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame[ch].astype(np.float32)),
torch.from_numpy(next_frames[0][ch]),
torch.from_numpy(past_frames[-1][ch].astype(np.float32)),
torch.from_numpy(next_frames[0][ch].astype(np.float32)),
)
channel_baseline_losses[ch][0] += loss.item()

Expand All @@ -70,22 +81,22 @@ def warp_flow(img, flow):
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame),
torch.from_numpy(np.expand_dims(next_frames[i][ch], axis=-1)),
torch.from_numpy(np.expand_dims(next_frames[i][ch].astype(np.float32), axis=-1)),
)
channel_total_losses[ch][i] += loss.item()
tmp_loss += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame[ch].astype(np.float32)),
torch.from_numpy(next_frames[i][ch]),
torch.from_numpy(past_frames[-1][ch].astype(np.float32)),
torch.from_numpy(next_frames[i][ch].astype(np.float32)),
)
channel_baseline_losses[ch][i] += loss.item()
print(
f"Avg Total Loss: {np.mean(channel_total_losses) / count} Avg Baseline Loss: {np.mean(channel_baseline_losses) / count}"
)
if count % 100 == 0:
np.save("optical_flow_mse_loss_channels_reverse.npy", channel_total_losses / count)
np.save("optical_flow_mse_loss_channels_two_frames.npy", channel_total_losses / count)
np.save(
"baseline_current_image_mse_loss_channels_reverse.npy", channel_baseline_losses / count
"baseline_current_image_mse_loss_channels_two_frames.npy",
channel_baseline_losses / count,
)
np.save("optical_flow_mse_loss_reverse.npy", channel_total_losses / count)
np.save("baseline_current_image_mse_loss_reverse.npy", channel_baseline_losses / count)
np.save("optical_flow_mse_loss_two_frames.npy", channel_total_losses / count)
np.save("baseline_current_image_mse_loss_two_frames.npy", channel_baseline_losses / count)
175 changes: 175 additions & 0 deletions satflow/baseline/optical_flow_avg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import cv2
from satflow.data.datasets import OpticalFlowDataset, SatFlowDataset
import webdataset as wds
import yaml
import torch.nn.functional as F
import numpy as np
from satflow.models.losses import SSIMLoss


def load_config(config_file):
with open(config_file, "r") as cfg:
return yaml.load(cfg, Loader=yaml.FullLoader)["config"]


used_loss = SSIMLoss(data_range=1.0, nonnegative_ssim=True, channel=3)

config = load_config(
"/home/jacob/Development/satflow/satflow/configs/datamodule/optical_flow_datamodule.yaml"
)
dset = wds.WebDataset("/run/media/jacob/data/satflow-flow-144-tiled-{00001..00149}.tar")

dataset = SatFlowDataset([dset], config=config)
import matplotlib.pyplot as plt
import torch


def warp_flow(img, flow):
h, w = flow.shape[:2]
flow = -flow
flow[:, :, 0] += np.arange(w)
flow[:, :, 1] += np.arange(h)[:, np.newaxis]
res = cv2.remap(img, flow, None, cv2.INTER_LINEAR)
return res


debug = False
total_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep
channel_total_losses = np.array([total_losses for _ in range(12)])
count = 0
baseline_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep
channel_baseline_losses = np.array([baseline_losses for _ in range(12)])

for data in dataset:
count += 1
past_frames, next_frames = data
# 6 past frames
# Get average of the past ones pairs
past_frames = past_frames.astype(np.float32)
next_frames = next_frames.astype(np.float32)
prev_frame = past_frames[0]
curr_frame = past_frames[1]
# Do it for each of the 12 channels
for ch in range(12):
flow = cv2.calcOpticalFlowFarneback(
past_frames[0][ch].astype(np.float32),
past_frames[1][ch].astype(np.float32),
None,
0.5,
3,
15,
3,
5,
1.2,
0,
)
flow += cv2.calcOpticalFlowFarneback(
past_frames[1][ch].astype(np.float32),
past_frames[2][ch].astype(np.float32),
None,
0.5,
3,
15,
3,
5,
1.2,
0,
)
flow += cv2.calcOpticalFlowFarneback(
past_frames[2][ch].astype(np.float32),
past_frames[3][ch].astype(np.float32),
None,
0.5,
3,
15,
3,
5,
1.2,
0,
)
flow += cv2.calcOpticalFlowFarneback(
past_frames[3][ch].astype(np.float32),
past_frames[4][ch].astype(np.float32),
None,
0.5,
3,
15,
3,
5,
1.2,
0,
)
flow += cv2.calcOpticalFlowFarneback(
past_frames[4][ch].astype(np.float32),
past_frames[5][ch].astype(np.float32),
None,
0.5,
3,
15,
3,
5,
1.2,
0,
)
flow /= 5
warped_frame = warp_flow(past_frames[5][ch].astype(np.float32), flow)
warped_frame2 = np.expand_dims(warped_frame, axis=0)
warped_frame2 = np.stack((warped_frame2,) * 3, axis=1)
next_frame = np.expand_dims(next_frames[0][ch].astype(np.float32), axis=0)
next_frame = np.stack((next_frame,) * 3, axis=1)
# Force between 0 and 1
next_frame = (next_frame - np.min(next_frame)) / (np.max(next_frame) - np.min(next_frame))
warped_frame2 = (warped_frame2 - np.min(warped_frame2)) / (
np.max(warped_frame2) - np.min(warped_frame2)
)
loss = used_loss(
torch.from_numpy(warped_frame2),
torch.from_numpy(next_frame),
)
channel_total_losses[ch][0] += loss.item()
current_frame = np.expand_dims(past_frames[-1][ch].astype(np.float32), axis=0)
current_frame = np.stack((current_frame,) * 3, axis=1)
current_frame = (current_frame - np.min(current_frame)) / (
np.max(current_frame) - np.min(current_frame)
)
loss = used_loss(
torch.from_numpy(current_frame),
torch.from_numpy(next_frame),
)
channel_baseline_losses[ch][0] += loss.item()
warped_frame = np.expand_dims(warped_frame, axis=-1)
for i in range(1, 48):
warped_frame = warp_flow(warped_frame.astype(np.float32), flow)
warped_frame2 = np.expand_dims(warped_frame, axis=0)
warped_frame2 = np.stack((warped_frame2,) * 3, axis=1)
next_frame = np.expand_dims(next_frames[0][ch].astype(np.float32), axis=0)
next_frame = np.stack((next_frame,) * 3, axis=1)
# Force between 0 and 1
next_frame = (next_frame - np.min(next_frame)) / (
np.max(next_frame) - np.min(next_frame)
)
warped_frame2 = (warped_frame2 - np.min(warped_frame2)) / (
np.max(warped_frame2) - np.min(warped_frame2)
)
loss = used_loss(
torch.from_numpy(warped_frame2),
torch.from_numpy(next_frame),
)
channel_total_losses[ch][i] += loss.item()
loss = used_loss(
torch.from_numpy(current_frame),
torch.from_numpy(next_frame),
)
channel_baseline_losses[ch][i] += loss.item()
warped_frame = np.expand_dims(warped_frame, axis=-1)
print(
f"Avg Total Loss: {np.mean(channel_total_losses) / count} Avg Baseline Loss: {np.mean(channel_baseline_losses) / count}"
)
if count % 10 == 0:
np.save("optical_flow_ssim_loss_avg_pairs.npy", channel_total_losses / count)
np.save(
"baseline_ssim_loss_channels_avg_pairs.npy",
channel_baseline_losses / count,
)
# np.save("optical_flow_mse_loss_avg_pairs.npy", channel_total_losses / count)
# np.save("baseline_current_image_mse_loss_avg_pairs.npy", channel_baseline_losses / count)
90 changes: 90 additions & 0 deletions satflow/baseline/plot_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as np
import matplotlib.pyplot as plt

bands = [
"HRV",
"IR016",
"IR039",
"IR087",
"IR097",
"IR108",
"IR120",
"IR134",
"VIS006",
"VIS008",
"WV062",
"WV073",
]

baseline = np.load("baseline_current_image_mse_loss_channels_avg_base_only.npy")
flow = np.load("optical_flow_mse_loss_channels_avg_pairs.npy")
ssim = np.load("optical_flow_ssim_loss_channels_two_frames.npy")
ssim_base = np.load("baseline_current_image_ssim_loss_channels_baseline.npy")


# Now slice it up by channel and by timestep
per_channel_base = np.mean(baseline, axis=1)
per_channel_total = np.mean(flow, axis=1)
plt.plot(per_channel_base, label="Current Image")
plt.plot(per_channel_total, label="Optical Flow")
plt.legend(loc="best")
plt.xlabel("Channel #")
plt.ylabel("MSE")
plt.title("Mean MSE per channel")
plt.savefig("mse_per_channel.png", dpi=300)
plt.show()
per_channel_base = 1.0 - np.mean(ssim_base, axis=1)
per_channel_total = 1.0 - np.mean(ssim, axis=1)
plt.plot(per_channel_base, label="Current Image")
plt.plot(per_channel_total, label="Optical Flow")
plt.legend(loc="best")
plt.xlabel("Channel #")
plt.ylabel("SSIM")
plt.title("Mean SSIM per channel")
plt.savefig("ssim_per_channel.png", dpi=300)
plt.show()

per_channel_base = np.mean(baseline, axis=0)
per_channel_total = np.mean(flow, axis=0)
plt.plot(per_channel_base, label="Current Image")
plt.plot(per_channel_total, label="Optical Flow")
plt.legend(loc="best")
plt.xlabel("Timestep")
plt.ylabel("MSE")
plt.title("Mean MSE per timestep")
plt.savefig("mse_per_timestep.png", dpi=300)
plt.show()
per_channel_base = 1.0 - np.mean(ssim_base, axis=0)
per_channel_total = 1.0 - np.mean(ssim, axis=0)
plt.plot(per_channel_base, label="Current Image")
plt.plot(per_channel_total, label="Optical Flow")
plt.legend(loc="best")
plt.xlabel("Timestep")
plt.ylabel("SSIM")
plt.title("Mean SSIM per timestep")
plt.savefig("ssim_per_timestep.png", dpi=300)
plt.show()

for i in range(12):
per_channel_base = baseline[i]
per_channel_total = flow[i]
plt.plot(per_channel_base, label="Current Image")
plt.plot(per_channel_total, label="Optical Flow")
plt.legend(loc="best")
plt.xlabel("Timestep")
plt.ylabel("MSE")
plt.title(f"Mean MSE per timestep ({bands[i]})")
plt.savefig(f"mse_{bands[i]}.png", dpi=300)
plt.show()

for i in range(12):
per_channel_base = 1.0 - ssim_base[i]
per_channel_total = 1.0 - ssim[i]
plt.plot(per_channel_base, label="Current Image")
plt.plot(per_channel_total, label="Optical Flow")
plt.legend(loc="best")
plt.xlabel("Timestep")
plt.ylabel("SSIM")
plt.title(f"Mean SSIM per timestep ({bands[i]})")
plt.savefig(f"ssim_{bands[i]}.png", dpi=300)
plt.show()
2 changes: 1 addition & 1 deletion satflow/configs/datamodule/optical_flow_datamodule.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ num_workers: 1
pin_memory: False
config:
visualize: False
num_timesteps: 1
num_timesteps: 5
skip_timesteps: 1
forecast_times: 48
output_shape: 400
Expand Down