diff --git a/metricflow/dag/dag_to_text.py b/metricflow/dag/dag_to_text.py index bcf76d8c3b..931e104593 100644 --- a/metricflow/dag/dag_to_text.py +++ b/metricflow/dag/dag_to_text.py @@ -68,7 +68,6 @@ def __init__( # In case this gets used in a multi-threaded context, use a thread-local variable since it has mutable state. self._thread_local_data = threading.local() - self._thread_local_data.max_width_tracker = MaxWidthTracker(max_width) @property def _max_width_tracker(self) -> MaxWidthTracker: # noqa: D @@ -200,7 +199,10 @@ def dag_to_text(self, dag: MetricFlowDag[DagNodeT]) -> str: inner_contents="\n".join(component_from_sink_nodes_as_text), ) except Exception: - logger.exception(f"Got an exception while converting {dag} to text") + logger.exception( + f"Got an exception while converting {dag} to text. This exception will be swallowed, and the built-in " + f"string representation will be returned instead." + ) return str(dag) def dag_component_to_text(self, dag_component_leaf_node: DagNode) -> str: diff --git a/metricflow/test/mf_logging/__init__.py b/metricflow/test/mf_logging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metricflow/test/mf_logging/test_dag_to_text.py b/metricflow/test/mf_logging/test_dag_to_text.py new file mode 100644 index 0000000000..f006d362ee --- /dev/null +++ b/metricflow/test/mf_logging/test_dag_to_text.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import logging +import textwrap +import threading +import time +from typing import List + +from metricflow.dag.dag_to_text import MetricFlowDagTextFormatter +from metricflow.dataflow.sql_table import SqlTable +from metricflow.mf_logging.formatting import indent +from metricflow.sql.sql_exprs import ( + SqlStringExpression, +) +from metricflow.sql.sql_plan import SqlQueryPlan, SqlSelectColumn, SqlSelectStatementNode, SqlTableFromClauseNode + +logger = logging.getLogger(__name__) + + +def test_multithread_dag_to_text() -> None: + """Test that dag_to_text() works correctly in a multithreading context.""" + num_threads = 4 + thread_outputs: List[str] = [] + + # Using a nested structure w/ small max_line_length to force recursion / cover recursive width tracking. + dag_to_text_formatter = MetricFlowDagTextFormatter(max_width=1) + dag = SqlQueryPlan( + plan_id="plan", + render_node=SqlSelectStatementNode( + description="test", + select_columns=( + SqlSelectColumn( + expr=SqlStringExpression("'foo'"), + column_alias="bar", + ), + ), + from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="schema", table_name="table")), + from_source_alias="src", + joins_descs=(), + group_bys=(), + order_bys=(), + ), + ) + + def _run_mf_pformat() -> None: # noqa: D + current_thread = threading.current_thread() + logger.debug(f"In {current_thread} - Starting .dag_to_text()") + # Sleep a little bit so that all threads are likely to be running simultaneously. + time.sleep(0.5) + try: + output = dag_to_text_formatter.dag_to_text(dag) + logger.debug(f"in {current_thread} - Output is:\n{indent(output)}") + thread_outputs.append(output) + logger.debug(f"In {current_thread} - Successfully finished .dag_to_text()") + except Exception: + logger.exception(f"In {current_thread} - Exiting due to an exception") + + threads = tuple(threading.Thread(target=_run_mf_pformat) for _ in range(num_threads)) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + expected_thread_output = textwrap.dedent( + """\ + + + + + + + + + + + + + + + + + + + + + + + + + """ + ).rstrip() + assert thread_outputs == [expected_thread_output for _ in range(num_threads)]