Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pep8 formatting #67

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pytorch_caney/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
# Encoder type for fine-tuning
_C.MODEL.ENCODER = ''
# Decoder type for fine-tuning
_C.MODEL.DECODER = ''
_C.MODEL.DECODER = ''
# Model name
_C.MODEL.NAME = 'swinv2_base_patch4_window7_224'
# Pretrained weight from checkpoint, could be from previous pre-training
Expand Down Expand Up @@ -104,8 +104,8 @@
_C.TRAIN = CN()
_C.TRAIN.ACCELERATOR = 'gpu'
_C.TRAIN.STRATEGY = 'deepspeed'
_C.TRAIN.LIMIT_TRAIN_BATCHES = True
_C.TRAIN.NUM_TRAIN_BATCHES = None
_C.TRAIN.LIMIT_TRAIN_BATCHES = True
_C.TRAIN.NUM_TRAIN_BATCHES = None
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.EPOCHS = 300
_C.TRAIN.WARMUP_EPOCHS = 20
Expand All @@ -120,7 +120,7 @@
_C.TRAIN.AUTO_RESUME = True
# Gradient accumulation steps
# could be overwritten by command line argument
_C.TRAIN.ACCUMULATION_STEPS = 1
_C.TRAIN.ACCUMULATION_STEPS = 1
# Whether to use gradient checkpointing to save memory
# could be overwritten by command line argument
_C.TRAIN.USE_CHECKPOINT = False
Expand Down Expand Up @@ -160,8 +160,8 @@
_C.DEEPSPEED.STAGE = 2
_C.DEEPSPEED.REDUCE_BUCKET_SIZE = 5e8
_C.DEEPSPEED.ALLGATHER_BUCKET_SIZE = 5e8
_C.DEEPSPEED.CONTIGUOUS_GRADIENTS = True
_C.DEEPSPEED.OVERLAP_COMM = True
_C.DEEPSPEED.CONTIGUOUS_GRADIENTS = True
_C.DEEPSPEED.OVERLAP_COMM = True


# -----------------------------------------------------------------------------
Expand All @@ -175,7 +175,7 @@
# Misc
# -----------------------------------------------------------------------------
# Whether to enable pytorch amp, overwritten by command line argument
_C.PRECISION = '32'
_C.PRECISION = '32'
# Enable Pytorch automatic mixed precision (amp).
_C.AMP_ENABLE = True
# Path to output folder, overwritten by command line argument
Expand All @@ -196,7 +196,7 @@
_C.PIPELINE = 'satvisiontoapretrain'
# Data module
_C.DATAMODULE = 'abitoa3dcloud'
# Fast dev run
# Fast dev run
_C.FAST_DEV_RUN = False


Expand Down
4 changes: 3 additions & 1 deletion pytorch_caney/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .abi_3dcloud_datamodule import AbiToa3DCloudDataModule
from .modis_toa_mim_datamodule import ModisToaMimDataModule


DATAMODULES = {
'abitoa3dcloud': AbiToa3DCloudDataModule,
'modistoamimpretrain': ModisToaMimDataModule,
}


def get_available_datamodules():
return {name: cls for name, cls in DATAMODULES.items()}
return {name: cls for name, cls in DATAMODULES.items()}
32 changes: 28 additions & 4 deletions pytorch_caney/datamodules/abi_3dcloud_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from pytorch_caney.transforms.abi_toa import AbiToaTransform


# -----------------------------------------------------------------------------
# AbiToa3DCloudDataModule
# -----------------------------------------------------------------------------
class AbiToa3DCloudDataModule(L.LightningDataModule):
"""NonGeo ABI TOA 3D cloud data module implementation"""

# -------------------------------------------------------------------------
# __init__
# -------------------------------------------------------------------------
def __init__(
self,
config,
Expand All @@ -21,6 +27,9 @@ def __init__(
self.batch_size = config.DATA.BATCH_SIZE
self.num_workers = config.DATA.NUM_WORKERS

# -------------------------------------------------------------------------
# setup
# -------------------------------------------------------------------------
def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = AbiToa3DCloudDataset(
Expand All @@ -40,12 +49,27 @@ def setup(self, stage: str) -> None:
self.test_data_paths,
self.transform,
)


# -------------------------------------------------------------------------
# train_dataloader
# -------------------------------------------------------------------------
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers)

# -------------------------------------------------------------------------
# val_dataloader
# -------------------------------------------------------------------------
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers)

# -------------------------------------------------------------------------
# test_dataloader
# -------------------------------------------------------------------------
def test_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers)
4 changes: 2 additions & 2 deletions pytorch_caney/datamodules/modis_toa_mim_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, config,) -> None:
self.pin_memory = config.DATA.PIN_MEMORY

# -------------------------------------------------------------------------
# setup
# setup
# -------------------------------------------------------------------------
def setup(self, stage: str) -> None:
if stage in ["fit"]:
Expand All @@ -40,7 +40,7 @@ def setup(self, stage: str) -> None:
batch_size=self.batch_size).dataset()

# -------------------------------------------------------------------------
# train_dataloader
# train_dataloader
# -------------------------------------------------------------------------
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset,
Expand Down
26 changes: 22 additions & 4 deletions pytorch_caney/datasets/abi_3dcloud_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@
from torchgeo.datasets import NonGeoDataset


