Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
ready to start testing exp 29. #139
Browse files Browse the repository at this point in the history
  • Loading branch information
JackKelly committed Jun 17, 2022
1 parent 7a0c8e9 commit 585298a
Show file tree
Hide file tree
Showing 20 changed files with 231 additions and 415 deletions.
2 changes: 1 addition & 1 deletion experiments/020_transformer_over_time_without_for_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
BatchKey.pv_time_utc_fourier,
BatchKey.hrvsatellite_solar_azimuth,
BatchKey.hrvsatellite_solar_elevation,
BatchKey.hrvsatellite,
BatchKey.hrvsatellite_actual,
BatchKey.hrvsatellite_time_utc_fourier,
):
x[batch_key] = einops.rearrange(x[batch_key], "example time ... -> (example time) ...")
Expand Down
4 changes: 2 additions & 2 deletions experiments/021_predict_future_imagery_climatehack.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __post_init__(self):
self.save_hyperparameters()

def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
historical_sat = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0]
historical_sat = x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0]
predicted_sat = self.satellite_predictor(historical_sat)
return dict(predicted_sat=predicted_sat)

Expand All @@ -87,7 +87,7 @@ def validation_step(
tag = "train" if self.training else "validation"
network_out = self(batch)
predicted_sat = network_out["predicted_sat"]
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]
sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{tag}/sat_mse", sat_mse_loss)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def __post_init__(self):
self.save_hyperparameters()

def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
data = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0] # Shape: (example, time, y, x)
data = x[BatchKey.hrvsatellite_actual][
:, :NUM_HIST_SAT_IMAGES, 0
] # Shape: (example, time, y, x)
height, width = data.shape[2:]
if self.use_coord_conv:
osgb_coords = get_osgb_coords_for_coord_conv(x)
Expand Down Expand Up @@ -208,7 +210,7 @@ def validation_step(
tag = "train" if self.training else "validation"
network_out = self(batch)
predicted_sat = network_out["predicted_sat"]
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]
sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{tag}/sat_mse", sat_mse_loss)

Expand Down
10 changes: 6 additions & 4 deletions experiments/023_unet_and_power_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def __post_init__(self):
self.save_hyperparameters()

def forward(self, x: dict[BatchKey, torch.Tensor]) -> torch.Tensor:
data = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0] # Shape: (example, time, y, x)
data = x[BatchKey.hrvsatellite_actual][
:, :NUM_HIST_SAT_IMAGES, 0
] # Shape: (example, time, y, x)
height, width = data.shape[2:]
assert height == IMAGE_SIZE_PIXELS, f"{height=}"
assert width == IMAGE_SIZE_PIXELS, f"{width=}"
Expand Down Expand Up @@ -245,7 +247,7 @@ def validation_step(
tag = "train" if self.training else "validation"
network_out = self(batch)
predicted_sat = network_out
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]
sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{tag}/sat_mse", sat_mse_loss)

Expand Down Expand Up @@ -373,7 +375,7 @@ def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
BatchKey.pv_time_utc_fourier,
BatchKey.hrvsatellite_solar_azimuth,
BatchKey.hrvsatellite_solar_elevation,
BatchKey.hrvsatellite,
BatchKey.hrvsatellite_actual,
BatchKey.hrvsatellite_time_utc_fourier,
):
original_x[batch_key] = x[batch_key]
Expand Down Expand Up @@ -494,7 +496,7 @@ def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:

# Replace the "actual" future satellite images with predicted images
# shape: (batch_size, time, channels, y, x, n_pixels_per_patch)
x[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0] = predicted_sat
x[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0] = predicted_sat

sat_trans_out = self.satellite_transformer(x)
pv_attn_out = sat_trans_out["pv_attn_out"] # Shape: (example time n_pv_systems d_model)
Expand Down
10 changes: 6 additions & 4 deletions experiments/024_raw_dataset_unet_and_power_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ def __post_init__(self):
self.save_hyperparameters()

def forward(self, x: dict[BatchKey, torch.Tensor]) -> torch.Tensor:
data = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0] # Shape: (example, time, y, x)
data = x[BatchKey.hrvsatellite_actual][
:, :NUM_HIST_SAT_IMAGES, 0
] # Shape: (example, time, y, x)
height, width = data.shape[2:]
assert height == SATELLITE_PREDICTOR_IMAGE_HEIGHT_PIXELS, f"{height=}"
assert width == SATELLITE_PREDICTOR_IMAGE_WIDTH_PIXELS, f"{width=}"
Expand Down Expand Up @@ -505,15 +507,15 @@ def __post_init__(self):
def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
# Predict future satellite images. The SatellitePredictor always gets every timestep.
if self.cheat:
predicted_sat = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0]
predicted_sat = x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0]
else:
predicted_sat = self.satellite_predictor(x=x) # Shape: example, time, y, x

if self.stop_gradients_before_unet:
predicted_sat = predicted_sat.detach()

