Skip to content

Commit

Permalink
Clean up ns-train {method} --help for not-yet-installed external me…
Browse files Browse the repository at this point in the history
…thods (nerfstudio-project#2760)

* Clean up `ns-train {method} --help` for not-yet-installed external methods

* Ruff

* Ruff

* add clearer print statement

* Types? Not sure if this fixes it

---------

Co-authored-by: Justin Kerr <[email protected]>
  • Loading branch information
brentyi and kerrj authored Jan 20, 2024
1 parent 70d83d4 commit ff4002d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
53 changes: 29 additions & 24 deletions nerfstudio/configs/external_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@


"""This file contains the configuration for external methods which are not included in this repository."""
import inspect
import subprocess
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, cast
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import tyro
from rich.prompt import Confirm

from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.utils.rich_utils import CONSOLE


Expand Down Expand Up @@ -177,21 +178,30 @@ class ExternalMethod:


@dataclass
class ExternalMethodTrainerConfig(TrainerConfig):
"""
Trainer config for external methods which does not have an implementation in this repository.
class ExternalMethodDummyTrainerConfig:
"""Dummy trainer config for external methods (a) which do not have an
implementation in this repository, and (b) are not yet installed. When this
config is instantiated, we give the user the option to install the method.
"""

_method: ExternalMethod = field(default=cast(ExternalMethod, None))
# tyro.conf.Suppress will prevent these fields from appearing as CLI arguments.
method_name: tyro.conf.Suppress[str]
method: tyro.conf.Suppress[ExternalMethod]

def __post_init__(self):
"""Offer to install an external method."""

def handle_print_information(self, *_args, **_kwargs):
"""Prints the method information and exits."""
CONSOLE.print(self._method.instructions)
if self._method.pip_package and Confirm.ask(
# Don't trigger install message from get_external_methods() below; only
# if this dummy object is instantiated from the CLI.
if inspect.stack()[2].function == "get_external_methods":
return

CONSOLE.print(self.method.instructions)
if self.method.pip_package and Confirm.ask(
"\nWould you like to run the install it now?", default=False, console=CONSOLE
):
# Install the method
install_command = f"{sys.executable} -m pip install {self._method.pip_package}"
install_command = f"{sys.executable} -m pip install {self.method.pip_package}"
CONSOLE.print(f"Running: [cyan]{install_command}[/cyan]")
result = subprocess.run(install_command, shell=True, check=False)
if result.returncode != 0:
Expand All @@ -200,20 +210,15 @@ def handle_print_information(self, *_args, **_kwargs):

sys.exit(0)

def __getattribute__(self, __name: str) -> Any:
out = object.__getattribute__(self, __name)
if callable(out) and __name not in {"handle_print_information"} and not __name.startswith("__"):
# We exit early, displaying the message
return self.handle_print_information
return out


def get_external_methods() -> Tuple[Dict[str, TrainerConfig], Dict[str, str]]:
def get_external_methods() -> Tuple[Dict[str, ExternalMethodDummyTrainerConfig], Dict[str, str]]:
"""Returns the external methods trainer configs and the descriptions."""
method_configs = {}
descriptions = {}
method_configs: Dict[str, ExternalMethodDummyTrainerConfig] = {}
descriptions: Dict[str, str] = {}
for external_method in external_methods:
for config_slug, config_description in external_method.configurations:
method_configs[config_slug] = ExternalMethodTrainerConfig(method_name=config_slug, _method=external_method)
descriptions[config_slug] = f"""[External] {config_description}"""
method_configs[config_slug] = ExternalMethodDummyTrainerConfig(
method_name=config_slug, method=external_method
)
descriptions[config_slug] = f"""[External, run 'ns-train {config_slug}' to install] {config_description}"""
return method_configs, descriptions
6 changes: 3 additions & 3 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from __future__ import annotations

from collections import OrderedDict
from typing import Dict
from typing import Dict, Union

import tyro

from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.configs.base_config import ViewerConfig
from nerfstudio.configs.external_methods import get_external_methods
from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig
Expand Down Expand Up @@ -65,7 +65,7 @@
from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig
from nerfstudio.plugins.registry import discover_methods

method_configs: Dict[str, TrainerConfig] = {}
method_configs: Dict[str, Union[TrainerConfig, ExternalMethodDummyTrainerConfig]] = {}
descriptions = {
"nerfacto": "Recommended real-time model tuned for real captures. This model will be continually updated.",
"depth-nerfacto": "Nerfacto with depth supervision.",
Expand Down

0 comments on commit ff4002d

Please sign in to comment.