Skip to content

Commit

Permalink
Add TorchGeo CLI tutorial (#2479)
Browse files Browse the repository at this point in the history
* Add TorchGeo CLI tutorial

* Pass checkpoint path
  • Loading branch information
adamjstewart authored Dec 19, 2024
1 parent f87d686 commit 7bcb6d4
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/tutorials/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The following tutorials introduce the basic concepts and components of TorchGeo:
* `Spectral Indices <indices.ipynb>`_: Visualizing and appending spectral indices
* `Pretrained Weights <pretrained_weights.ipynb>`_: Models and pretrained weights
* `Lightning Trainers <trainers.ipynb>`_: PyTorch Lightning data modules and trainers
* `Command-Line Interface <cli.ipynb>`_: TorchGeo's command-line interface

.. toctree::
:hidden:
Expand All @@ -16,3 +17,4 @@ The following tutorials introduce the basic concepts and components of TorchGeo:
indices
pretrained_weights
trainers
cli
292 changes: 292 additions & 0 deletions docs/tutorials/cli.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "16421d50-8d7a-4972-b06f-160fd890cc86",
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Microsoft Corporation. All rights reserved.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"id": "e563313d",
"metadata": {},
"source": [
"# Command-Line Interface\n",
"\n",
"_Written by: Adam J. Stewart_\n",
"\n",
"TorchGeo provides a command-line interface based on [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface."
]
},
{
"cell_type": "markdown",
"id": "8c1f4156",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, we install TorchGeo. In addition to the Python library, this also installs a `torchgeo` executable."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f0d31a8",
"metadata": {},
"outputs": [],
"source": [
"%pip install torchgeo"
]
},
{
"cell_type": "markdown",
"id": "7801ab8b-0ee3-40ac-88c2-4bdc29bb4e1b",
"metadata": {},
"source": [
"## Subcommands\n",
"\n",
"The `torchgeo` command has a number of *subcommands* that can be run. The `--help` flag can be used to list them."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6ccac4e-7f20-4aa8-b851-27234ffd259f",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo --help"
]
},
{
"cell_type": "markdown",
"id": "19ee017d-0d8f-41c6-8e7c-68495c7e62b6",
"metadata": {},
"source": [
"## Trainer\n",
"\n",
"Below, we run `--help` on the `fit` subcommand to see what options are available to us. `fit` is used to train and validate a model, and we can customize many aspects of the training process."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afe1dc9d-4cee-43b0-ae30-200c64d3401a",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --help"
]
},
{
"cell_type": "markdown",
"id": "b437860c-b406-4150-b30b-8aa895eebfcd",
"metadata": {},
"source": [
"## Model\n",
"\n",
"We must first select an `nn.Module` model architecture to train and a `lightning.pytorch.LightningModule` trainer to train it. We will experiment with the `ClassificationTask` trainer and see what options we can customize. Any of TorchGeo's builtin trainers, or trainers written by the user, can be used in this way."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cd9bbd0-17c9-4e87-b10d-ea846c39bc24",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --model.help ClassificationTask"
]
},
{
"cell_type": "markdown",
"id": "3daacd8d-64f4-4357-bdf3-759295a14224",
"metadata": {},
"source": [
"## Data\n",
"\n",
"We must also select a `Dataset` we would like to train on and a `lightning.pytorch.LightningDataModule` we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the `--help` flag to see what options are available for the `EuroSAT100` dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "136eb59f-6662-44af-82e9-c55bdb3f17ac",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --data.help EuroSAT100DataModule"
]
},
{
"cell_type": "markdown",
"id": "8039cb67-ee18-4b41-8bf5-0e939493f5bb",
"metadata": {},
"source": [
"## Config\n",
"\n",
"Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e25c8efb-ed8c-4795-862c-bfb84cc84e1f",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import tempfile\n",
"\n",
"root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n",
"config = f\"\"\"\n",
"trainer:\n",
" max_epochs: 1\n",
" default_root_dir: '{root}'\n",
"model:\n",
" class_path: ClassificationTask\n",
" init_args:\n",
" model: 'resnet18'\n",
" in_channels: 13\n",
" num_classes: 10\n",
"data:\n",
" class_path: EuroSAT100DataModule\n",
" init_args:\n",
" batch_size: 8\n",
" dict_kwargs:\n",
" root: '{root}'\n",
" download: true\n",
"\"\"\"\n",
"os.makedirs(root, exist_ok=True)\n",
"with open(os.path.join(root, 'config.yaml'), 'w') as f:\n",
" f.write(config)"
]
},
{
"cell_type": "markdown",
"id": "a661b8d7-2dc9-4a30-8842-bd52d130e080",
"metadata": {},
"source": [
"This YAML file has three sections:\n",
"\n",
"* trainer: Arguments to pass to the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)\n",
"* model: Arguments to pass to the task\n",
"* data: Arguments to pass to the data module\n",
"\n",
"The `class_path` gives the class to instantiate, `init_args` lists standard arguments, and `dict_kwargs` lists keyword arguments."
]
},
{
"cell_type": "markdown",
"id": "e132f933-4edf-42bb-b585-e0d8ceb65eab",
"metadata": {},
"source": [
"## Training\n",
"\n",
"We can now train our model like so."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f84b0739-c9e7-4057-8864-98ab69a11f64",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --config {root}/config.yaml"
]
},
{
"cell_type": "markdown",
"id": "cb1557f1-6cc0-46da-909c-836911acb248",
"metadata": {},
"source": [
"## Validation\n",
"\n",
"Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9cbb4f4-1879-4ae7-bae4-2c24d49a4a61",
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"\n",
"checkpoint = glob.glob(\n",
" os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')\n",
")[0]\n",
"\n",
"!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}"
]
},
{
"cell_type": "markdown",
"id": "ba816fc3-5cac-4cbc-a6ef-effc6c9faa61",
"metadata": {},
"source": [
"## Testing\n",
"\n",
"After finishing our hyperparameter tuning, we can calculate and report the final test performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1faa997-9f81-4847-94fc-5a8bb7687369",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}"
]
},
{
"cell_type": "markdown",
"id": "f5383d30-8f76-44a2-8366-e6fcbd1e6042",
"metadata": {},
"source": [
"## Additional Reading\n",
"\n",
"Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:\n",
"\n",
"* [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"execution": {
"timeout": 1200
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 7bcb6d4

Please sign in to comment.