From cc10e00120d84795fa34260ca128b67bc8e49a47 Mon Sep 17 00:00:00 2001 From: davrob Date: Wed, 18 Sep 2024 20:00:16 +0000 Subject: [PATCH] Adding some comments --- docs/tutorials/custom_segmentation_trainer.ipynb | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/custom_segmentation_trainer.ipynb b/docs/tutorials/custom_segmentation_trainer.ipynb index f7ecd6b94e3..b7ab002325c 100644 --- a/docs/tutorials/custom_segmentation_trainer.ipynb +++ b/docs/tutorials/custom_segmentation_trainer.ipynb @@ -38,7 +38,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install torchgeo[datasets]" + "%pip install torchgeo" ] }, { @@ -198,7 +198,7 @@ "source": [ "## Train model\n", "\n", - "The remainder of the turial is straightforward and follows the typical [PyTorch Lightning](https://lightning.ai/) training routine. We instantiate a `DataModule` for the LandCover.AI dataset, instantiate a `CustomSemanticSegmentationTask` with a U-Net and ResNet-50 backbone, then train the model using a Lightning trainer." + "The remainder of the turial is straightforward and follows the typical [PyTorch Lightning](https://lightning.ai/) training routine. We instantiate a `DataModule` for the LandCover.AI 100 dataset (a small version of the LandCover.AI dataset for notebook testing), instantiate a `CustomSemanticSegmentationTask` with a U-Net and ResNet-50 backbone, then train the model using a Lightning trainer." ] }, { @@ -207,7 +207,11 @@ "metadata": {}, "outputs": [], "source": [ - "dm = LandCoverAI100DataModule(root='data/', batch_size=64, num_workers=8, download=True)" + "dm = LandCoverAI100DataModule(root='data/', batch_size=64, num_workers=8, download=True)\n", + "\n", + "# You can use the following for actual training runs\n", + "# from torchgeo.datamodules import LandCoverAIDataModule\n", + "# dm = LandCoverAIDataModule(root='data/', batch_size=64, num_workers=8, download=True)" ] }, { @@ -389,6 +393,8 @@ "outputs": [], "source": [ "# You can load directly from a saved checkpoint with `.load_from_checkpoint(...)`\n", + "# Note that you can also just call `trainer.test(task, dm)` if you've already trained\n", + "# the model in the current notebook session.\n", "task = CustomSemanticSegmentationTask.load_from_checkpoint(\n", " 'lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt'\n", ")"