hrvsatellite = torch.concat(
(x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
(x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
)
assert hrvsatellite.isfinite().all()

Expand Down Expand Up @@ -759,7 +761,7 @@ def validation_step(

# SATELLITE PREDICTOR LOSS ################
predicted_sat = network_out["predicted_sat"]
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]

sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{self.tag}/sat_mse", sat_mse_loss)
Expand Down
10 changes: 6 additions & 4 deletions experiments/025_RNN_raw_dataset_unet_and_power_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def __post_init__(self):
)

def forward(self, x: dict[BatchKey, torch.Tensor]) -> torch.Tensor:
data = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0] # Shape: (example, time, y, x)
data = x[BatchKey.hrvsatellite_actual][
:, :NUM_HIST_SAT_IMAGES, 0
] # Shape: (example, time, y, x)
height, width = data.shape[2:]
assert height == SATELLITE_PREDICTOR_IMAGE_HEIGHT_PIXELS, f"{height=}"
assert width == SATELLITE_PREDICTOR_IMAGE_WIDTH_PIXELS, f"{width=}"
Expand Down Expand Up @@ -567,15 +569,15 @@ def __post_init__(self):
def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
# Predict future satellite images. The SatellitePredictor always gets every timestep.
if self.cheat:
predicted_sat = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0]
predicted_sat = x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0]
else:
predicted_sat = self.satellite_predictor(x=x) # Shape: example, time, y, x

if self.stop_gradients_before_unet:
predicted_sat = predicted_sat.detach()

hrvsatellite = torch.concat(
(x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
(x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
)
assert hrvsatellite.isfinite().all()

Expand Down Expand Up @@ -890,7 +892,7 @@ def validation_step(

# SATELLITE PREDICTOR LOSS ################
predicted_sat = network_out["predicted_sat"]
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]

sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{self.tag}/sat_mse", sat_mse_loss)
Expand Down
10 changes: 6 additions & 4 deletions experiments/025_extract_sat_predictor_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def __post_init__(self):
self.save_hyperparameters()

def forward(self, x: dict[BatchKey, torch.Tensor]) -> torch.Tensor:
data = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0] # Shape: (example, time, y, x)
data = x[BatchKey.hrvsatellite_actual][
:, :NUM_HIST_SAT_IMAGES, 0
] # Shape: (example, time, y, x)
height, width = data.shape[2:]
assert height == SATELLITE_PREDICTOR_IMAGE_HEIGHT_PIXELS, f"{height=}"
assert width == SATELLITE_PREDICTOR_IMAGE_WIDTH_PIXELS, f"{width=}"
Expand Down Expand Up @@ -382,15 +384,15 @@ def __post_init__(self):
def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
# Predict future satellite images. The SatellitePredictor always gets every timestep.
if self.cheat:
predicted_sat = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0]
predicted_sat = x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0]
else:
predicted_sat = self.satellite_predictor(x=x) # Shape: example, time, y, x

if self.stop_gradients_before_unet:
predicted_sat = predicted_sat.detach()

hrvsatellite = torch.concat(
(x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
(x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
)
assert hrvsatellite.isfinite().all()

Expand Down Expand Up @@ -698,7 +700,7 @@ def validation_step(

# SATELLITE PREDICTOR LOSS ################
predicted_sat = network_out["predicted_sat"]
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]

sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{self.tag}/sat_mse", sat_mse_loss)
Expand Down
12 changes: 7 additions & 5 deletions experiments/026_dont_train_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ def __post_init__(self):
)

def forward(self, x: dict[BatchKey, torch.Tensor]) -> torch.Tensor:
data = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0] # Shape: (example, time, y, x)
data = x[BatchKey.hrvsatellite_actual][
:, :NUM_HIST_SAT_IMAGES, 0
] # Shape: (example, time, y, x)
height, width = data.shape[2:]
assert height == SATELLITE_PREDICTOR_IMAGE_HEIGHT_PIXELS, f"{height=}"
assert width == SATELLITE_PREDICTOR_IMAGE_WIDTH_PIXELS, f"{width=}"
Expand Down Expand Up @@ -360,7 +362,7 @@ def validation_step(
self, batch: dict[BatchKey, torch.Tensor], batch_idx: int
) -> dict[str, object]:
predicted_sat = self(batch)
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]

sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{self.tag}/sat_mse", sat_mse_loss)
Expand Down Expand Up @@ -648,13 +650,13 @@ def __post_init__(self):
def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
# Predict future satellite images. The SatellitePredictor always gets every timestep.
if self.cheat:
predicted_sat = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0]
predicted_sat = x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0]
else:
predicted_sat = self.satellite_predictor(x=x) # Shape: example, time, y, x

hrvsatellite = torch.concat(
(
x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0],
x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0],
# Detach because it looks like it hurts performance to let the gradients go
# backwards from here.
predicted_sat.detach(),
Expand Down Expand Up @@ -949,7 +951,7 @@ def validation_step(
"predicted_pv_power": predicted_pv_power,
"predicted_pv_power_mean": get_distribution(predicted_pv_power).mean,
"predicted_sat": network_out["predicted_sat"], # Shape: example, time, y, x
"actual_sat": batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0],
"actual_sat": batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0],
"pv_power_from_sat_transformer": network_out["pv_power_from_sat_transformer"],
"gsp_power_from_sat_transformer": network_out["gsp_power_from_sat_transformer"],
"random_timestep_indexes": network_out["random_timestep_indexes"],
Expand Down
12 changes: 7 additions & 5 deletions experiments/027_longer_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def __post_init__(self):
)

