Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feat/snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark committed Dec 11, 2024
2 parents c5e6323 + 2408639 commit 0129db7
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 20 deletions.
52 changes: 33 additions & 19 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,17 +642,21 @@ def run(
stuck = []

thread_pool = ThreadPoolExecutor(max_workers=args.solver_threads)
result_exs = []
future_models = []
counterexamples = []
unsat_cores = []
traces = {}
traces: dict[int, str] = {}
exec_cache: dict[int, Exec] = {}

def future_callback(future_model):
m = future_model.result()
models.append(m)

model, index, result = m.model, m.index, m.result

# retrieve cached exec and clear the cache entry
exec = exec_cache.pop(index, None)

if result == unsat:
if m.unsat_core:
unsat_cores.append(m.unsat_core)
Expand All @@ -672,20 +676,25 @@ def future_callback(future_model):
else:
warn_code(COUNTEREXAMPLE_UNKNOWN, f"Counterexample: {result}")

if args.print_failed_states:
print(f"# {idx+1}")
print(result_exs[index])

if args.verbose >= VERBOSITY_TRACE_COUNTEREXAMPLE:
print(
f"Trace #{idx+1}:"
f"Trace #{index + 1}:"
if args.verbose == VERBOSITY_TRACE_PATHS
else "Trace:"
)
print(traces[index], end="")

if args.print_failed_states:
print(f"# {index + 1}")
print(exec)

# initialize with default value in case we don't enter the loop body
idx = -1

for idx, ex in enumerate(exs):
result_exs.append(ex)
# cache exec in case we need to print it later
if args.print_failed_states:
exec_cache[idx] = ex

if args.verbose >= VERBOSITY_TRACE_PATHS:
print(f"Path #{idx+1}:")
Expand Down Expand Up @@ -725,15 +734,21 @@ def future_callback(future_model):
print(ex)
normal += 1

# print post-states
if args.print_states:
print(f"# {idx+1}")
print(ex)

# 0 width is unlimited
if args.width and len(result_exs) >= args.width:
if args.width and idx >= args.width:
break

num_execs = idx + 1
timer.create_subtimer("models")

if len(future_models) > 0 and args.verbose >= 1:
if future_models and args.verbose >= 1:
print(
f"# of potential paths involving assertion violations: {len(future_models)} / {len(result_exs)} (--solver-threads {args.solver_threads})"
f"# of potential paths involving assertion violations: {len(future_models)} / {num_execs} (--solver-threads {args.solver_threads})"
)

# display assertion solving progress
Expand Down Expand Up @@ -781,7 +796,7 @@ def future_callback(future_model):

# print result
print(
f"{passfail} {funsig} (paths: {len(result_exs)}, {time_info}, bounds: [{', '.join([str(x) for x in dyn_params])}])"
f"{passfail} {funsig} (paths: {num_execs}, {time_info}, bounds: [{', '.join([str(x) for x in dyn_params])}])"
)

for idx, _, err in stuck:
Expand All @@ -797,12 +812,6 @@ def future_callback(future_model):
)
debug("\n".join(jumpid_str(x) for x in logs.bounded_loops))

# print post-states
if args.print_states:
for idx, ex in enumerate(result_exs):
print(f"# {idx+1} / {len(result_exs)}")
print(ex)

# log steps
if args.log:
with open(args.log, "w") as json_file:
Expand All @@ -817,7 +826,7 @@ def future_callback(future_model):
exitcode,
len(counterexamples),
counterexamples,
(len(result_exs), normal, len(stuck)),
(num_execs, normal, len(stuck)),
(timer.elapsed(), timer["paths"].elapsed(), timer["models"].elapsed()),
len(logs.bounded_loops),
)
Expand Down Expand Up @@ -1352,6 +1361,11 @@ def _main(_args=None) -> MainResult:
logger.setLevel(logging.DEBUG)
logger_unique.setLevel(logging.DEBUG)

if args.trace_memory:
import halmos.memtrace as memtrace

memtrace.MemTracer.get().start()

#
# compile
#
Expand Down
8 changes: 7 additions & 1 deletion src/halmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ class Config:
group=debugging,
)

trace_memory: bool = arg(
help="trace memory allocations and deallocations",
global_default=False,
group=debugging,
)

### Build options

