diff --git a/guardrails/cli/start.py b/guardrails/cli/start.py index d29bbf03d..dcf5df470 100644 --- a/guardrails/cli/start.py +++ b/guardrails/cli/start.py @@ -7,6 +7,7 @@ from guardrails.cli.telemetry import trace_if_enabled from guardrails.cli.version import version_warnings_if_applicable from guardrails.cli.hub.console import console +from guardrails.settings import settings def api_is_installed() -> bool: @@ -32,15 +33,22 @@ def start( default=8000, help="The port to run the server on.", ), + watch: bool = typer.Option( + default=False, is_flag=True, help="Enable watch mode for logs." + ), ): logger.debug("Checking for prerequisites...") if not api_is_installed(): package_name = 'guardrails-api>="^0.0.0a0"' pip_process("install", package_name) - from guardrails_api.cli.start import start # type: ignore + from guardrails_api.cli.start import start as start_api # type: ignore logger.info("Starting Guardrails server") + + if watch: + settings._watch_mode_enabled = True + version_warnings_if_applicable(console) trace_if_enabled("start") - start(env, config, port) + start_api(env, config, port) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index fffb551ac..d555c3494 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -7,6 +7,7 @@ import rich import typer +from guardrails.settings import settings from guardrails.cli.guardrails import guardrails as gr_cli from guardrails.call_tracing import GuardTraceEntry, TraceHandler from guardrails.cli.telemetry import trace_if_enabled @@ -31,6 +32,7 @@ def watch_command( default=False, is_flag=True, help="Clear all log outputs and exit." ), ): + settings._watch_mode_enabled = True trace_if_enabled("watch") if clear: _clear_and_quit() diff --git a/guardrails/settings.py b/guardrails/settings.py index e3c2bcfab..df17f4ff1 100644 --- a/guardrails/settings.py +++ b/guardrails/settings.py @@ -8,6 +8,7 @@ class Settings: _instance = None _lock = threading.Lock() _rc: RC + _watch_mode_enabled: bool """Whether to use a local server for running Guardrails.""" use_server: Optional[bool] """Whether to disable tracing. @@ -29,6 +30,7 @@ def _initialize(self): self.use_server = None self.disable_tracing = None self._rc = RC.load() + self._watch_mode_enabled = False @property def rc(self) -> RC: @@ -40,5 +42,9 @@ def rc(self) -> RC: def rc(self, value: RC): self._rc = value + @property + def watch_mode_enabled(self) -> bool: + return self._watch_mode_enabled + settings = Settings() diff --git a/guardrails/telemetry/legacy_validator_tracing.py b/guardrails/telemetry/legacy_validator_tracing.py index 621f9ddb2..ee15a3200 100644 --- a/guardrails/telemetry/legacy_validator_tracing.py +++ b/guardrails/telemetry/legacy_validator_tracing.py @@ -5,6 +5,7 @@ from guardrails.actions.refrain import Refrain from guardrails.call_tracing.trace_handler import TraceHandler from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.settings import settings from guardrails.telemetry.common import get_span from guardrails.utils.casting_utils import to_string @@ -68,7 +69,8 @@ def trace_validator_result( **kwargs, } - TraceHandler().log_validator(validator_log) + if settings.watch_mode_enabled: + TraceHandler().log_validator(validator_log) current_span.add_event( f"{validator_name}_result", @@ -85,6 +87,6 @@ def trace_validation_result( current_span=None, ): _current_span = get_span(current_span) - if _current_span is not None: + if _current_span is not None and not settings.disable_tracing: for log in validation_logs: trace_validator_result(_current_span, log, attempt_number)