def forward(self, x: dict[BatchKey, torch.Tensor]) -> torch.Tensor:
data = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0] # Shape: (example, time, y, x)
data = x[BatchKey.hrvsatellite_actual][
:, :NUM_HIST_SAT_IMAGES, 0
] # Shape: (example, time, y, x)
height, width = data.shape[2:]
assert height == SATELLITE_PREDICTOR_IMAGE_HEIGHT_PIXELS, f"{height=}"
assert width == SATELLITE_PREDICTOR_IMAGE_WIDTH_PIXELS, f"{width=}"
Expand Down Expand Up @@ -344,7 +346,7 @@ def validation_step(
self, batch: dict[BatchKey, torch.Tensor], batch_idx: int
) -> dict[str, object]:
predicted_sat = self(batch)
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]

sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{self.tag}/sat_mse", sat_mse_loss)
Expand Down Expand Up @@ -632,12 +634,12 @@ def __post_init__(self):
def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
# Predict future satellite images. The SatellitePredictor always gets every timestep.
if self.cheat:
predicted_sat = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0]
predicted_sat = x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0]
else:
predicted_sat = self.satellite_predictor(x=x) # Shape: example, time, y, x

hrvsatellite = torch.concat(
(x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
(x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
)
assert hrvsatellite.isfinite().all()

Expand Down Expand Up @@ -931,7 +933,7 @@ def validation_step(
"predicted_pv_power": predicted_pv_power,
"predicted_pv_power_mean": get_distribution(predicted_pv_power).mean,
"predicted_sat": network_out["predicted_sat"], # Shape: example, time, y, x
"actual_sat": batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0],
"actual_sat": batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0],
"pv_power_from_sat_transformer": network_out["pv_power_from_sat_transformer"],
"gsp_power_from_sat_transformer": network_out["gsp_power_from_sat_transformer"],
"random_timestep_indexes": network_out["random_timestep_indexes"],
Expand Down
12 changes: 7 additions & 5 deletions experiments/028_remove_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def __post_init__(self):
)

def forward(self, x: dict[BatchKey, torch.Tensor]) -> torch.Tensor:
data = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0] # Shape: (example, time, y, x)
data = x[BatchKey.hrvsatellite_actual][
:, :NUM_HIST_SAT_IMAGES, 0
] # Shape: (example, time, y, x)
height, width = data.shape[2:]
assert height == SATELLITE_PREDICTOR_IMAGE_HEIGHT_PIXELS, f"{height=}"
assert width == SATELLITE_PREDICTOR_IMAGE_WIDTH_PIXELS, f"{width=}"
Expand Down Expand Up @@ -344,7 +346,7 @@ def validation_step(
self, batch: dict[BatchKey, torch.Tensor], batch_idx: int
) -> dict[str, object]:
predicted_sat = self(batch)
actual_sat = batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0]
actual_sat = batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0]

sat_mse_loss = F.mse_loss(predicted_sat, actual_sat)
self.log(f"{self.tag}/sat_mse", sat_mse_loss)
Expand Down Expand Up @@ -609,12 +611,12 @@ def __post_init__(self):
def forward(self, x: dict[BatchKey, torch.Tensor]) -> dict[str, torch.Tensor]:
# Predict future satellite images. The SatellitePredictor always gets every timestep.
if self.cheat:
predicted_sat = x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0]
predicted_sat = x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0]
else:
predicted_sat = self.satellite_predictor(x=x) # Shape: example, time, y, x

hrvsatellite = torch.concat(
(x[BatchKey.hrvsatellite][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
(x[BatchKey.hrvsatellite_actual][:, :NUM_HIST_SAT_IMAGES, 0], predicted_sat), dim=1
)
assert hrvsatellite.isfinite().all()

Expand Down Expand Up @@ -878,7 +880,7 @@ def validation_step(
"predicted_pv_power": predicted_pv_power,
"predicted_pv_power_mean": get_distribution(predicted_pv_power).mean,
"predicted_sat": network_out["predicted_sat"], # Shape: example, time, y, x
"actual_sat": batch[BatchKey.hrvsatellite][:, NUM_HIST_SAT_IMAGES:, 0],
"actual_sat": batch[BatchKey.hrvsatellite_actual][:, NUM_HIST_SAT_IMAGES:, 0],
"pv_power_from_sat_transformer": network_out["pv_power_from_sat_transformer"],
"gsp_power_from_sat_transformer": network_out["gsp_power_from_sat_transformer"],
"random_timestep_indexes": network_out["random_timestep_indexes"],
Expand Down
Loading

0 comments on commit 585298a

Please sign in to comment.