forge_build_out: str = arg(
Expand Down Expand Up @@ -787,7 +793,7 @@ def _to_toml_str(value: Any, type) -> str:
continue

name = field_info.name.replace("_", "-")
if name in ["config", "root", "version"]:
if name in ["config", "root", "version", "trace_memory"]:
# skip fields that don't make sense in a config file
continue

Expand Down
181 changes: 181 additions & 0 deletions src/halmos/memtrace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import io
import linecache
import threading
import time
import tracemalloc

from rich.console import Console

from halmos.logs import debug

console = Console()


def readable_size(num: int | float) -> str:
if num < 1024:
return f"{num}B"

if num < 1024 * 1024:
return f"{num/1024:.1f}KiB"

return f"{num/(1024*1024):.1f}MiB"


def pretty_size(num: int | float) -> str:
return f"[magenta]{readable_size(num)}[/magenta]"


def pretty_count_diff(num: int | float) -> str:
if num > 0:
return f"[red]+{num}[/red]"
elif num < 0:
return f"[green]{num}[/green]"
else:
return "[gray]0[/gray]"


def pretty_line(line: str):
return f"[white] {line}[/white]" if line else ""


def pretty_frame_info(
frame: tracemalloc.Frame, result_number: int | None = None
) -> str:
result_number_str = (
f"[grey37]# {result_number+1}:[/grey37] " if result_number is not None else ""
)
filename_str = f"[grey37]{frame.filename}:[/grey37]"
lineno_str = f"[grey37]{frame.lineno}:[/grey37]"
return f"{result_number_str}{filename_str}{lineno_str}"


class MemTracer:
curr_snapshot: tracemalloc.Snapshot | None = None
prev_snapshot: tracemalloc.Snapshot | None = None
running: bool = False

_instance = None
_lock = threading.Lock()

def __init__(self):
if MemTracer._instance is not None:
raise RuntimeError("Use MemTracer.get() to access the singleton instance.")
self.curr_snapshot = None
self.prev_snapshot = None
self.running = False

@classmethod
def get(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance

def take_snapshot(self):
debug("memtracer: taking snapshot")
self.prev_snapshot = self.curr_snapshot
self.curr_snapshot = tracemalloc.take_snapshot()
self.display_stats()

def display_stats(self):
"""Display statistics about the current memory snapshot."""
if not self.running:
return

if self.curr_snapshot is None:
debug("memtracer: no current snapshot")
return

out = io.StringIO()

# Show top memory consumers by line
out.write("[cyan][ Top memory consumers ][/cyan]\n")
stats = self.curr_snapshot.statistics("lineno")
for i, stat in enumerate(stats[:10]):
frame = stat.traceback[0]
line = linecache.getline(frame.filename, frame.lineno).strip()
out.write(f"{pretty_frame_info(frame, i)} " f"{pretty_size(stat.size)}\n")
out.write(f"{pretty_line(line)}\n")
out.write("\n")

# Get total memory usage
total = sum(stat.size for stat in self.curr_snapshot.statistics("filename"))
out.write(f"Total memory used in snapshot: {pretty_size(total)}\n\n")

console.print(out.getvalue())

def start(self, interval_seconds=60):
"""Start tracking memory usage at the specified interval."""
if not tracemalloc.is_tracing():
nframes = 1
tracemalloc.start(nframes)
self.running = True

self.take_snapshot()
threading.Thread(
target=self._run, args=(interval_seconds,), daemon=True
).start()

def stop(self):
"""Stop the memory tracer."""
self.running = False

def _run(self, interval_seconds):
"""Run the tracer periodically."""
while self.running:
time.sleep(interval_seconds)
self.take_snapshot()
self._display_differences()

def _display_differences(self):
"""Display top memory differences between snapshots."""

if not self.running:
return

if self.prev_snapshot is None or self.curr_snapshot is None:
debug("memtracer: no snapshots to compare")
return

out = io.StringIO()

top_stats = self.curr_snapshot.compare_to(
self.prev_snapshot, "lineno", cumulative=True
)
out.write("[cyan][ Top differences ][/cyan]\n")
for i, stat in enumerate(top_stats[:10]):
frame = stat.traceback[0]
line = linecache.getline(frame.filename, frame.lineno).strip()
out.write(
f"{pretty_frame_info(frame, i)} "
f"{pretty_size(stat.size_diff)} "
f"[{pretty_count_diff(stat.count_diff)}]\n"
)
out.write(f"{pretty_line(line)}\n")

total_diff = sum(stat.size_diff for stat in top_stats)
out.write(f"Total size difference: {pretty_size(total_diff)}\n")

console.print(out.getvalue())


def main():
tracer = MemTracer.get()
tracer.start(interval_seconds=2)

# Simulate some workload
import random

memory_hog = []
try:
while True:
memory_hog.append([random.random() for _ in range(1000)])
time.sleep(0.1)
except KeyboardInterrupt:
# Stop the tracer on exit
tracer.stop()


if __name__ == "__main__":
main()

0 comments on commit 0129db7

Please sign in to comment.