From ff4002d343ce8358be6eb4cb05b127782f20b77b Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 19 Jan 2024 18:02:56 -0800 Subject: [PATCH] Clean up `ns-train {method} --help` for not-yet-installed external methods (#2760) * Clean up `ns-train {method} --help` for not-yet-installed external methods * Ruff * Ruff * add clearer print statement * Types? Not sure if this fixes it --------- Co-authored-by: Justin Kerr --- nerfstudio/configs/external_methods.py | 53 ++++++++++++++------------ nerfstudio/configs/method_configs.py | 6 +-- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/nerfstudio/configs/external_methods.py b/nerfstudio/configs/external_methods.py index d530cf5bc6..86f237d5a4 100644 --- a/nerfstudio/configs/external_methods.py +++ b/nerfstudio/configs/external_methods.py @@ -14,14 +14,15 @@ """This file contains the configuration for external methods which are not included in this repository.""" +import inspect import subprocess import sys -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, cast +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple +import tyro from rich.prompt import Confirm -from nerfstudio.engine.trainer import TrainerConfig from nerfstudio.utils.rich_utils import CONSOLE @@ -177,21 +178,30 @@ class ExternalMethod: @dataclass -class ExternalMethodTrainerConfig(TrainerConfig): - """ - Trainer config for external methods which does not have an implementation in this repository. +class ExternalMethodDummyTrainerConfig: + """Dummy trainer config for external methods (a) which do not have an + implementation in this repository, and (b) are not yet installed. When this + config is instantiated, we give the user the option to install the method. """ - _method: ExternalMethod = field(default=cast(ExternalMethod, None)) + # tyro.conf.Suppress will prevent these fields from appearing as CLI arguments. + method_name: tyro.conf.Suppress[str] + method: tyro.conf.Suppress[ExternalMethod] + + def __post_init__(self): + """Offer to install an external method.""" - def handle_print_information(self, *_args, **_kwargs): - """Prints the method information and exits.""" - CONSOLE.print(self._method.instructions) - if self._method.pip_package and Confirm.ask( + # Don't trigger install message from get_external_methods() below; only + # if this dummy object is instantiated from the CLI. + if inspect.stack()[2].function == "get_external_methods": + return + + CONSOLE.print(self.method.instructions) + if self.method.pip_package and Confirm.ask( "\nWould you like to run the install it now?", default=False, console=CONSOLE ): # Install the method - install_command = f"{sys.executable} -m pip install {self._method.pip_package}" + install_command = f"{sys.executable} -m pip install {self.method.pip_package}" CONSOLE.print(f"Running: [cyan]{install_command}[/cyan]") result = subprocess.run(install_command, shell=True, check=False) if result.returncode != 0: @@ -200,20 +210,15 @@ def handle_print_information(self, *_args, **_kwargs): sys.exit(0) - def __getattribute__(self, __name: str) -> Any: - out = object.__getattribute__(self, __name) - if callable(out) and __name not in {"handle_print_information"} and not __name.startswith("__"): - # We exit early, displaying the message - return self.handle_print_information - return out - -def get_external_methods() -> Tuple[Dict[str, TrainerConfig], Dict[str, str]]: +def get_external_methods() -> Tuple[Dict[str, ExternalMethodDummyTrainerConfig], Dict[str, str]]: """Returns the external methods trainer configs and the descriptions.""" - method_configs = {} - descriptions = {} + method_configs: Dict[str, ExternalMethodDummyTrainerConfig] = {} + descriptions: Dict[str, str] = {} for external_method in external_methods: for config_slug, config_description in external_method.configurations: - method_configs[config_slug] = ExternalMethodTrainerConfig(method_name=config_slug, _method=external_method) - descriptions[config_slug] = f"""[External] {config_description}""" + method_configs[config_slug] = ExternalMethodDummyTrainerConfig( + method_name=config_slug, method=external_method + ) + descriptions[config_slug] = f"""[External, run 'ns-train {config_slug}' to install] {config_description}""" return method_configs, descriptions diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index dd842e8ab6..9a7d31c10c 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -19,13 +19,13 @@ from __future__ import annotations from collections import OrderedDict -from typing import Dict +from typing import Dict, Union import tyro from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig from nerfstudio.configs.base_config import ViewerConfig -from nerfstudio.configs.external_methods import get_external_methods +from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig @@ -65,7 +65,7 @@ from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig from nerfstudio.plugins.registry import discover_methods -method_configs: Dict[str, TrainerConfig] = {} +method_configs: Dict[str, Union[TrainerConfig, ExternalMethodDummyTrainerConfig]] = {} descriptions = { "nerfacto": "Recommended real-time model tuned for real captures. This model will be continually updated.", "depth-nerfacto": "Nerfacto with depth supervision.",