Skip to content

Commit

Permalink
Use build constraints to limit the package version (#132)
Browse files Browse the repository at this point in the history
Summary:
Promote numpy version file to `requirements_numpy.txt` and used by install.py

Optionally user can install it with `python install.py --numpy`.

Use build constraints to make sure numpy and torch version are consistent throughout the install.

Pull Request resolved: #132

Reviewed By: adamomainz

Differential Revision: D68414892

Pulled By: xuzhao9

fbshipit-source-id: 0bd1ddd41a987e78d4d4ed5a085dd0f2bdf39ab2
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jan 24, 2025
1 parent e1e0db0 commit f23ad77
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 13 deletions.
19 changes: 18 additions & 1 deletion install.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

from tools.cuda_utils import CUDA_VERSION_MAP, DEFAULT_CUDA_VERSION
from tools.git_utils import checkout_submodules
from tools.python_utils import pip_install_requirements
from tools.python_utils import (
generate_build_constraints,
get_pkg_versions,
has_pkg,
pip_install_requirements,
)

from tritonbench.utils.env_utils import is_hip

Expand All @@ -18,6 +23,10 @@
REPO_PATH = Path(os.path.abspath(__file__)).parent
FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu")

# Packages we assume to have installed before running this script
# We will use build constraints to assume the version is not changed across the install
TRITONBENCH_DEPS = ["torch", "numpy"]


def install_jax(cuda_version=DEFAULT_CUDA_VERSION):
jax_package_name = CUDA_VERSION_MAP[cuda_version]["jax"]
Expand Down Expand Up @@ -82,6 +91,7 @@ def setup_hip(args: argparse.Namespace):

if __name__ == "__main__":
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--numpy", action="store_true", help="Install suggested numpy")
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
parser.add_argument(
"--colfax", action="store_true", help="Install optional Colfax CUTLASS kernels"
Expand Down Expand Up @@ -110,6 +120,13 @@ def setup_hip(args: argparse.Namespace):
if args.all and is_hip():
setup_hip(args)

if args.numpy or not has_pkg("numpy"):
pip_install_requirements("requirements_numpy.txt", add_build_constraints=False)

# generate build constraints before installing anything
deps = get_pkg_versions(TRITONBENCH_DEPS)
generate_build_constraints(deps)

# install framework dependencies
pip_install_requirements("requirements.txt")
# checkout submodules
Expand Down
2 changes: 0 additions & 2 deletions tools/build_requirements.txt → requirements_numpy.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,3 @@
# which still supports python 3.8
numpy==1.21.2; python_version < '3.11'
numpy==1.26.0; python_version >= '3.11'
psutil
pyyaml
63 changes: 53 additions & 10 deletions tools/python_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import subprocess
import sys
from pathlib import Path

from typing import List, Optional
from typing import Dict, List, Optional

DEFAULT_PYTHON_VERSION = "3.11"

Expand All @@ -19,30 +21,71 @@
REPO_DIR = Path(__file__).parent.parent


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def create_conda_env(pyver: str, name: str):
command = ["conda", "create", "-n", name, "-y", f"python={pyver}"]
subprocess.check_call(command)


def get_pkg_versions(packages: List[str]) -> Dict[str, str]:
versions = {}
for module in packages:
cmd = [sys.executable, "-c", f"import {module}; print({module}.__version__)"]
version = subprocess.check_output(cmd).decode().strip()
versions[module] = version
return versions


def has_pkg(pkg: str):
"""
Check if a package is installed
"""
try:
cmd = [sys.executable, "-c", f"import {pkg}; {pkg}.__version__"]
subprocess.check_call(cmd)
return True
except subprocess.CalledProcessError:
return False


def generate_build_constraints(package_versions: Dict[str, str]):
"""
Generate package versions dict and save them to REPO_DIR/build/constraints.txt
"""
output_dir = REPO_DIR.joinpath("build")
output_dir.mkdir(exist_ok=True)
with open(output_dir.joinpath("constraints.txt"), "w") as fp:
for k, v in package_versions.items():
fp.write(f"{k}=={v}\n")


def pip_install_requirements(
requirements_txt="requirements.txt",
continue_on_fail=False,
no_build_isolation=False,
add_build_constraints=True,
extra_args: Optional[List[str]] = None,
):
import sys

constraints_file = REPO_DIR.joinpath("build", "constraints.txt")
if not constraints_file.exists():
print(
"The build/constrants.txt file is not found. "
"Please consider rerunning the install.py script to generate it."
"It is recommended to install with the build/constrants.txt file "
"to prevent unexpected version change of numpy or torch."
)
constraints_parameters = []
if add_build_constraints:
if not constraints_file.exists():
logger.warn(
"The build/constrants.txt file is not found. "
"Please consider rerunning the install.py script to generate it."
"It is recommended to install with the build/constrants.txt file "
"to prevent unexpected version change of numpy or torch."
)
constraints_parameters = []
else:
constraints_parameters = ["-c", str(constraints_file.resolve())]
else:
constraints_parameters = ["-c", str(constraints_file.resolve())]
constraints_parameters = []

if no_build_isolation:
constraints_parameters.append("--no-build-isolation")
if extra_args and isinstance(extra_args, list):
Expand Down

0 comments on commit f23ad77

Please sign in to comment.