Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lsp-devtools: Visualise message traffic between client and server #139

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/lsp-devtools/changes/134.enhancement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When not printing messages to stdout, the `lsp-devtools record` command now displays a nice visualisation of the traffic between client and server - so that you can see that it's doing something
59 changes: 39 additions & 20 deletions lib/lsp-devtools/lsp_devtools/record/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lsp_devtools.handlers.sql import SqlHandler

from .filters import LSPFilter
from .visualize import SpinnerHandler

EXPORTERS = {
".html": ("save_html", {}),
Expand Down Expand Up @@ -78,10 +79,9 @@ def log_rpc_message(ls: AgentServer, message: MessageText):
parse_rpc_message(ls, message, logfn)


def setup_stdout_output(args, logger: logging.Logger) -> Console:
"""Log to stdout."""
def setup_stdout_output(args, logger: logging.Logger, console: Console):
"""Log messages to stdout."""

console = Console(record=args.save_output is not None)
handler = RichLSPHandler(level=logging.INFO, console=console)
handler.addFilter(
LSPFilter(
Expand All @@ -95,10 +95,10 @@ def setup_stdout_output(args, logger: logging.Logger) -> Console:
)

logger.addHandler(handler)
return console


def setup_file_output(args, logger: logging.Logger):
def setup_file_output(args, logger: logging.Logger, console: Optional[Console] = None):
"""Log messages to a file."""
handler = logging.FileHandler(filename=str(args.to_file))
handler.setLevel(logging.INFO)
handler.addFilter(
Expand All @@ -112,10 +112,19 @@ def setup_file_output(args, logger: logging.Logger):
)
)

if console:
spinner = SpinnerHandler(console)
spinner.setLevel(logging.INFO)
logger.addHandler(spinner)

# This must come last!
logger.addHandler(handler)


def setup_sqlite_output(args, logger: logging.Logger):
def setup_sqlite_output(
args, logger: logging.Logger, console: Optional[Console] = None
):
"""Log messages to SQLite."""
handler = SqlHandler(args.to_sqlite)
handler.setLevel(logging.INFO)
handler.addFilter(
Expand All @@ -128,6 +137,12 @@ def setup_sqlite_output(args, logger: logging.Logger):
)
)

if console:
spinner = SpinnerHandler(console)
spinner.setLevel(logging.INFO)
logger.addHandler(spinner)

# This must come last!
logger.addHandler(handler)


Expand All @@ -137,36 +152,40 @@ def start_recording(args, extra: List[str]):
logger.setLevel(logging.INFO)
server.feature(MESSAGE_TEXT_NOTIFICATION)(log_func)

console: Optional[Console] = None
host = args.host
port = args.port
console = Console(record=args.save_output is not None)

if args.to_file:
setup_file_output(args, logger)
setup_file_output(args, logger, console)

elif args.to_sqlite:
setup_sqlite_output(args, logger)
setup_sqlite_output(args, logger, console)

else:
console = setup_stdout_output(args, logger)
setup_stdout_output(args, logger, console)

try:
host = args.host
port = args.port

print(f"Waiting for connection on {host}:{port}...", end="\r", flush=True)
asyncio.run(server.start_tcp(host, port))
except asyncio.CancelledError:
pass
except KeyboardInterrupt:
server.stop()

if console is not None and args.save_output is not None:
destination = args.save_output
exporter_name, kwargs = EXPORTERS.get(destination.suffix, (None, None))
if exporter_name is None:
console.print(f"Unable to save output to '{destination.suffix}' files")
return
if console is not None:
console.show_cursor(True)

if args.save_output is not None:
destination = args.save_output
exporter_name, kwargs = EXPORTERS.get(destination.suffix, (None, None))
if exporter_name is None:
console.print(f"Unable to save output to '{destination.suffix}' files")
return

exporter = getattr(console, exporter_name)
exporter(str(destination), **kwargs)
exporter = getattr(console, exporter_name)
exporter(str(destination), **kwargs)


