Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TorchGeo CLI tutorial #2479

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/tutorials/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The following tutorials introduce the basic concepts and components of TorchGeo:
* `Indices <indices.ipynb>`_: Spectral indices
* `Pretrained Weights <pretrained_weights.ipynb>`_: Models and pretrained weights
* `Lightning Trainers <trainers.ipynb>`_: PyTorch Lightning data modules and trainers
* `Command-Line Interface <cli.ipynb>`_: TorchGeo's command-line interface

.. toctree::
:hidden:
Expand All @@ -16,3 +17,4 @@ The following tutorials introduce the basic concepts and components of TorchGeo:
indices
pretrained_weights
trainers
cli
292 changes: 292 additions & 0 deletions docs/tutorials/cli.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "16421d50-8d7a-4972-b06f-160fd890cc86",
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Microsoft Corporation. All rights reserved.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"id": "e563313d",
"metadata": {},
"source": [
"# Command-Line Interface\n",
"\n",
"_Written by: Adam J. Stewart_\n",
"\n",
"TorchGeo provides a command-line interface based on [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface."
]
},
{
"cell_type": "markdown",
"id": "8c1f4156",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, we install TorchGeo. In addition to the Python library, this also installs a `torchgeo` executable."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f0d31a8",
"metadata": {},
"outputs": [],
"source": [
"%pip install torchgeo"
]
},
{
"cell_type": "markdown",
"id": "7801ab8b-0ee3-40ac-88c2-4bdc29bb4e1b",
"metadata": {},
"source": [
"## Subcommands\n",
"\n",
"The `torchgeo` command has a number of *subcommands* that can be run. The `--help` flag can be used to list them."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6ccac4e-7f20-4aa8-b851-27234ffd259f",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo --help"
]
},
{
"cell_type": "markdown",
"id": "19ee017d-0d8f-41c6-8e7c-68495c7e62b6",
"metadata": {},
"source": [
"## Trainer\n",
"\n",
"Below, we run `--help` on the `fit` subcommand to see what options are available to us. `fit` is used to train and validate a model, and we can customize many aspects of the training process."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afe1dc9d-4cee-43b0-ae30-200c64d3401a",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --help"
]
},
{
"cell_type": "markdown",
"id": "b437860c-b406-4150-b30b-8aa895eebfcd",
"metadata": {},
"source": [
"## Model\n",
"\n",
"We must first select an `nn.Module` model architecture to train and a `lightning.pytorch.LightningModule` trainer to train it. We will experiment with the `ClassificationTask` trainer and see what options we can customize. Any of TorchGeo's builtin trainers, or trainers written by the user, can be used in this way."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cd9bbd0-17c9-4e87-b10d-ea846c39bc24",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --model.help ClassificationTask"
]
},
{
"cell_type": "markdown",
"id": "3daacd8d-64f4-4357-bdf3-759295a14224",
"metadata": {},
"source": [
"## Data\n",
"\n",
"We must also select a `Dataset` we would like to train on and a `lightning.pytorch.LightningDataModule` we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the `--help` flag to see what options are available for the `EuroSAT100` dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "136eb59f-6662-44af-82e9-c55bdb3f17ac",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --data.help EuroSAT100DataModule"
]
},
{
"cell_type": "markdown",
"id": "8039cb67-ee18-4b41-8bf5-0e939493f5bb",
"metadata": {},
"source": [
"## Config\n",
"\n",
"Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e25c8efb-ed8c-4795-862c-bfb84cc84e1f",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import tempfile\n",
"\n",
"root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n",
"config = f\"\"\"\n",
"trainer:\n",
" max_epochs: 1\n",
" default_root_dir: '{root}'\n",
"model:\n",
" class_path: ClassificationTask\n",
" init_args:\n",
" model: 'resnet18'\n",
" in_channels: 13\n",
" num_classes: 10\n",
"data:\n",
" class_path: EuroSAT100DataModule\n",
" init_args:\n",
" batch_size: 8\n",
" dict_kwargs:\n",
" root: '{root}'\n",
" download: true\n",
"\"\"\"\n",
"os.makedirs(root, exist_ok=True)\n",
"with open(os.path.join(root, 'config.yaml'), 'w') as f:\n",
" f.write(config)"
]
},
{
"cell_type": "markdown",
"id": "a661b8d7-2dc9-4a30-8842-bd52d130e080",
"metadata": {},
"source": [
"This YAML file has three sections:\n",
"\n",
"* trainer: Arguments to pass to the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)\n",
"* model: Arguments to pass to the task\n",
"* data: Arguments to pass to the data module\n",
"\n",
"The `class_path` gives the class to instantiate, `init_args` lists standard arguments, and `dict_kwargs` lists keyword arguments."
]
},
{
"cell_type": "markdown",
"id": "e132f933-4edf-42bb-b585-e0d8ceb65eab",
"metadata": {},
"source": [
"## Training\n",
"\n",
"We can now train our model like so."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f84b0739-c9e7-4057-8864-98ab69a11f64",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --config {root}/config.yaml"
]
},
{
"cell_type": "markdown",
"id": "cb1557f1-6cc0-46da-909c-836911acb248",
"metadata": {},
"source": [
"## Validation\n",
"\n",
"Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9cbb4f4-1879-4ae7-bae4-2c24d49a4a61",
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"\n",
"checkpoint = glob.glob(\n",
" os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')\n",
")[0]\n",
"\n",
"!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}"
]
},
{
"cell_type": "markdown",
"id": "ba816fc3-5cac-4cbc-a6ef-effc6c9faa61",
"metadata": {},
"source": [
"## Testing\n",
"\n",
"After finishing our hyperparameter tuning, we can calculate and report the final test performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1faa997-9f81-4847-94fc-5a8bb7687369",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}"
]
},
{
"cell_type": "markdown",
"id": "f5383d30-8f76-44a2-8366-e6fcbd1e6042",
"metadata": {},
"source": [
"## Additional Reading\n",
"\n",
"Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:\n",
"\n",
"* [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"execution": {
"timeout": 1200
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading