Skip to content

Commit

Permalink
Adding some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Sep 18, 2024
1 parent b10278d commit cc10e00
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions docs/tutorials/custom_segmentation_trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install torchgeo[datasets]"
"%pip install torchgeo"
]
},
{
Expand Down Expand Up @@ -198,7 +198,7 @@
"source": [
"## Train model\n",
"\n",
"The remainder of the turial is straightforward and follows the typical [PyTorch Lightning](https://lightning.ai/) training routine. We instantiate a `DataModule` for the LandCover.AI dataset, instantiate a `CustomSemanticSegmentationTask` with a U-Net and ResNet-50 backbone, then train the model using a Lightning trainer."
"The remainder of the turial is straightforward and follows the typical [PyTorch Lightning](https://lightning.ai/) training routine. We instantiate a `DataModule` for the LandCover.AI 100 dataset (a small version of the LandCover.AI dataset for notebook testing), instantiate a `CustomSemanticSegmentationTask` with a U-Net and ResNet-50 backbone, then train the model using a Lightning trainer."
]
},
{
Expand All @@ -207,7 +207,11 @@
"metadata": {},
"outputs": [],
"source": [
"dm = LandCoverAI100DataModule(root='data/', batch_size=64, num_workers=8, download=True)"
"dm = LandCoverAI100DataModule(root='data/', batch_size=64, num_workers=8, download=True)\n",
"\n",
"# You can use the following for actual training runs\n",
"# from torchgeo.datamodules import LandCoverAIDataModule\n",
"# dm = LandCoverAIDataModule(root='data/', batch_size=64, num_workers=8, download=True)"
]
},
{
Expand Down Expand Up @@ -389,6 +393,8 @@
"outputs": [],
"source": [
"# You can load directly from a saved checkpoint with `.load_from_checkpoint(...)`\n",
"# Note that you can also just call `trainer.test(task, dm)` if you've already trained\n",
"# the model in the current notebook session.\n",
"task = CustomSemanticSegmentationTask.load_from_checkpoint(\n",
" 'lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt'\n",
")"
Expand Down

0 comments on commit cc10e00

Please sign in to comment.