def setup_filter_args(cmd: argparse.ArgumentParser):
Expand Down
169 changes: 169 additions & 0 deletions lib/lsp-devtools/lsp_devtools/record/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from __future__ import annotations

import logging
import typing

from rich import progress
from rich.measure import Measurement
from rich.segment import Segment
from rich.style import Style

if typing.TYPE_CHECKING:
from typing import List
from typing import Optional

from rich.console import Console
from rich.console import ConsoleOptions
from rich.console import RenderResult
from rich.table import Column


class PacketPipe:
"""Rich renderable that generates the visualisation of the in-flight packets between
client and server."""

def __init__(self, server_packets, client_packets):
self.server_packets = server_packets
self.client_packets = client_packets

def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:
width = options.max_width
pipe_length = width - 2

client_packets = {int(p * pipe_length) for p in self.client_packets}
server_packets = {
pipe_length - int(p * pipe_length) for p in self.server_packets
}

yield Segment("[")

for idx in range(pipe_length):
if idx in server_packets:
yield Segment("●", style=Style(color="blue"))
elif idx in client_packets:
yield Segment("●", style=Style(color="red"))
else:
yield Segment(" ")

yield Segment("]")

def __rich_measure__(
self, console: Console, options: ConsoleOptions
) -> Measurement:
return Measurement(4, options.max_width)


class PacketPipeColumn(progress.ProgressColumn):
"""Visualizes messages sent between client and server as "packets"."""

def __init__(
self, duration: float = 1.0, table_column: Optional[Column] = None
) -> None:
self.client_count = 0
self.server_count = 0
self.server_times: List[float] = []
self.client_times: List[float] = []

# How long it should take for a packet to propogate.
self.duration = duration

super().__init__(table_column)

def _update_packets(self, task: progress.Task, source: str) -> List[float]:
"""Update the packet positions for the given message source.

Parameters
----------
task
The task object

source
The message source

Returns
-------
List[float]
A list of floats in the range [0,1] indicating the number of packets in flight
and their position
"""
count_attr = f"{source}_count"
time_attr = f"{source}_times"

count = getattr(self, count_attr)
times = getattr(self, time_attr)
current_count = task.fields[count_attr]
current_time = task.get_time()

if current_count > count:
setattr(self, count_attr, current_count)
times.append(current_time)

packets = []
new_times = []

for time in times:
if (delta := current_time - time) > self.duration:
continue

packets.append(delta / self.duration)
new_times.append(time)

setattr(self, time_attr, new_times)
return packets

def render(self, task: progress.Task) -> PacketPipe:
"""Render the packet pipe."""

client_packets = self._update_packets(task, "client")
server_packets = self._update_packets(task, "server")

return PacketPipe(server_packets=server_packets, client_packets=client_packets)


class SpinnerHandler(logging.Handler):
"""A logging handler that shows a customised progress bar, used to show feedback for
an active connection."""

def __init__(self, console: Console, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.server_count = 0
self.client_count = 0
self.progress = progress.Progress(
progress.TextColumn("[red]CLIENT[/red] {task.fields[client_method]}"),
PacketPipeColumn(),
progress.TextColumn("{task.fields[server_method]} [blue]SERVER[/blue]"),
console=console,
auto_refresh=True,
expand=True,
)
self.task = self.progress.add_task(
"",
total=None,
server_method="",
client_method="",
server_count=self.server_count,
client_count=self.client_count,
)

def emit(self, record: logging.LogRecord):
message = record.args

if not isinstance(message, dict):
return

self.progress.start()

method = message.get("method", None)
source = record.__dict__["source"]
args = {}

if method:
args[f"{source}_method"] = method
count = getattr(self, f"{source}_count") + 1

setattr(self, f"{source}_count", count)
args[f"{source}_count"] = count

self.progress.update(self.task, **args)
Loading