diff --git a/setup.py b/setup.py index 3109d4af1a..1ceda5c044 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,14 @@ def parse_requirements(): ): # Handle standard packages _install_requires.append(line) + + # TODO(wing) remove once xformers release supports torch 2.1.0 + if "torch==2.1.0" in _install_requires: + _install_requires.pop(_install_requires.index("xformers>=0.0.22")) + _install_requires.append( + "git+https://github.com/facebookresearch/xformers.git@main#egg=xformers" + ) + return _install_requires, _dependency_links