From e2b99c18f88c1e88d881b1870cd300239b4f082e Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Tue, 12 Nov 2024 16:35:17 -0800 Subject: [PATCH] Allow custom repr for function args and retval logging (#505) --- docs/source/extra_log.rst | 19 +++++++++++++ docs/source/viztracer.rst | 20 ++++++++++++++ src/viztracer/main.py | 3 +++ src/viztracer/modules/snaptrace.c | 44 ++++++++++++++++++++++++------- src/viztracer/modules/snaptrace.h | 1 + src/viztracer/tracer.py | 17 +++++++++++- src/viztracer/viztracer.py | 9 +++++++ tests/test_cmdline.py | 3 +++ tests/test_invalid.py | 5 ++++ tests/test_tracer.py | 12 +++++++++ 10 files changed, 123 insertions(+), 10 deletions(-) diff --git a/docs/source/extra_log.rst b/docs/source/extra_log.rst index 5d8adc90..4f1ed754 100644 --- a/docs/source/extra_log.rst +++ b/docs/source/extra_log.rst @@ -140,6 +140,25 @@ You can enable this feature in command line or using inline. tracer = VizTracer(log_func_retval=True) +Log Function Argument And Return Value With Custom Function +----------------------------------------------------------- + +You can log every function's arguments and return value with a custom function. You can feed your own function to ``VizTracer`` + +.. code-block:: python + + def myrepr(obj): + if isinstance(obj, MyClass): + return f"MyClass({obj.value})" + return repr(obj) + + tracer = VizTracer(log_func_args=True, log_func_repr=myrepr) + +From the CLI, you can use the ``--log_func_with_objprint`` option to log with objprint + +.. code-block:: + + viztracer --log_func_args --log_func_with_objprint my_script.py Log Print --------- diff --git a/docs/source/viztracer.rst b/docs/source/viztracer.rst index 22ba0fb3..258b32f3 100644 --- a/docs/source/viztracer.rst +++ b/docs/source/viztracer.rst @@ -153,6 +153,26 @@ VizTracer .. code-block:: viztracer --log_func_args + + .. py:attribute:: log_func_repr + :type: Optional[Callable] + :value: None + + A custom repr function to log the function arguments and return value. The function should take + a single argument and return a string. + + .. py:attribute:: log_func_with_objprint + :type: boolean + :value: False + + Whether log the arguments and return value of the function with ``objprint``. + This attribute can't be ``True`` if ``log_func_repr`` is given. + + Setting it to ``True`` is equivalent to + + .. code-block:: + + viztracer --log_func_with_objprint .. py:attribute:: log_print :type: boolean diff --git a/src/viztracer/main.py b/src/viztracer/main.py index 4d7d584a..9d0cc4b3 100644 --- a/src/viztracer/main.py +++ b/src/viztracer/main.py @@ -92,6 +92,8 @@ def create_parser(self) -> argparse.ArgumentParser: help="log functions in exit functions like atexit") parser.add_argument("--log_func_retval", action="store_true", default=False, help="log return value of the function in the report") + parser.add_argument("--log_func_with_objprint", action="store_true", default=False, + help="use objprint for function argument and return value") parser.add_argument("--log_print", action="store_true", default=False, help="replace all print() function to adding an event to the result") parser.add_argument("--log_sparse", action="store_true", default=False, @@ -272,6 +274,7 @@ def parse(self, argv: List[str]) -> VizProcedureResult: "ignore_frozen": options.ignore_frozen, "log_func_retval": options.log_func_retval, "log_func_args": options.log_func_args, + "log_func_with_objprint": options.log_func_with_objprint, "log_print": options.log_print, "log_gc": options.log_gc, "log_sparse": options.log_sparse, diff --git a/src/viztracer/modules/snaptrace.c b/src/viztracer/modules/snaptrace.c index 11195669..45478770 100644 --- a/src/viztracer/modules/snaptrace.c +++ b/src/viztracer/modules/snaptrace.c @@ -47,7 +47,7 @@ static PyObject* snaptrace_setignorestackcounter(TracerObject* self, PyObject* a static void snaptrace_flush_unfinished(TracerObject* self, int flush_as_finish); static void snaptrace_threaddestructor(void* key); static struct ThreadInfo* snaptrace_createthreadinfo(TracerObject* self); -static void log_func_args(struct FunctionNode* node, PyFrameObject* frame); +static void log_func_args(struct FunctionNode* node, PyFrameObject* frame, PyObject* log_func_repr); static double get_ts(struct ThreadInfo*); TracerObject* curr_tracer = NULL; @@ -133,7 +133,7 @@ static inline struct EventNode* get_next_node(TracerObject* self) return node; } -static void log_func_args(struct FunctionNode* node, PyFrameObject* frame) +static void log_func_args(struct FunctionNode* node, PyFrameObject* frame, PyObject* log_func_repr) { PyObject* func_arg_dict = PyDict_New(); PyCodeObject* code = PyFrame_GetCode(frame); @@ -162,8 +162,13 @@ static void log_func_args(struct FunctionNode* node, PyFrameObject* frame) while (idx < name_length) { // Borrowed PyObject* name = PyTuple_GET_ITEM(names, idx); + PyObject* repr = NULL; // New - PyObject* repr = PyObject_Repr(PyDict_GetItem(locals, name)); + if (log_func_repr) { + repr = PyObject_CallOneArg(log_func_repr, PyDict_GetItem(locals, name)); + } else { + repr = PyObject_Repr(PyDict_GetItem(locals, name)); + } if (!repr) { repr = PyUnicode_FromString("Not Displayable"); PyErr_Clear(); @@ -335,7 +340,7 @@ snaptrace_pycall_callback(TracerObject* self, PyFrameObject* frame, struct Threa info->stack_top->func = (PyObject*) code; Py_INCREF(code); if (CHECK_FLAG(self->check_flags, SNAPTRACE_LOG_FUNCTION_ARGS)) { - log_func_args(info->stack_top, frame); + log_func_args(info->stack_top, frame, self->log_func_repr); } cleanup: @@ -358,7 +363,7 @@ snaptrace_ccall_callback(TracerObject* self, PyFrameObject* frame, struct Thread info->stack_top->func = arg; Py_INCREF(arg); if (CHECK_FLAG(self->check_flags, SNAPTRACE_LOG_FUNCTION_ARGS)) { - log_func_args(info->stack_top, frame); + log_func_args(info->stack_top, frame, self->log_func_repr); } return 0; @@ -397,7 +402,17 @@ snaptrace_pyreturn_callback(TracerObject* self, PyFrameObject* frame, struct Thr Py_INCREF(stack_top->args); } if (CHECK_FLAG(self->check_flags, SNAPTRACE_LOG_RETURN_VALUE)) { - node->data.fee.retval = PyObject_Repr(arg); + PyObject* repr = NULL; + if (self->log_func_repr) { + repr = PyObject_CallOneArg(self->log_func_repr, arg); + } else { + repr = PyObject_Repr(arg); + } + if (!repr) { + repr = PyUnicode_FromString("Not Displayable"); + PyErr_Clear(); + } + node->data.fee.retval = repr; } if (CHECK_FLAG(self->check_flags, SNAPTRACE_LOG_ASYNC)) { @@ -1308,7 +1323,7 @@ snaptrace_config(TracerObject* self, PyObject* args, PyObject* kw) static char* kwlist[] = {"verbose", "lib_file_path", "max_stack_depth", "include_files", "exclude_files", "ignore_c_function", "ignore_frozen", "log_func_retval", "log_func_args", "log_async", "trace_self", - "min_duration", "process_name", + "min_duration", "process_name", "log_func_repr", NULL}; int kw_verbose = -1; int kw_max_stack_depth = 0; @@ -1316,6 +1331,7 @@ snaptrace_config(TracerObject* self, PyObject* args, PyObject* kw) PyObject* kw_process_name = NULL; PyObject* kw_include_files = NULL; PyObject* kw_exclude_files = NULL; + PyObject* kw_log_func_repr = NULL; int kw_ignore_c_function = -1; int kw_ignore_frozen = -1; int kw_log_func_retval = -1; @@ -1323,7 +1339,7 @@ snaptrace_config(TracerObject* self, PyObject* args, PyObject* kw) int kw_log_async = -1; int kw_trace_self = -1; double kw_min_duration = 0; - if (!PyArg_ParseTupleAndKeywords(args, kw, "|isiOOppppppdO", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kw, "|isiOOppppppdOO", kwlist, &kw_verbose, &kw_lib_file_path, &kw_max_stack_depth, @@ -1336,7 +1352,8 @@ snaptrace_config(TracerObject* self, PyObject* args, PyObject* kw) &kw_log_async, &kw_trace_self, &kw_min_duration, - &kw_process_name)) { + &kw_process_name, + &kw_log_func_repr)) { return NULL; } @@ -1442,6 +1459,15 @@ snaptrace_config(TracerObject* self, PyObject* args, PyObject* kw) UNSET_FLAG(self->check_flags, SNAPTRACE_EXCLUDE_FILES); } + if (kw_log_func_repr && kw_log_func_repr != Py_None) { + Py_XDECREF(self->log_func_repr); + self->log_func_repr = kw_log_func_repr; + Py_INCREF(self->log_func_repr); + } else { + Py_XDECREF(self->log_func_repr); + self->log_func_repr = NULL; + } + Py_RETURN_NONE; } diff --git a/src/viztracer/modules/snaptrace.h b/src/viztracer/modules/snaptrace.h index 6f49d3c7..9e98209b 100644 --- a/src/viztracer/modules/snaptrace.h +++ b/src/viztracer/modules/snaptrace.h @@ -76,6 +76,7 @@ typedef struct { PyObject* process_name; PyObject* include_files; PyObject* exclude_files; + PyObject* log_func_repr; double min_duration; struct EventNode* buffer; long buffer_size; diff --git a/src/viztracer/tracer.py b/src/viztracer/tracer.py index 4cd07653..5e868b75 100644 --- a/src/viztracer/tracer.py +++ b/src/viztracer/tracer.py @@ -6,7 +6,7 @@ import os import sys from io import StringIO -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Union import viztracer.snaptrace as snaptrace # type: ignore @@ -24,6 +24,7 @@ def __init__( ignore_frozen: bool = False, log_func_retval: bool = False, log_func_args: bool = False, + log_func_repr: Optional[Callable] = None, log_print: bool = False, log_gc: bool = False, log_async: bool = False, @@ -45,6 +46,7 @@ def __init__( self.ignore_frozen = ignore_frozen self.log_func_retval = log_func_retval self.log_func_args = log_func_args + self.log_func_repr = log_func_repr self.log_async = log_async self.min_duration = min_duration self.log_print = log_print @@ -143,6 +145,18 @@ def log_func_retval(self, log_func_retval: bool) -> None: raise ValueError(f"log_func_retval needs to be True or False, not {log_func_retval}") self.config() + @property + def log_func_repr(self) -> Optional[Callable]: + return self.__log_func_repr + + @log_func_repr.setter + def log_func_repr(self, log_func_repr: Optional[Callable]) -> None: + if log_func_repr is None or callable(log_func_repr): + self.__log_func_repr = log_func_repr + else: + raise ValueError("log_func_repr needs to be a callable") + self.config() + @property def log_async(self) -> bool: return self.__log_async @@ -238,6 +252,7 @@ def config(self) -> None: "ignore_frozen": self.ignore_frozen, "log_func_retval": self.log_func_retval, "log_func_args": self.log_func_args, + "log_func_repr": self.log_func_repr, "log_async": self.log_async, "trace_self": self.trace_self, "min_duration": self.min_duration, diff --git a/src/viztracer/viztracer.py b/src/viztracer/viztracer.py index 2dcd0433..11b56bef 100644 --- a/src/viztracer/viztracer.py +++ b/src/viztracer/viztracer.py @@ -30,6 +30,8 @@ def __init__(self, ignore_frozen: bool = False, log_func_retval: bool = False, log_func_args: bool = False, + log_func_repr: Optional[Callable] = None, + log_func_with_objprint: bool = False, log_print: bool = False, log_gc: bool = False, log_sparse: bool = False, @@ -47,6 +49,12 @@ def __init__(self, process_name: Optional[str] = None, output_file: str = "result.json", plugins: Sequence[Union[VizPluginBase, str]] = []) -> None: + + if log_func_with_objprint: + if log_func_repr: + raise ValueError("log_func_repr and log_func_with_objprint can't be both set") + log_func_repr = objprint.objstr + super().__init__( tracer_entries=tracer_entries, max_stack_depth=max_stack_depth, @@ -58,6 +66,7 @@ def __init__(self, log_print=log_print, log_gc=log_gc, log_func_args=log_func_args, + log_func_repr=log_func_repr, log_async=log_async, trace_self=trace_self, min_duration=min_duration, diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 6e4bf8e3..177241a4 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -319,6 +319,9 @@ def test_log_func_retval(self): def test_log_func_args(self): self.template(["python", "-m", "viztracer", "--log_func_args", "cmdline_test.py"]) + def test_log_func_with_objprint(self): + self.template(["python", "-m", "viztracer", "--log_func_args", "--log_func_with_objprint", "cmdline_test.py"]) + def test_minimize_memory(self): self.template(["python", "-m", "viztracer", "--minimize_memory", "cmdline_test.py"]) diff --git a/tests/test_invalid.py b/tests/test_invalid.py index 86ef881b..943d762d 100644 --- a/tests/test_invalid.py +++ b/tests/test_invalid.py @@ -23,6 +23,7 @@ def test_invalid_args(self): "min_duration": ["0.1.0", "12", "3us"], "ignore_frozen": ["hello", 1, "True"], "log_async": ["hello", 1, "True"], + "log_func_repr": ["hello", 1, True], } tracer = VizTracer(verbose=0) for args, vals in invalid_args.items(): @@ -44,6 +45,10 @@ def test_save_invalid_format(self): with self.assertRaises(Exception): tracer.save("test.invalid") + def test_log_func_conflict(self): + with self.assertRaises(ValueError): + _ = VizTracer(log_func_repr=repr, log_func_with_objprint=True, verbose=0) + def test_add_invalid_variable(self): tracer = VizTracer(verbose=0) tracer.start() diff --git a/tests/test_tracer.py b/tests/test_tracer.py index 97e03dce..55676c5a 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -205,6 +205,18 @@ def test_log_func_args(self): events = [e for e in tracer.data["traceEvents"] if e["ph"] != "M"] self.assertTrue("args" in events[0] and "func_args" in events[0]["args"]) + def test_log_func_repr(self): + def myrepr(obj): + return "deadbeef" + tracer = _VizTracer(log_func_args=True, log_func_repr=myrepr) + tracer.start() + fib(5) + tracer.stop() + tracer.parse() + events = [e for e in tracer.data["traceEvents"] if e["ph"] != "M"] + self.assertTrue("args" in events[0] and "func_args" in events[0]["args"] + and events[0]["args"]["func_args"]["n"] == "deadbeef") + def test_log_gc(self): import gc tracer = _VizTracer(log_gc=True)