Skip to content

Commit

Permalink
lsp-devtools: Visualise message traffic between client and server
Browse files Browse the repository at this point in the history
Aside from just looking cool, this gives the user feedback that the
`lsp-devtools record` command is actually doing something.
  • Loading branch information
alcarney committed Jan 29, 2024
1 parent b968f89 commit 91a6f4c
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 20 deletions.
1 change: 1 addition & 0 deletions lib/lsp-devtools/changes/134.enhancement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When not printing messages to stdout, the `lsp-devtools record` command now displays a nice visualisation of the traffic between client and server - so that you can see that it's doing something
59 changes: 39 additions & 20 deletions lib/lsp-devtools/lsp_devtools/record/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lsp_devtools.handlers.sql import SqlHandler

from .filters import LSPFilter
from .visualize import SpinnerHandler

EXPORTERS = {
".html": ("save_html", {}),
Expand Down Expand Up @@ -78,10 +79,9 @@ def log_rpc_message(ls: AgentServer, message: MessageText):
parse_rpc_message(ls, message, logfn)


def setup_stdout_output(args, logger: logging.Logger) -> Console:
"""Log to stdout."""
def setup_stdout_output(args, logger: logging.Logger, console: Console):
"""Log messages to stdout."""

console = Console(record=args.save_output is not None)
handler = RichLSPHandler(level=logging.INFO, console=console)
handler.addFilter(
LSPFilter(
Expand All @@ -95,10 +95,10 @@ def setup_stdout_output(args, logger: logging.Logger) -> Console:
)

logger.addHandler(handler)
return console


def setup_file_output(args, logger: logging.Logger):
def setup_file_output(args, logger: logging.Logger, console: Optional[Console] = None):
"""Log messages to a file."""
handler = logging.FileHandler(filename=str(args.to_file))
handler.setLevel(logging.INFO)
handler.addFilter(
Expand All @@ -112,10 +112,19 @@ def setup_file_output(args, logger: logging.Logger):
)
)

if console:
spinner = SpinnerHandler(console)
spinner.setLevel(logging.INFO)
logger.addHandler(spinner)

# This must come last!
logger.addHandler(handler)


def setup_sqlite_output(args, logger: logging.Logger):
def setup_sqlite_output(
args, logger: logging.Logger, console: Optional[Console] = None
):
"""Log messages to SQLite."""
handler = SqlHandler(args.to_sqlite)
handler.setLevel(logging.INFO)
handler.addFilter(
Expand All @@ -128,6 +137,12 @@ def setup_sqlite_output(args, logger: logging.Logger):
)
)

if console:
spinner = SpinnerHandler(console)
spinner.setLevel(logging.INFO)
logger.addHandler(spinner)

# This must come last!
logger.addHandler(handler)


Expand All @@ -137,36 +152,40 @@ def start_recording(args, extra: List[str]):
logger.setLevel(logging.INFO)
server.feature(MESSAGE_TEXT_NOTIFICATION)(log_func)

console: Optional[Console] = None
host = args.host
port = args.port
console = Console(record=args.save_output is not None)

if args.to_file:
setup_file_output(args, logger)
setup_file_output(args, logger, console)

elif args.to_sqlite:
setup_sqlite_output(args, logger)
setup_sqlite_output(args, logger, console)

else:
console = setup_stdout_output(args, logger)
setup_stdout_output(args, logger, console)

try:
host = args.host
port = args.port

print(f"Waiting for connection on {host}:{port}...", end="\r", flush=True)
asyncio.run(server.start_tcp(host, port))
except asyncio.CancelledError:
pass
except KeyboardInterrupt:
server.stop()

if console is not None and args.save_output is not None:
destination = args.save_output
exporter_name, kwargs = EXPORTERS.get(destination.suffix, (None, None))
if exporter_name is None:
console.print(f"Unable to save output to '{destination.suffix}' files")
return
if console is not None:
console.show_cursor(True)

if args.save_output is not None:
destination = args.save_output
exporter_name, kwargs = EXPORTERS.get(destination.suffix, (None, None))
if exporter_name is None:
console.print(f"Unable to save output to '{destination.suffix}' files")
return

