diff --git a/setup.py b/setup.py index b0c9ecbc32..6f816ce4a6 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ def parse_requirements(): try: torch_version = version("torch") + _install_requires.append(f"torch=={torch_version}") if torch_version.startswith("2.1."): _install_requires.pop(_install_requires.index("xformers==0.0.22")) _install_requires.append("xformers>=0.0.23") @@ -50,7 +51,7 @@ def parse_requirements(): dependency_links=dependency_links, extras_require={ "flash-attn": [ - "flash-attn==2.3.3", + "flash-attn==2.5.0", ], "fused-dense-lib": [ "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",