diff --git a/.gitignore b/.gitignore index 070b2f61..d3f359f9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ .idea/ satflow/logs/ +*.png # C extensions *.so diff --git a/satflow/baseline/optical_flow.py b/satflow/baseline/optical_flow.py index 76254d8c..2c3c0993 100644 --- a/satflow/baseline/optical_flow.py +++ b/satflow/baseline/optical_flow.py @@ -39,29 +39,40 @@ def warp_flow(img, flow): channel_baseline_losses = np.array([baseline_losses for _ in range(12)]) for data in dataset: - tmp_loss = 0 - tmp_base = 0 count += 1 past_frames, next_frames = data - prev_frame = past_frames[1] - curr_frame = past_frames[0] + # 6 past frames + # Get average of the past ones pairs + past_frames = past_frames.astype(np.float32) + next_frames = next_frames.astype(np.float32) + prev_frame = past_frames[0] + curr_frame = past_frames[1] # Do it for each of the 12 channels for ch in range(12): # prev_frame = np.moveaxis(prev_frame, [0], [2]) # curr_frame = np.moveaxis(curr_frame, [0], [2]) flow = cv2.calcOpticalFlowFarneback( - past_frames[1][ch], past_frames[0][ch], None, 0.5, 3, 15, 3, 5, 1.2, 0 + past_frames[-2][ch].astype(np.float32), + past_frames[-1][ch].astype(np.float32), + None, + 0.5, + 3, + 15, + 3, + 5, + 1.2, + 0, ) - warped_frame = warp_flow(curr_frame[ch].astype(np.float32), flow) + warped_frame = warp_flow(past_frames[-1][ch].astype(np.float32), flow) warped_frame = np.expand_dims(warped_frame, axis=-1) loss = F.mse_loss( torch.from_numpy(warped_frame), - torch.from_numpy(np.expand_dims(next_frames[0][ch], axis=-1)), + torch.from_numpy(np.expand_dims(next_frames[0][ch].astype(np.float32), axis=-1)), ) channel_total_losses[ch][0] += loss.item() loss = F.mse_loss( - torch.from_numpy(curr_frame[ch].astype(np.float32)), - torch.from_numpy(next_frames[0][ch]), + torch.from_numpy(past_frames[-1][ch].astype(np.float32)), + torch.from_numpy(next_frames[0][ch].astype(np.float32)), ) channel_baseline_losses[ch][0] += loss.item() @@ -70,22 +81,22 @@ def warp_flow(img, flow): warped_frame = np.expand_dims(warped_frame, axis=-1) loss = F.mse_loss( torch.from_numpy(warped_frame), - torch.from_numpy(np.expand_dims(next_frames[i][ch], axis=-1)), + torch.from_numpy(np.expand_dims(next_frames[i][ch].astype(np.float32), axis=-1)), ) channel_total_losses[ch][i] += loss.item() - tmp_loss += loss.item() loss = F.mse_loss( - torch.from_numpy(curr_frame[ch].astype(np.float32)), - torch.from_numpy(next_frames[i][ch]), + torch.from_numpy(past_frames[-1][ch].astype(np.float32)), + torch.from_numpy(next_frames[i][ch].astype(np.float32)), ) channel_baseline_losses[ch][i] += loss.item() print( f"Avg Total Loss: {np.mean(channel_total_losses) / count} Avg Baseline Loss: {np.mean(channel_baseline_losses) / count}" ) if count % 100 == 0: - np.save("optical_flow_mse_loss_channels_reverse.npy", channel_total_losses / count) + np.save("optical_flow_mse_loss_channels_two_frames.npy", channel_total_losses / count) np.save( - "baseline_current_image_mse_loss_channels_reverse.npy", channel_baseline_losses / count + "baseline_current_image_mse_loss_channels_two_frames.npy", + channel_baseline_losses / count, ) -np.save("optical_flow_mse_loss_reverse.npy", channel_total_losses / count) -np.save("baseline_current_image_mse_loss_reverse.npy", channel_baseline_losses / count) +np.save("optical_flow_mse_loss_two_frames.npy", channel_total_losses / count) +np.save("baseline_current_image_mse_loss_two_frames.npy", channel_baseline_losses / count) diff --git a/satflow/baseline/optical_flow_avg.py b/satflow/baseline/optical_flow_avg.py new file mode 100644 index 00000000..9ca80e6b --- /dev/null +++ b/satflow/baseline/optical_flow_avg.py @@ -0,0 +1,175 @@ +import cv2 +from satflow.data.datasets import OpticalFlowDataset, SatFlowDataset +import webdataset as wds +import yaml +import torch.nn.functional as F +import numpy as np +from satflow.models.losses import SSIMLoss + + +def load_config(config_file): + with open(config_file, "r") as cfg: + return yaml.load(cfg, Loader=yaml.FullLoader)["config"] + + +used_loss = SSIMLoss(data_range=1.0, nonnegative_ssim=True, channel=3) + +config = load_config( + "/home/jacob/Development/satflow/satflow/configs/datamodule/optical_flow_datamodule.yaml" +) +dset = wds.WebDataset("/run/media/jacob/data/satflow-flow-144-tiled-{00001..00149}.tar") + +dataset = SatFlowDataset([dset], config=config) +import matplotlib.pyplot as plt +import torch + + +def warp_flow(img, flow): + h, w = flow.shape[:2] + flow = -flow + flow[:, :, 0] += np.arange(w) + flow[:, :, 1] += np.arange(h)[:, np.newaxis] + res = cv2.remap(img, flow, None, cv2.INTER_LINEAR) + return res + + +debug = False +total_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep +channel_total_losses = np.array([total_losses for _ in range(12)]) +count = 0 +baseline_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep +channel_baseline_losses = np.array([baseline_losses for _ in range(12)]) + +for data in dataset: + count += 1 + past_frames, next_frames = data + # 6 past frames + # Get average of the past ones pairs + past_frames = past_frames.astype(np.float32) + next_frames = next_frames.astype(np.float32) + prev_frame = past_frames[0] + curr_frame = past_frames[1] + # Do it for each of the 12 channels + for ch in range(12): + flow = cv2.calcOpticalFlowFarneback( + past_frames[0][ch].astype(np.float32), + past_frames[1][ch].astype(np.float32), + None, + 0.5, + 3, + 15, + 3, + 5, + 1.2, + 0, + ) + flow += cv2.calcOpticalFlowFarneback( + past_frames[1][ch].astype(np.float32), + past_frames[2][ch].astype(np.float32), + None, + 0.5, + 3, + 15, + 3, + 5, + 1.2, + 0, + ) + flow += cv2.calcOpticalFlowFarneback( + past_frames[2][ch].astype(np.float32), + past_frames[3][ch].astype(np.float32), + None, + 0.5, + 3, + 15, + 3, + 5, + 1.2, + 0, + ) + flow += cv2.calcOpticalFlowFarneback( + past_frames[3][ch].astype(np.float32), + past_frames[4][ch].astype(np.float32), + None, + 0.5, + 3, + 15, + 3, + 5, + 1.2, + 0, + ) + flow += cv2.calcOpticalFlowFarneback( + past_frames[4][ch].astype(np.float32), + past_frames[5][ch].astype(np.float32), + None, + 0.5, + 3, + 15, + 3, + 5, + 1.2, + 0, + ) + flow /= 5 + warped_frame = warp_flow(past_frames[5][ch].astype(np.float32), flow) + warped_frame2 = np.expand_dims(warped_frame, axis=0) + warped_frame2 = np.stack((warped_frame2,) * 3, axis=1) + next_frame = np.expand_dims(next_frames[0][ch].astype(np.float32), axis=0) + next_frame = np.stack((next_frame,) * 3, axis=1) + # Force between 0 and 1 + next_frame = (next_frame - np.min(next_frame)) / (np.max(next_frame) - np.min(next_frame)) + warped_frame2 = (warped_frame2 - np.min(warped_frame2)) / ( + np.max(warped_frame2) - np.min(warped_frame2) + ) + loss = used_loss( + torch.from_numpy(warped_frame2), + torch.from_numpy(next_frame), + ) + channel_total_losses[ch][0] += loss.item() + current_frame = np.expand_dims(past_frames[-1][ch].astype(np.float32), axis=0) + current_frame = np.stack((current_frame,) * 3, axis=1) + current_frame = (current_frame - np.min(current_frame)) / ( + np.max(current_frame) - np.min(current_frame) + ) + loss = used_loss( + torch.from_numpy(current_frame), + torch.from_numpy(next_frame), + ) + channel_baseline_losses[ch][0] += loss.item() + warped_frame = np.expand_dims(warped_frame, axis=-1) + for i in range(1, 48): + warped_frame = warp_flow(warped_frame.astype(np.float32), flow) + warped_frame2 = np.expand_dims(warped_frame, axis=0) + warped_frame2 = np.stack((warped_frame2,) * 3, axis=1) + next_frame = np.expand_dims(next_frames[0][ch].astype(np.float32), axis=0) + next_frame = np.stack((next_frame,) * 3, axis=1) + # Force between 0 and 1 + next_frame = (next_frame - np.min(next_frame)) / ( + np.max(next_frame) - np.min(next_frame) + ) + warped_frame2 = (warped_frame2 - np.min(warped_frame2)) / ( + np.max(warped_frame2) - np.min(warped_frame2) + ) + loss = used_loss( + torch.from_numpy(warped_frame2), + torch.from_numpy(next_frame), + ) + channel_total_losses[ch][i] += loss.item() + loss = used_loss( + torch.from_numpy(current_frame), + torch.from_numpy(next_frame), + ) + channel_baseline_losses[ch][i] += loss.item() + warped_frame = np.expand_dims(warped_frame, axis=-1) + print( + f"Avg Total Loss: {np.mean(channel_total_losses) / count} Avg Baseline Loss: {np.mean(channel_baseline_losses) / count}" + ) + if count % 10 == 0: + np.save("optical_flow_ssim_loss_avg_pairs.npy", channel_total_losses / count) + np.save( + "baseline_ssim_loss_channels_avg_pairs.npy", + channel_baseline_losses / count, + ) +# np.save("optical_flow_mse_loss_avg_pairs.npy", channel_total_losses / count) +# np.save("baseline_current_image_mse_loss_avg_pairs.npy", channel_baseline_losses / count) diff --git a/satflow/baseline/plot_flow.py b/satflow/baseline/plot_flow.py new file mode 100644 index 00000000..2aee1a87 --- /dev/null +++ b/satflow/baseline/plot_flow.py @@ -0,0 +1,90 @@ +import numpy as np +import matplotlib.pyplot as plt + +bands = [ + "HRV", + "IR016", + "IR039", + "IR087", + "IR097", + "IR108", + "IR120", + "IR134", + "VIS006", + "VIS008", + "WV062", + "WV073", +] + +baseline = np.load("baseline_current_image_mse_loss_channels_avg_base_only.npy") +flow = np.load("optical_flow_mse_loss_channels_avg_pairs.npy") +ssim = np.load("optical_flow_ssim_loss_channels_two_frames.npy") +ssim_base = np.load("baseline_current_image_ssim_loss_channels_baseline.npy") + + +# Now slice it up by channel and by timestep +per_channel_base = np.mean(baseline, axis=1) +per_channel_total = np.mean(flow, axis=1) +plt.plot(per_channel_base, label="Current Image") +plt.plot(per_channel_total, label="Optical Flow") +plt.legend(loc="best") +plt.xlabel("Channel #") +plt.ylabel("MSE") +plt.title("Mean MSE per channel") +plt.savefig("mse_per_channel.png", dpi=300) +plt.show() +per_channel_base = 1.0 - np.mean(ssim_base, axis=1) +per_channel_total = 1.0 - np.mean(ssim, axis=1) +plt.plot(per_channel_base, label="Current Image") +plt.plot(per_channel_total, label="Optical Flow") +plt.legend(loc="best") +plt.xlabel("Channel #") +plt.ylabel("SSIM") +plt.title("Mean SSIM per channel") +plt.savefig("ssim_per_channel.png", dpi=300) +plt.show() + +per_channel_base = np.mean(baseline, axis=0) +per_channel_total = np.mean(flow, axis=0) +plt.plot(per_channel_base, label="Current Image") +plt.plot(per_channel_total, label="Optical Flow") +plt.legend(loc="best") +plt.xlabel("Timestep") +plt.ylabel("MSE") +plt.title("Mean MSE per timestep") +plt.savefig("mse_per_timestep.png", dpi=300) +plt.show() +per_channel_base = 1.0 - np.mean(ssim_base, axis=0) +per_channel_total = 1.0 - np.mean(ssim, axis=0) +plt.plot(per_channel_base, label="Current Image") +plt.plot(per_channel_total, label="Optical Flow") +plt.legend(loc="best") +plt.xlabel("Timestep") +plt.ylabel("SSIM") +plt.title("Mean SSIM per timestep") +plt.savefig("ssim_per_timestep.png", dpi=300) +plt.show() + +for i in range(12): + per_channel_base = baseline[i] + per_channel_total = flow[i] + plt.plot(per_channel_base, label="Current Image") + plt.plot(per_channel_total, label="Optical Flow") + plt.legend(loc="best") + plt.xlabel("Timestep") + plt.ylabel("MSE") + plt.title(f"Mean MSE per timestep ({bands[i]})") + plt.savefig(f"mse_{bands[i]}.png", dpi=300) + plt.show() + +for i in range(12): + per_channel_base = 1.0 - ssim_base[i] + per_channel_total = 1.0 - ssim[i] + plt.plot(per_channel_base, label="Current Image") + plt.plot(per_channel_total, label="Optical Flow") + plt.legend(loc="best") + plt.xlabel("Timestep") + plt.ylabel("SSIM") + plt.title(f"Mean SSIM per timestep ({bands[i]})") + plt.savefig(f"ssim_{bands[i]}.png", dpi=300) + plt.show() diff --git a/satflow/configs/datamodule/optical_flow_datamodule.yaml b/satflow/configs/datamodule/optical_flow_datamodule.yaml index 98001b9e..827f02aa 100644 --- a/satflow/configs/datamodule/optical_flow_datamodule.yaml +++ b/satflow/configs/datamodule/optical_flow_datamodule.yaml @@ -11,7 +11,7 @@ num_workers: 1 pin_memory: False config: visualize: False - num_timesteps: 1 + num_timesteps: 5 skip_timesteps: 1 forecast_times: 48 output_shape: 400