-
Notifications
You must be signed in to change notification settings - Fork 386
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add TorchGeo CLI tutorial * Pass checkpoint path
- Loading branch information
1 parent
f87d686
commit 7bcb6d4
Showing
2 changed files
with
294 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "16421d50-8d7a-4972-b06f-160fd890cc86", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Copyright (c) Microsoft Corporation. All rights reserved.\n", | ||
"# Licensed under the MIT License." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e563313d", | ||
"metadata": {}, | ||
"source": [ | ||
"# Command-Line Interface\n", | ||
"\n", | ||
"_Written by: Adam J. Stewart_\n", | ||
"\n", | ||
"TorchGeo provides a command-line interface based on [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8c1f4156", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup\n", | ||
"\n", | ||
"First, we install TorchGeo. In addition to the Python library, this also installs a `torchgeo` executable." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3f0d31a8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install torchgeo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7801ab8b-0ee3-40ac-88c2-4bdc29bb4e1b", | ||
"metadata": {}, | ||
"source": [ | ||
"## Subcommands\n", | ||
"\n", | ||
"The `torchgeo` command has a number of *subcommands* that can be run. The `--help` flag can be used to list them." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a6ccac4e-7f20-4aa8-b851-27234ffd259f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!torchgeo --help" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "19ee017d-0d8f-41c6-8e7c-68495c7e62b6", | ||
"metadata": {}, | ||
"source": [ | ||
"## Trainer\n", | ||
"\n", | ||
"Below, we run `--help` on the `fit` subcommand to see what options are available to us. `fit` is used to train and validate a model, and we can customize many aspects of the training process." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "afe1dc9d-4cee-43b0-ae30-200c64d3401a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!torchgeo fit --help" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b437860c-b406-4150-b30b-8aa895eebfcd", | ||
"metadata": {}, | ||
"source": [ | ||
"## Model\n", | ||
"\n", | ||
"We must first select an `nn.Module` model architecture to train and a `lightning.pytorch.LightningModule` trainer to train it. We will experiment with the `ClassificationTask` trainer and see what options we can customize. Any of TorchGeo's builtin trainers, or trainers written by the user, can be used in this way." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7cd9bbd0-17c9-4e87-b10d-ea846c39bc24", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!torchgeo fit --model.help ClassificationTask" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3daacd8d-64f4-4357-bdf3-759295a14224", | ||
"metadata": {}, | ||
"source": [ | ||
"## Data\n", | ||
"\n", | ||
"We must also select a `Dataset` we would like to train on and a `lightning.pytorch.LightningDataModule` we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the `--help` flag to see what options are available for the `EuroSAT100` dataset." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "136eb59f-6662-44af-82e9-c55bdb3f17ac", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!torchgeo fit --data.help EuroSAT100DataModule" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8039cb67-ee18-4b41-8bf5-0e939493f5bb", | ||
"metadata": {}, | ||
"source": [ | ||
"## Config\n", | ||
"\n", | ||
"Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e25c8efb-ed8c-4795-862c-bfb84cc84e1f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import tempfile\n", | ||
"\n", | ||
"root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", | ||
"config = f\"\"\"\n", | ||
"trainer:\n", | ||
" max_epochs: 1\n", | ||
" default_root_dir: '{root}'\n", | ||
"model:\n", | ||
" class_path: ClassificationTask\n", | ||
" init_args:\n", | ||
" model: 'resnet18'\n", | ||
" in_channels: 13\n", | ||
" num_classes: 10\n", | ||
"data:\n", | ||
" class_path: EuroSAT100DataModule\n", | ||
" init_args:\n", | ||
" batch_size: 8\n", | ||
" dict_kwargs:\n", | ||
" root: '{root}'\n", | ||
" download: true\n", | ||
"\"\"\"\n", | ||
"os.makedirs(root, exist_ok=True)\n", | ||
"with open(os.path.join(root, 'config.yaml'), 'w') as f:\n", | ||
" f.write(config)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a661b8d7-2dc9-4a30-8842-bd52d130e080", | ||
"metadata": {}, | ||
"source": [ | ||
"This YAML file has three sections:\n", | ||
"\n", | ||
"* trainer: Arguments to pass to the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)\n", | ||
"* model: Arguments to pass to the task\n", | ||
"* data: Arguments to pass to the data module\n", | ||
"\n", | ||
"The `class_path` gives the class to instantiate, `init_args` lists standard arguments, and `dict_kwargs` lists keyword arguments." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e132f933-4edf-42bb-b585-e0d8ceb65eab", | ||
"metadata": {}, | ||
"source": [ | ||
"## Training\n", | ||
"\n", | ||
"We can now train our model like so." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f84b0739-c9e7-4057-8864-98ab69a11f64", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!torchgeo fit --config {root}/config.yaml" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cb1557f1-6cc0-46da-909c-836911acb248", | ||
"metadata": {}, | ||
"source": [ | ||
"## Validation\n", | ||
"\n", | ||
"Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b9cbb4f4-1879-4ae7-bae4-2c24d49a4a61", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import glob\n", | ||
"\n", | ||
"checkpoint = glob.glob(\n", | ||
" os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')\n", | ||
")[0]\n", | ||
"\n", | ||
"!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ba816fc3-5cac-4cbc-a6ef-effc6c9faa61", | ||
"metadata": {}, | ||
"source": [ | ||
"## Testing\n", | ||
"\n", | ||
"After finishing our hyperparameter tuning, we can calculate and report the final test performance." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f1faa997-9f81-4847-94fc-5a8bb7687369", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f5383d30-8f76-44a2-8366-e6fcbd1e6042", | ||
"metadata": {}, | ||
"source": [ | ||
"## Additional Reading\n", | ||
"\n", | ||
"Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:\n", | ||
"\n", | ||
"* [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"execution": { | ||
"timeout": 1200 | ||
}, | ||
"gpuClass": "standard", | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.13.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |