diff --git a/graph_weather/models/gencast/train.py b/graph_weather/models/gencast/train.py index 0ff730c..f36dcae 100644 --- a/graph_weather/models/gencast/train.py +++ b/graph_weather/models/gencast/train.py @@ -159,7 +159,7 @@ def plot_sample(self, prev_inputs, target_residuals): preds = sampler.sample(self.model, prev_inputs) fig1, ax = plt.subplots(2) - im = ax[0].imshow(preds[0, :, :, 78].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) + ax[0].imshow(preds[0, :, :, 78].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) ax[0].set_xticks([]) ax[0].set_yticks([]) ax[0].set_title("Diffusion sampling prediction") @@ -170,7 +170,7 @@ def plot_sample(self, prev_inputs, target_residuals): ax[1].set_title("Ground truth") fig2, ax = plt.subplots(2) - im = ax[0].imshow(preds[0, :, :, 12].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) + ax[0].imshow(preds[0, :, :, 12].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) ax[0].set_xticks([]) ax[0].set_yticks([]) ax[0].set_title("Diffusion sampling prediction")