diff --git a/lib/lsp-devtools/changes/134.enhancement.md b/lib/lsp-devtools/changes/134.enhancement.md new file mode 100644 index 0000000..35689c8 --- /dev/null +++ b/lib/lsp-devtools/changes/134.enhancement.md @@ -0,0 +1 @@ +When not printing messages to stdout, the `lsp-devtools record` command now displays a nice visualisation of the traffic between client and server - so that you can see that it's doing something diff --git a/lib/lsp-devtools/lsp_devtools/record/__init__.py b/lib/lsp-devtools/lsp_devtools/record/__init__.py index b33c54d..b6e86f3 100644 --- a/lib/lsp-devtools/lsp_devtools/record/__init__.py +++ b/lib/lsp-devtools/lsp_devtools/record/__init__.py @@ -20,6 +20,7 @@ from lsp_devtools.handlers.sql import SqlHandler from .filters import LSPFilter +from .visualize import SpinnerHandler EXPORTERS = { ".html": ("save_html", {}), @@ -78,10 +79,9 @@ def log_rpc_message(ls: AgentServer, message: MessageText): parse_rpc_message(ls, message, logfn) -def setup_stdout_output(args, logger: logging.Logger) -> Console: - """Log to stdout.""" +def setup_stdout_output(args, logger: logging.Logger, console: Console): + """Log messages to stdout.""" - console = Console(record=args.save_output is not None) handler = RichLSPHandler(level=logging.INFO, console=console) handler.addFilter( LSPFilter( @@ -95,10 +95,10 @@ def setup_stdout_output(args, logger: logging.Logger) -> Console: ) logger.addHandler(handler) - return console -def setup_file_output(args, logger: logging.Logger): +def setup_file_output(args, logger: logging.Logger, console: Optional[Console] = None): + """Log messages to a file.""" handler = logging.FileHandler(filename=str(args.to_file)) handler.setLevel(logging.INFO) handler.addFilter( @@ -111,11 +111,18 @@ def setup_file_output(args, logger: logging.Logger): formatter=args.format_message or "{.|json-compact}", ) ) - logger.addHandler(handler) + if console: + spinner = SpinnerHandler(console) + spinner.setLevel(logging.INFO) + logger.addHandler(spinner) -def setup_sqlite_output(args, logger: logging.Logger): + +def setup_sqlite_output( + args, logger: logging.Logger, console: Optional[Console] = None +): + """Log messages to SQLite.""" handler = SqlHandler(args.to_sqlite) handler.setLevel(logging.INFO) handler.addFilter( @@ -127,9 +134,13 @@ def setup_sqlite_output(args, logger: logging.Logger): exclude_methods=args.exclude_methods, ) ) - logger.addHandler(handler) + if console: + spinner = SpinnerHandler(console) + spinner.setLevel(logging.INFO) + logger.addHandler(spinner) + def start_recording(args, extra: List[str]): server = AgentServer() @@ -137,20 +148,21 @@ def start_recording(args, extra: List[str]): logger.setLevel(logging.INFO) server.feature(MESSAGE_TEXT_NOTIFICATION)(log_func) - console: Optional[Console] = None - host = args.host - port = args.port + console = Console(record=args.save_output is not None) if args.to_file: - setup_file_output(args, logger) + setup_file_output(args, logger, console) elif args.to_sqlite: - setup_sqlite_output(args, logger) + setup_sqlite_output(args, logger, console) else: - console = setup_stdout_output(args, logger) + setup_stdout_output(args, logger, console) try: + host = args.host + port = args.port + print(f"Waiting for connection on {host}:{port}...", end="\r", flush=True) asyncio.run(server.start_tcp(host, port)) except asyncio.CancelledError: @@ -158,15 +170,18 @@ def start_recording(args, extra: List[str]): except KeyboardInterrupt: server.stop() - if console is not None and args.save_output is not None: - destination = args.save_output - exporter_name, kwargs = EXPORTERS.get(destination.suffix, (None, None)) - if exporter_name is None: - console.print(f"Unable to save output to '{destination.suffix}' files") - return + if console is not None: + console.show_cursor(True) + + if args.save_output is not None: + destination = args.save_output + exporter_name, kwargs = EXPORTERS.get(destination.suffix, (None, None)) + if exporter_name is None: + console.print(f"Unable to save output to '{destination.suffix}' files") + return - exporter = getattr(console, exporter_name) - exporter(str(destination), **kwargs) + exporter = getattr(console, exporter_name) + exporter(str(destination), **kwargs) def setup_filter_args(cmd: argparse.ArgumentParser): diff --git a/lib/lsp-devtools/lsp_devtools/record/visualize.py b/lib/lsp-devtools/lsp_devtools/record/visualize.py new file mode 100644 index 0000000..1e0bcaa --- /dev/null +++ b/lib/lsp-devtools/lsp_devtools/record/visualize.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import logging +import typing + +from rich import progress +from rich.measure import Measurement +from rich.segment import Segment +from rich.style import Style + +if typing.TYPE_CHECKING: + from typing import List + from typing import Optional + + from rich.console import Console + from rich.console import ConsoleOptions + from rich.console import RenderResult + from rich.table import Column + + +class PacketPipe: + """Rich renderable that generates the visualisation of the in-flight packets between + client and server.""" + + def __init__(self, server_packets, client_packets): + self.server_packets = server_packets + self.client_packets = client_packets + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + width = options.max_width + pipe_length = width - 2 + + client_packets = {int(p * pipe_length) for p in self.client_packets} + server_packets = { + pipe_length - int(p * pipe_length) for p in self.server_packets + } + + yield Segment("[") + + for idx in range(pipe_length): + if idx in server_packets: + yield Segment("●", style=Style(color="blue")) + elif idx in client_packets: + yield Segment("●", style=Style(color="red")) + else: + yield Segment(" ") + + yield Segment("]") + + def __rich_measure__( + self, console: Console, options: ConsoleOptions + ) -> Measurement: + return Measurement(4, options.max_width) + + +class PacketPipeColumn(progress.ProgressColumn): + """Visualizes messages sent between client and server as "packets".""" + + def __init__( + self, duration: float = 1.0, table_column: Optional[Column] = None + ) -> None: + self.client_count = 0 + self.server_count = 0 + self.server_times: List[float] = [] + self.client_times: List[float] = [] + + # How long it should take for a packet to propogate. + self.duration = duration + + super().__init__(table_column) + + def _update_packets(self, task: progress.Task, source: str) -> List[float]: + """Update the packet positions for the given message source. + + Parameters + ---------- + task + The task object + + source + The message source + + Returns + ------- + List[float] + A list of floats in the range [0,1] indicating the number of packets in flight + and their position + """ + count_attr = f"{source}_count" + time_attr = f"{source}_times" + + count = getattr(self, count_attr) + times = getattr(self, time_attr) + current_count = task.fields[count_attr] + current_time = task.get_time() + + if current_count > count: + setattr(self, count_attr, current_count) + times.append(current_time) + + packets = [] + new_times = [] + + for time in times: + if (delta := current_time - time) > self.duration: + continue + + packets.append(delta / self.duration) + new_times.append(time) + + setattr(self, time_attr, new_times) + return packets + + def render(self, task: progress.Task) -> PacketPipe: + """Render the packet pipe.""" + + client_packets = self._update_packets(task, "client") + server_packets = self._update_packets(task, "server") + + return PacketPipe(server_packets=server_packets, client_packets=client_packets) + + +class SpinnerHandler(logging.Handler): + """A logging handler that shows a customised progress bar, used to show feedback for + an active connection.""" + + def __init__(self, console: Console, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.server_count = 0 + self.client_count = 0 + self.progress = progress.Progress( + progress.TextColumn("[red]CLIENT[/red] {task.fields[client_method]}"), + PacketPipeColumn(), + progress.TextColumn("{task.fields[server_method]} [blue]SERVER[/blue]"), + console=console, + auto_refresh=True, + expand=True, + ) + self.task = self.progress.add_task( + "", + total=None, + server_method="", + client_method="", + server_count=self.server_count, + client_count=self.client_count, + ) + + def emit(self, record: logging.LogRecord): + message = record.args + + if not isinstance(message, dict): + return + + self.progress.start() + + method = message.get("method", None) + source = record.__dict__["source"] + args = {} + + if method: + args[f"{source}_method"] = method + count = getattr(self, f"{source}_count") + 1 + + setattr(self, f"{source}_count", count) + args[f"{source}_count"] = count + + self.progress.update(self.task, **args)