From e0a691200693ceb262e80f67776bdaacf1f0f4f1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 19 Dec 2024 22:37:19 +0100 Subject: [PATCH] Add TorchGeo CLI tutorial (#2479) * Add TorchGeo CLI tutorial * Pass checkpoint path --- docs/tutorials/basic_usage.rst | 2 + docs/tutorials/cli.ipynb | 292 +++++++++++++++++++++++++++++++++ 2 files changed, 294 insertions(+) create mode 100644 docs/tutorials/cli.ipynb diff --git a/docs/tutorials/basic_usage.rst b/docs/tutorials/basic_usage.rst index 57848bede19..51eb77a85ca 100644 --- a/docs/tutorials/basic_usage.rst +++ b/docs/tutorials/basic_usage.rst @@ -7,6 +7,7 @@ The following tutorials introduce the basic concepts and components of TorchGeo: * `Spectral Indices `_: Visualizing and appending spectral indices * `Pretrained Weights `_: Models and pretrained weights * `Lightning Trainers `_: PyTorch Lightning data modules and trainers +* `Command-Line Interface `_: TorchGeo's command-line interface .. toctree:: :hidden: @@ -16,3 +17,4 @@ The following tutorials introduce the basic concepts and components of TorchGeo: indices pretrained_weights trainers + cli diff --git a/docs/tutorials/cli.ipynb b/docs/tutorials/cli.ipynb new file mode 100644 index 00000000000..1424b13291a --- /dev/null +++ b/docs/tutorials/cli.ipynb @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "16421d50-8d7a-4972-b06f-160fd890cc86", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "id": "e563313d", + "metadata": {}, + "source": [ + "# Command-Line Interface\n", + "\n", + "_Written by: Adam J. Stewart_\n", + "\n", + "TorchGeo provides a command-line interface based on [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface." + ] + }, + { + "cell_type": "markdown", + "id": "8c1f4156", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, we install TorchGeo. In addition to the Python library, this also installs a `torchgeo` executable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f0d31a8", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "cell_type": "markdown", + "id": "7801ab8b-0ee3-40ac-88c2-4bdc29bb4e1b", + "metadata": {}, + "source": [ + "## Subcommands\n", + "\n", + "The `torchgeo` command has a number of *subcommands* that can be run. The `--help` flag can be used to list them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6ccac4e-7f20-4aa8-b851-27234ffd259f", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo --help" + ] + }, + { + "cell_type": "markdown", + "id": "19ee017d-0d8f-41c6-8e7c-68495c7e62b6", + "metadata": {}, + "source": [ + "## Trainer\n", + "\n", + "Below, we run `--help` on the `fit` subcommand to see what options are available to us. `fit` is used to train and validate a model, and we can customize many aspects of the training process." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afe1dc9d-4cee-43b0-ae30-200c64d3401a", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo fit --help" + ] + }, + { + "cell_type": "markdown", + "id": "b437860c-b406-4150-b30b-8aa895eebfcd", + "metadata": {}, + "source": [ + "## Model\n", + "\n", + "We must first select an `nn.Module` model architecture to train and a `lightning.pytorch.LightningModule` trainer to train it. We will experiment with the `ClassificationTask` trainer and see what options we can customize. Any of TorchGeo's builtin trainers, or trainers written by the user, can be used in this way." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cd9bbd0-17c9-4e87-b10d-ea846c39bc24", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo fit --model.help ClassificationTask" + ] + }, + { + "cell_type": "markdown", + "id": "3daacd8d-64f4-4357-bdf3-759295a14224", + "metadata": {}, + "source": [ + "## Data\n", + "\n", + "We must also select a `Dataset` we would like to train on and a `lightning.pytorch.LightningDataModule` we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the `--help` flag to see what options are available for the `EuroSAT100` dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "136eb59f-6662-44af-82e9-c55bdb3f17ac", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo fit --data.help EuroSAT100DataModule" + ] + }, + { + "cell_type": "markdown", + "id": "8039cb67-ee18-4b41-8bf5-0e939493f5bb", + "metadata": {}, + "source": [ + "## Config\n", + "\n", + "Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e25c8efb-ed8c-4795-862c-bfb84cc84e1f", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", + "config = f\"\"\"\n", + "trainer:\n", + " max_epochs: 1\n", + " default_root_dir: '{root}'\n", + "model:\n", + " class_path: ClassificationTask\n", + " init_args:\n", + " model: 'resnet18'\n", + " in_channels: 13\n", + " num_classes: 10\n", + "data:\n", + " class_path: EuroSAT100DataModule\n", + " init_args:\n", + " batch_size: 8\n", + " dict_kwargs:\n", + " root: '{root}'\n", + " download: true\n", + "\"\"\"\n", + "os.makedirs(root, exist_ok=True)\n", + "with open(os.path.join(root, 'config.yaml'), 'w') as f:\n", + " f.write(config)" + ] + }, + { + "cell_type": "markdown", + "id": "a661b8d7-2dc9-4a30-8842-bd52d130e080", + "metadata": {}, + "source": [ + "This YAML file has three sections:\n", + "\n", + "* trainer: Arguments to pass to the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)\n", + "* model: Arguments to pass to the task\n", + "* data: Arguments to pass to the data module\n", + "\n", + "The `class_path` gives the class to instantiate, `init_args` lists standard arguments, and `dict_kwargs` lists keyword arguments." + ] + }, + { + "cell_type": "markdown", + "id": "e132f933-4edf-42bb-b585-e0d8ceb65eab", + "metadata": {}, + "source": [ + "## Training\n", + "\n", + "We can now train our model like so." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f84b0739-c9e7-4057-8864-98ab69a11f64", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo fit --config {root}/config.yaml" + ] + }, + { + "cell_type": "markdown", + "id": "cb1557f1-6cc0-46da-909c-836911acb248", + "metadata": {}, + "source": [ + "## Validation\n", + "\n", + "Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9cbb4f4-1879-4ae7-bae4-2c24d49a4a61", + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "\n", + "checkpoint = glob.glob(\n", + " os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')\n", + ")[0]\n", + "\n", + "!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}" + ] + }, + { + "cell_type": "markdown", + "id": "ba816fc3-5cac-4cbc-a6ef-effc6c9faa61", + "metadata": {}, + "source": [ + "## Testing\n", + "\n", + "After finishing our hyperparameter tuning, we can calculate and report the final test performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1faa997-9f81-4847-94fc-5a8bb7687369", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}" + ] + }, + { + "cell_type": "markdown", + "id": "f5383d30-8f76-44a2-8366-e6fcbd1e6042", + "metadata": {}, + "source": [ + "## Additional Reading\n", + "\n", + "Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:\n", + "\n", + "* [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "execution": { + "timeout": 1200 + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}