# -----------------------------------------------------------------------------
# AbiToa3DCloudDataModule
# -----------------------------------------------------------------------------
class AbiToa3DCloudDataset(NonGeoDataset):


# -------------------------------------------------------------------------
# __init__
# -------------------------------------------------------------------------
def __init__(self, config, data_paths: list, transform=None) -> None:

super().__init__()

self.config = config
self.data_paths = data_paths
self.data_paths = data_paths
self.transform = transform
self.img_size = config.DATA.IMG_SIZE

Expand All @@ -27,9 +33,15 @@ def __init__(self, config, data_paths: list, transform=None) -> None:

self.rgb_indices = [0, 1, 2]

# -------------------------------------------------------------------------
# __len__
# -------------------------------------------------------------------------
def __len__(self) -> int:
return len(self.image_list)

# -------------------------------------------------------------------------
# __getitem__
# -------------------------------------------------------------------------
def __getitem__(self, index: int) -> Dict[str, Any]:

npz_array = self._load_file(self.image_list[index])
Expand All @@ -39,16 +51,22 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
if self.transform is not None:
image = self.transform(image)

return image, mask
return image, mask

# -------------------------------------------------------------------------
# _load_file
# -------------------------------------------------------------------------
def _load_file(self, path: Path):
if Path(path).suffix == '.npy' or Path(path).suffix == '.npz':
return np.load(path, allow_pickle=True)
elif Path(path).suffix == '.tif':
return rxr.open_rasterio(path)
else:
raise RuntimeError('Non-recognized dataset format. Expects npy or tif.')
raise RuntimeError('Non-recognized dataset format. Expects npy or tif.') # noqa: E501

# -------------------------------------------------------------------------
# get_filenames
# -------------------------------------------------------------------------
def get_filenames(self, path):
"""
Returns a list of absolute paths to images inside given `path`
Expand Down
8 changes: 4 additions & 4 deletions pytorch_caney/datasets/sharded_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def nodesplitter(src, group=None):
if i % size == rank:
yield item
count += 1
logging.info(f"nodesplitter: rank={rank} size={size} " + \
logging.info(f"nodesplitter: rank={rank} size={size} " +
f"count={count} DONE")
else:
yield from src
Expand All @@ -34,7 +34,7 @@ def nodesplitter(src, group=None):
# -----------------------------------------------------------------------------
class ShardedDataset(object):
"""
Base pre-training webdataset
Base pre-training webdataset
"""

SHARD_PATH = os.path.join("shards")
Expand All @@ -53,7 +53,7 @@ def __init__(
batch_size=64,
):

self.random_state = 1000
self.random_state = 1000
self.config = config
self.img_size = img_size
self.transform = transform
Expand Down Expand Up @@ -87,4 +87,4 @@ def dataset(self):
.with_length(self.length)
)

return dataset
return dataset
10 changes: 4 additions & 6 deletions pytorch_caney/models/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import torch.nn as nn

# -----------------------------------------------------------------------------
# ModelFactory
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -38,7 +36,7 @@ def register_head(cls, name: str, head_cls):
# -------------------------------------------------------------------------
@classmethod
def get_component(cls, component_type: str, name: str, **kwargs):
"""Public method to retrieve and instantiate a component by type and name."""
"""Public method to retrieve and instantiate a component by type and name.""" # noqa: E501
print(cls.backbones)
print(cls.decoders)
print(cls.heads)
Expand All @@ -49,8 +47,8 @@ def get_component(cls, component_type: str, name: str, **kwargs):
}.get(component_type)

if registry is None or name not in registry:
raise ValueError(f"{component_type.capitalize()} '{name}' not found in registry.")
raise ValueError(f"{component_type.capitalize()} '{name}' not found in registry.") # noqa: E501

return registry[name](**kwargs)

# -------------------------------------------------------------------------
Expand Down Expand Up @@ -84,4 +82,4 @@ def head(cls, name):
def decorator(head_cls):
cls.register_head(name, head_cls)
return head_cls
return decorator
return decorator
8 changes: 4 additions & 4 deletions pytorch_caney/ptc_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os

from lightning.pytorch import Trainer
from lightning.pytorch import Trainer

from pytorch_caney.configs.config import _C, _update_config_from_file
from pytorch_caney.utils import get_strategy, get_distributed_train_batches
Expand Down Expand Up @@ -42,7 +42,8 @@ def main(config, output_dir):
)

if config.TRAIN.LIMIT_TRAIN_BATCHES:
trainer.limit_train_batches = get_distributed_train_batches(config, trainer)
trainer.limit_train_batches = get_distributed_train_batches(
config, trainer)

if config.DATA.DATAMODULE:
available_datamodules = get_available_datamodules()
Expand All @@ -53,13 +54,12 @@ def main(config, output_dir):
trainer.fit(model=ptlPipeline, datamodule=datamodule)

else:
print(f'Training without datamodule, assuming data is set in pipeline: {ptlPipeline}')
print(f'Training without datamodule, assuming data is set in pipeline: {ptlPipeline}') # noqa: E501
trainer.fit(model=ptlPipeline)


if __name__ == "__main__":


parser = argparse.ArgumentParser()

parser.add_argument(
Expand Down
Loading