From 31e3ffae605aff1b4db65715056d951a64de3855 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Mar 2023 19:55:36 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 2 +- satflow/data/datamodules.py | 1 - satflow/data/utils/utils.py | 3 ++- satflow/models/conv_lstm.py | 2 -- satflow/models/gan/generators.py | 1 - satflow/models/layers/Attention.py | 5 ----- satflow/models/layers/Discriminator.py | 1 - satflow/models/layers/GResBlock.py | 2 -- satflow/models/layers/Generator.py | 3 --- satflow/models/layers/Normalization.py | 1 - satflow/models/layers/RUnetLayers.py | 1 - satflow/run.py | 1 - 12 files changed, 3 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 8bd50657..74d57a81 100644 --- a/README.md +++ b/README.md @@ -46,4 +46,4 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d -This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! \ No newline at end of file +This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! diff --git a/satflow/data/datamodules.py b/satflow/data/datamodules.py index 7b3ba46f..26417337 100644 --- a/satflow/data/datamodules.py +++ b/satflow/data/datamodules.py @@ -186,7 +186,6 @@ def per_worker_init(self, worker_id: int): pass def __getitem__(self, idx): - x = { SATELLITE_DATA: torch.randn( self.batch_size, self.seq_length, self.width, self.height, self.number_sat_channels diff --git a/satflow/data/utils/utils.py b/satflow/data/utils/utils.py index bf896524..d320a0d1 100644 --- a/satflow/data/utils/utils.py +++ b/satflow/data/utils/utils.py @@ -33,7 +33,8 @@ def eumetsat_name_to_datetime(filename: str): def retrieve_pixel_value(geo_coord, data_source): """Return floating-point value that corresponds to given point. - Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal""" + Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal + """ x, y = geo_coord[0], geo_coord[1] forward_transform = affine.Affine.from_gdal(*data_source.GetGeoTransform()) reverse_transform = ~forward_transform diff --git a/satflow/models/conv_lstm.py b/satflow/models/conv_lstm.py index 7373fb77..15d46a8b 100644 --- a/satflow/models/conv_lstm.py +++ b/satflow/models/conv_lstm.py @@ -169,7 +169,6 @@ def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "s ) def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4): - outputs = [] # encoder @@ -203,7 +202,6 @@ def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, return outputs def forward(self, x, forecast_steps=0, hidden_state=None): - """ Parameters ---------- diff --git a/satflow/models/gan/generators.py b/satflow/models/gan/generators.py index 2e2ded7e..73953ddd 100644 --- a/satflow/models/gan/generators.py +++ b/satflow/models/gan/generators.py @@ -146,7 +146,6 @@ def __init__( mult = 2**n_downsampling for i in range(n_blocks): # add ResNet blocks - model += [ ResnetBlock( ngf * mult, diff --git a/satflow/models/layers/Attention.py b/satflow/models/layers/Attention.py index ef543152..677a8c0b 100644 --- a/satflow/models/layers/Attention.py +++ b/satflow/models/layers/Attention.py @@ -16,7 +16,6 @@ def __init__( ) def forward(self, x): - return self.model(x) @@ -56,7 +55,6 @@ def init_conv(self, conv, glu=True): conv.bias.data.zero_() def forward(self, x): - batch_size, C, T, W, H = x.size() assert T % 2 == 0 and W % 2 == 0 and H % 2 == 0, "T, W, H is not even" @@ -111,7 +109,6 @@ def forward(self, x): class SelfAttention(nn.Module): def __init__(self, in_dim, activation=F.relu, pooling_factor=2): # TODO for better compability - super(SelfAttention, self).__init__() self.activation = activation @@ -134,7 +131,6 @@ def init_conv(self, conv, glu=True): conv.bias.data.zero_() def forward(self, x): - if len(x.size()) == 4: batch_size, C, W, H = x.size() T = 1 @@ -224,7 +220,6 @@ def forward(self, x): if __name__ == "__main__": - self_attn = SelfAttention(16) # no less than 8 print(self_attn) diff --git a/satflow/models/layers/Discriminator.py b/satflow/models/layers/Discriminator.py index 097fa5d3..5ab23df9 100644 --- a/satflow/models/layers/Discriminator.py +++ b/satflow/models/layers/Discriminator.py @@ -651,7 +651,6 @@ def forward(self, x, class_id): if __name__ == "__main__": - batch_size = 6 n_frames = 8 n_class = 4 diff --git a/satflow/models/layers/GResBlock.py b/satflow/models/layers/GResBlock.py index 492ccea7..9432ca41 100644 --- a/satflow/models/layers/GResBlock.py +++ b/satflow/models/layers/GResBlock.py @@ -52,7 +52,6 @@ def __init__( self.CBNorm2 = ConditionalNorm(out_channel, n_class) def forward(self, x, condition=None): - # The time dimension is combined with the batch dimension here, so each frame proceeds # through the blocks independently BT, C, W, H = x.size() @@ -100,7 +99,6 @@ def forward(self, x, condition=None): if __name__ == "__main__": - n_class = 96 batch_size = 4 n_frames = 20 diff --git a/satflow/models/layers/Generator.py b/satflow/models/layers/Generator.py index 2a860bb9..34d17383 100644 --- a/satflow/models/layers/Generator.py +++ b/satflow/models/layers/Generator.py @@ -68,7 +68,6 @@ def __init__(self, in_dim=120, latent_dim=4, n_class=4, ch=32, n_frames=48, hier self.colorize = SpectralNorm(nn.Conv2d(2 * ch, 3, kernel_size=(3, 3), padding=1)) def forward(self, x, class_id): - if self.hierar_flag is True: noise_emb = torch.split(x, self.in_dim, dim=1) else: @@ -87,7 +86,6 @@ def forward(self, x, class_id): for k, conv in enumerate(self.conv): if isinstance(conv, ConvGRU): - if k > 0: _, C, W, H = y.size() y = y.view(-1, self.n_frames, C, W, H).contiguous() @@ -132,7 +130,6 @@ def forward(self, x, class_id): if __name__ == "__main__": - batch_size = 5 in_dim = 120 n_class = 4 diff --git a/satflow/models/layers/Normalization.py b/satflow/models/layers/Normalization.py index ddee3564..eab7ae18 100644 --- a/satflow/models/layers/Normalization.py +++ b/satflow/models/layers/Normalization.py @@ -87,7 +87,6 @@ def forward(self, x, class_id): if __name__ == "__main__": - cn = ConditionalNorm(3, 2) x = torch.rand([4, 3, 64, 64]) class_id = torch.rand([4, 2]) diff --git a/satflow/models/layers/RUnetLayers.py b/satflow/models/layers/RUnetLayers.py index f6c910af..9fcf1c6b 100644 --- a/satflow/models/layers/RUnetLayers.py +++ b/satflow/models/layers/RUnetLayers.py @@ -80,7 +80,6 @@ def __init__(self, ch_out, t=2, conv_type: str = "standard"): def forward(self, x): for i in range(self.t): - if i == 0: x1 = self.conv(x) diff --git a/satflow/run.py b/satflow/run.py index 828112d9..060af94c 100644 --- a/satflow/run.py +++ b/satflow/run.py @@ -12,7 +12,6 @@ @hydra.main(config_path="configs/", config_name="config.yaml") def main(config: DictConfig): - # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 from satflow.core import utils