From 7a00a4c565a086c535e4ec1d7a9a4873b6e15b60 Mon Sep 17 00:00:00 2001 From: cssprad1 Date: Mon, 18 Nov 2024 13:33:48 -0500 Subject: [PATCH] pep8 formatting --- pytorch_caney/configs/config.py | 16 +++++----- pytorch_caney/datamodules/__init__.py | 4 ++- .../datamodules/abi_3dcloud_datamodule.py | 32 ++++++++++++++++--- .../datamodules/modis_toa_mim_datamodule.py | 4 +-- pytorch_caney/datasets/abi_3dcloud_dataset.py | 26 ++++++++++++--- pytorch_caney/datasets/sharded_dataset.py | 8 ++--- pytorch_caney/models/model_factory.py | 10 +++--- pytorch_caney/ptc_cli.py | 8 ++--- 8 files changed, 75 insertions(+), 33 deletions(-) diff --git a/pytorch_caney/configs/config.py b/pytorch_caney/configs/config.py index 6beecb9..f633293 100644 --- a/pytorch_caney/configs/config.py +++ b/pytorch_caney/configs/config.py @@ -49,7 +49,7 @@ # Encoder type for fine-tuning _C.MODEL.ENCODER = '' # Decoder type for fine-tuning -_C.MODEL.DECODER = '' +_C.MODEL.DECODER = '' # Model name _C.MODEL.NAME = 'swinv2_base_patch4_window7_224' # Pretrained weight from checkpoint, could be from previous pre-training @@ -104,8 +104,8 @@ _C.TRAIN = CN() _C.TRAIN.ACCELERATOR = 'gpu' _C.TRAIN.STRATEGY = 'deepspeed' -_C.TRAIN.LIMIT_TRAIN_BATCHES = True -_C.TRAIN.NUM_TRAIN_BATCHES = None +_C.TRAIN.LIMIT_TRAIN_BATCHES = True +_C.TRAIN.NUM_TRAIN_BATCHES = None _C.TRAIN.START_EPOCH = 0 _C.TRAIN.EPOCHS = 300 _C.TRAIN.WARMUP_EPOCHS = 20 @@ -120,7 +120,7 @@ _C.TRAIN.AUTO_RESUME = True # Gradient accumulation steps # could be overwritten by command line argument -_C.TRAIN.ACCUMULATION_STEPS = 1 +_C.TRAIN.ACCUMULATION_STEPS = 1 # Whether to use gradient checkpointing to save memory # could be overwritten by command line argument _C.TRAIN.USE_CHECKPOINT = False @@ -160,8 +160,8 @@ _C.DEEPSPEED.STAGE = 2 _C.DEEPSPEED.REDUCE_BUCKET_SIZE = 5e8 _C.DEEPSPEED.ALLGATHER_BUCKET_SIZE = 5e8 -_C.DEEPSPEED.CONTIGUOUS_GRADIENTS = True -_C.DEEPSPEED.OVERLAP_COMM = True +_C.DEEPSPEED.CONTIGUOUS_GRADIENTS = True +_C.DEEPSPEED.OVERLAP_COMM = True # ----------------------------------------------------------------------------- @@ -175,7 +175,7 @@ # Misc # ----------------------------------------------------------------------------- # Whether to enable pytorch amp, overwritten by command line argument -_C.PRECISION = '32' +_C.PRECISION = '32' # Enable Pytorch automatic mixed precision (amp). _C.AMP_ENABLE = True # Path to output folder, overwritten by command line argument @@ -196,7 +196,7 @@ _C.PIPELINE = 'satvisiontoapretrain' # Data module _C.DATAMODULE = 'abitoa3dcloud' -# Fast dev run +# Fast dev run _C.FAST_DEV_RUN = False diff --git a/pytorch_caney/datamodules/__init__.py b/pytorch_caney/datamodules/__init__.py index ad540c3..b5633d2 100644 --- a/pytorch_caney/datamodules/__init__.py +++ b/pytorch_caney/datamodules/__init__.py @@ -1,10 +1,12 @@ from .abi_3dcloud_datamodule import AbiToa3DCloudDataModule from .modis_toa_mim_datamodule import ModisToaMimDataModule + DATAMODULES = { 'abitoa3dcloud': AbiToa3DCloudDataModule, 'modistoamimpretrain': ModisToaMimDataModule, } + def get_available_datamodules(): - return {name: cls for name, cls in DATAMODULES.items()} \ No newline at end of file + return {name: cls for name, cls in DATAMODULES.items()} diff --git a/pytorch_caney/datamodules/abi_3dcloud_datamodule.py b/pytorch_caney/datamodules/abi_3dcloud_datamodule.py index c44c342..2b23f03 100644 --- a/pytorch_caney/datamodules/abi_3dcloud_datamodule.py +++ b/pytorch_caney/datamodules/abi_3dcloud_datamodule.py @@ -5,9 +5,15 @@ from pytorch_caney.transforms.abi_toa import AbiToaTransform +# ----------------------------------------------------------------------------- +# AbiToa3DCloudDataModule +# ----------------------------------------------------------------------------- class AbiToa3DCloudDataModule(L.LightningDataModule): """NonGeo ABI TOA 3D cloud data module implementation""" + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- def __init__( self, config, @@ -21,6 +27,9 @@ def __init__( self.batch_size = config.DATA.BATCH_SIZE self.num_workers = config.DATA.NUM_WORKERS + # ------------------------------------------------------------------------- + # setup + # ------------------------------------------------------------------------- def setup(self, stage: str) -> None: if stage in ["fit"]: self.train_dataset = AbiToa3DCloudDataset( @@ -40,12 +49,27 @@ def setup(self, stage: str) -> None: self.test_data_paths, self.transform, ) - + + # ------------------------------------------------------------------------- + # train_dataloader + # ------------------------------------------------------------------------- def train_dataloader(self): - return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers) + return DataLoader(self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers) + # ------------------------------------------------------------------------- + # val_dataloader + # ------------------------------------------------------------------------- def val_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers) + return DataLoader(self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers) + # ------------------------------------------------------------------------- + # test_dataloader + # ------------------------------------------------------------------------- def test_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers) + return DataLoader(self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers) diff --git a/pytorch_caney/datamodules/modis_toa_mim_datamodule.py b/pytorch_caney/datamodules/modis_toa_mim_datamodule.py index b52b8e2..e77064e 100644 --- a/pytorch_caney/datamodules/modis_toa_mim_datamodule.py +++ b/pytorch_caney/datamodules/modis_toa_mim_datamodule.py @@ -26,7 +26,7 @@ def __init__(self, config,) -> None: self.pin_memory = config.DATA.PIN_MEMORY # ------------------------------------------------------------------------- - # setup + # setup # ------------------------------------------------------------------------- def setup(self, stage: str) -> None: if stage in ["fit"]: @@ -40,7 +40,7 @@ def setup(self, stage: str) -> None: batch_size=self.batch_size).dataset() # ------------------------------------------------------------------------- - # train_dataloader + # train_dataloader # ------------------------------------------------------------------------- def train_dataloader(self) -> DataLoader: return DataLoader(self.train_dataset, diff --git a/pytorch_caney/datasets/abi_3dcloud_dataset.py b/pytorch_caney/datasets/abi_3dcloud_dataset.py index 3f37eae..85056fc 100644 --- a/pytorch_caney/datasets/abi_3dcloud_dataset.py +++ b/pytorch_caney/datasets/abi_3dcloud_dataset.py @@ -8,14 +8,20 @@ from torchgeo.datasets import NonGeoDataset +# ----------------------------------------------------------------------------- +# AbiToa3DCloudDataModule +# ----------------------------------------------------------------------------- class AbiToa3DCloudDataset(NonGeoDataset): - + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- def __init__(self, config, data_paths: list, transform=None) -> None: super().__init__() self.config = config - self.data_paths = data_paths + self.data_paths = data_paths self.transform = transform self.img_size = config.DATA.IMG_SIZE @@ -27,9 +33,15 @@ def __init__(self, config, data_paths: list, transform=None) -> None: self.rgb_indices = [0, 1, 2] + # ------------------------------------------------------------------------- + # __len__ + # ------------------------------------------------------------------------- def __len__(self) -> int: return len(self.image_list) + # ------------------------------------------------------------------------- + # __getitem__ + # ------------------------------------------------------------------------- def __getitem__(self, index: int) -> Dict[str, Any]: npz_array = self._load_file(self.image_list[index]) @@ -39,16 +51,22 @@ def __getitem__(self, index: int) -> Dict[str, Any]: if self.transform is not None: image = self.transform(image) - return image, mask + return image, mask + # ------------------------------------------------------------------------- + # _load_file + # ------------------------------------------------------------------------- def _load_file(self, path: Path): if Path(path).suffix == '.npy' or Path(path).suffix == '.npz': return np.load(path, allow_pickle=True) elif Path(path).suffix == '.tif': return rxr.open_rasterio(path) else: - raise RuntimeError('Non-recognized dataset format. Expects npy or tif.') + raise RuntimeError('Non-recognized dataset format. Expects npy or tif.') # noqa: E501 + # ------------------------------------------------------------------------- + # get_filenames + # ------------------------------------------------------------------------- def get_filenames(self, path): """ Returns a list of absolute paths to images inside given `path` diff --git a/pytorch_caney/datasets/sharded_dataset.py b/pytorch_caney/datasets/sharded_dataset.py index 185a3b8..8cec063 100644 --- a/pytorch_caney/datasets/sharded_dataset.py +++ b/pytorch_caney/datasets/sharded_dataset.py @@ -23,7 +23,7 @@ def nodesplitter(src, group=None): if i % size == rank: yield item count += 1 - logging.info(f"nodesplitter: rank={rank} size={size} " + \ + logging.info(f"nodesplitter: rank={rank} size={size} " + f"count={count} DONE") else: yield from src @@ -34,7 +34,7 @@ def nodesplitter(src, group=None): # ----------------------------------------------------------------------------- class ShardedDataset(object): """ - Base pre-training webdataset + Base pre-training webdataset """ SHARD_PATH = os.path.join("shards") @@ -53,7 +53,7 @@ def __init__( batch_size=64, ): - self.random_state = 1000 + self.random_state = 1000 self.config = config self.img_size = img_size self.transform = transform @@ -87,4 +87,4 @@ def dataset(self): .with_length(self.length) ) - return dataset \ No newline at end of file + return dataset diff --git a/pytorch_caney/models/model_factory.py b/pytorch_caney/models/model_factory.py index f12b3fd..e888ae4 100644 --- a/pytorch_caney/models/model_factory.py +++ b/pytorch_caney/models/model_factory.py @@ -1,5 +1,3 @@ -import torch.nn as nn - # ----------------------------------------------------------------------------- # ModelFactory # ----------------------------------------------------------------------------- @@ -38,7 +36,7 @@ def register_head(cls, name: str, head_cls): # ------------------------------------------------------------------------- @classmethod def get_component(cls, component_type: str, name: str, **kwargs): - """Public method to retrieve and instantiate a component by type and name.""" + """Public method to retrieve and instantiate a component by type and name.""" # noqa: E501 print(cls.backbones) print(cls.decoders) print(cls.heads) @@ -49,8 +47,8 @@ def get_component(cls, component_type: str, name: str, **kwargs): }.get(component_type) if registry is None or name not in registry: - raise ValueError(f"{component_type.capitalize()} '{name}' not found in registry.") - + raise ValueError(f"{component_type.capitalize()} '{name}' not found in registry.") # noqa: E501 + return registry[name](**kwargs) # ------------------------------------------------------------------------- @@ -84,4 +82,4 @@ def head(cls, name): def decorator(head_cls): cls.register_head(name, head_cls) return head_cls - return decorator \ No newline at end of file + return decorator diff --git a/pytorch_caney/ptc_cli.py b/pytorch_caney/ptc_cli.py index 424623b..d41ed96 100644 --- a/pytorch_caney/ptc_cli.py +++ b/pytorch_caney/ptc_cli.py @@ -1,7 +1,7 @@ import argparse import os -from lightning.pytorch import Trainer +from lightning.pytorch import Trainer from pytorch_caney.configs.config import _C, _update_config_from_file from pytorch_caney.utils import get_strategy, get_distributed_train_batches @@ -42,7 +42,8 @@ def main(config, output_dir): ) if config.TRAIN.LIMIT_TRAIN_BATCHES: - trainer.limit_train_batches = get_distributed_train_batches(config, trainer) + trainer.limit_train_batches = get_distributed_train_batches( + config, trainer) if config.DATA.DATAMODULE: available_datamodules = get_available_datamodules() @@ -53,13 +54,12 @@ def main(config, output_dir): trainer.fit(model=ptlPipeline, datamodule=datamodule) else: - print(f'Training without datamodule, assuming data is set in pipeline: {ptlPipeline}') + print(f'Training without datamodule, assuming data is set in pipeline: {ptlPipeline}') # noqa: E501 trainer.fit(model=ptlPipeline) if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument(