diff --git a/runtime/tools/python/ttrt/__init__.py b/runtime/tools/python/ttrt/__init__.py index 1166f464a..c53c4424b 100644 --- a/runtime/tools/python/ttrt/__init__.py +++ b/runtime/tools/python/ttrt/__init__.py @@ -2,19 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 -import os -import json -import importlib.machinery -import sys -import signal -import os -import io -import subprocess -import time -import socket -from pkg_resources import get_distribution -import sys -import shutil +# NOTE: it is _VERY_ important that this import & setup call is _BEFORE_ any +# other `ttrt` imports and _AFTER_ all system imports to ensure a well ordered +# setup of the pybound `.so`. Otherwise, undefined behaviour ensues related to +# the timing of when `TTMETAL_HOME` environment variable is set. DO NOT MOVE +# w.r.t. other imports. This is a temporary workaround until `TT_METAL_HOME` is +# not used anymore in TTMetal +import ttrt.library_tweaks + +ttrt.library_tweaks.set_tt_metal_home() import ttrt.binary from ttrt.common.api import API diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 4f7003969..3f1b1adf8 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -11,6 +11,8 @@ import torch +from ttrt.runtime._C import DataType + # environment tweaks if "LOGGER_LEVEL" not in os.environ: @@ -19,8 +21,7 @@ os.environ["TT_METAL_LOGGER_LEVEL"] = "FATAL" -def ttrt_datatype_to_torch_dtype(dtype) -> torch.dtype: - from ttrt.runtime._C import DataType +def ttrt_datatype_to_torch_dtype(dtype: DataType) -> torch.dtype: """Converts a PyBound `::tt::target::DataType` into a `torch.dtype`. @@ -67,17 +68,6 @@ def get_ttrt_metal_home_path(): return tt_metal_home -os.environ["TT_METAL_HOME"] = get_ttrt_metal_home_path() - -new_linker_path = f"{get_ttrt_metal_home_path()}/tests" -current_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") -if current_ld_library_path: - updated_ld_library_path = f"{new_linker_path}:{current_ld_library_path}" -else: - updated_ld_library_path = new_linker_path -os.environ["LD_LIBRARY_PATH"] = updated_ld_library_path - - class Logger: def __init__(self, file_name=""): import logging diff --git a/runtime/tools/python/ttrt/library_tweaks.py b/runtime/tools/python/ttrt/library_tweaks.py new file mode 100644 index 000000000..db743a4e1 --- /dev/null +++ b/runtime/tools/python/ttrt/library_tweaks.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +""" +Simple library tweaks module used to move `TT_METAL_HOME` to point to the +mirrored TTMetal tree within the `ttrt` wheel. It is important that +`set_tt_metal_home()` is the _FIRST_ bit of code run in this `ttrt` module. +Thus, this file should only be included in `ttrt/__init__.py` and only run +there. This is a temporary fix, and will need to be cleaned up once TTMetal +drops `TT_METAL_HOME` functionality +""" +import importlib.util +import os + + +def get_ttrt_metal_home_path() -> str: + """Finds the root of the mirrored TTMetal tree within the `ttrt` wheel""" + package_name = "ttrt" + spec = importlib.util.find_spec(package_name) + package_path = os.path.dirname(spec.origin) + tt_metal_home = f"{package_path}/runtime" + return tt_metal_home + + +def set_tt_metal_home(): + """Sets the environment variable `TT_METAL_HOME` to point into the root + mirrored TTMetal tree within the `ttrt` wheel. + """ + os.environ["TT_METAL_HOME"] = get_ttrt_metal_home_path() + + new_linker_path = f"{get_ttrt_metal_home_path()}/tests" + current_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + if current_ld_library_path: + updated_ld_library_path = f"{new_linker_path}:{current_ld_library_path}" + else: + updated_ld_library_path = new_linker_path + os.environ["LD_LIBRARY_PATH"] = updated_ld_library_path