exporter = getattr(console, exporter_name)
exporter(str(destination), **kwargs)
exporter = getattr(console, exporter_name)
exporter(str(destination), **kwargs)


def setup_filter_args(cmd: argparse.ArgumentParser):
Expand Down
169 changes: 169 additions & 0 deletions lib/lsp-devtools/lsp_devtools/record/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from __future__ import annotations

import logging
import typing

from rich import progress
from rich.measure import Measurement
from rich.segment import Segment
from rich.style import Style

if typing.TYPE_CHECKING:
from typing import List
from typing import Optional

from rich.console import Console
from rich.console import ConsoleOptions
from rich.console import RenderResult
from rich.table import Column


class PacketPipe:
"""Rich renderable that generates the visualisation of the in-flight packets between
client and server."""

def __init__(self, server_packets, client_packets):
self.server_packets = server_packets
self.client_packets = client_packets

def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:
width = options.max_width
pipe_length = width - 2

client_packets = {int(p * pipe_length) for p in self.client_packets}
server_packets = {
pipe_length - int(p * pipe_length) for p in self.server_packets
}

yield Segment("[")

for idx in range(pipe_length):
if idx in server_packets:
yield Segment("●", style=Style(color="blue"))
elif idx in client_packets:
yield Segment("●", style=Style(color="red"))
else:
yield Segment(" ")

yield Segment("]")

def __rich_measure__(
self, console: Console, options: ConsoleOptions
) -> Measurement:
return Measurement(4, options.max_width)


class PacketPipeColumn(progress.ProgressColumn):
"""Visualizes messages sent between client and server as "packets"."""

def __init__(
self, duration: float = 1.0, table_column: Optional[Column] = None
) -> None:
self.client_count = 0
self.server_count = 0
self.server_times: List[float] = []
self.client_times: List[float] = []

# How long it should take for a packet to propogate.
self.duration = duration

super().__init__(table_column)

def _update_packets(self, task: progress.Task, source: str) -> List[float]:
"""Update the packet positions for the given message source.
Parameters
----------
task
The task object
source
The message source
Returns
-------
List[float]
A list of floats in the range [0,1] indicating the number of packets in flight
and their position
"""
count_attr = f"{source}_count"
time_attr = f"{source}_times"

count = getattr(self, count_attr)
times = getattr(self, time_attr)
current_count = task.fields[count_attr]
current_time = task.get_time()

if current_count > count:
setattr(self, count_attr, current_count)
times.append(current_time)

packets = []
new_times = []

for time in times:
if (delta := current_time - time) > self.duration:
continue

packets.append(delta / self.duration)
new_times.append(time)

setattr(self, time_attr, new_times)
return packets

def render(self, task: progress.Task) -> PacketPipe:
"""Render the packet pipe."""

client_packets = self._update_packets(task, "client")
server_packets = self._update_packets(task, "server")

return PacketPipe(server_packets=server_packets, client_packets=client_packets)


class SpinnerHandler(logging.Handler):
"""A logging handler that shows a customised progress bar, used to show feedback for
an active connection."""

def __init__(self, console: Console, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.server_count = 0
self.client_count = 0
self.progress = progress.Progress(
progress.TextColumn("[red]CLIENT[/red] {task.fields[client_method]}"),
PacketPipeColumn(),
progress.TextColumn("{task.fields[server_method]} [blue]SERVER[/blue]"),
console=console,
auto_refresh=True,
expand=True,
)
self.task = self.progress.add_task(
"",
total=None,
server_method="",
client_method="",
server_count=self.server_count,
client_count=self.client_count,
)

def emit(self, record: logging.LogRecord):
message = record.args

if not isinstance(message, dict):
return

self.progress.start()

method = message.get("method", None)
source = record.__dict__["source"]
args = {}

if method:
args[f"{source}_method"] = method
count = getattr(self, f"{source}_count") + 1

setattr(self, f"{source}_count", count)
args[f"{source}_count"] = count

self.progress.update(self.task, **args)

0 comments on commit 91a6f4c

Please sign in to comment.