diff --git a/requirements.txt b/requirements.txt index 2dd3517a7a..37ee1e42cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,5 @@ lm_eval==0.4.4 langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 + +torchao==0.5.0 diff --git a/setup.py b/setup.py index e939bc37ee..7d9568dbff 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ def parse_requirements(): try: xformers_version = [req for req in _install_requires if "xformers" in req][0] + torchao_version = [req for req in _install_requires if "torchao" in req][0] if "Darwin" in platform.system(): # don't install xformers on MacOS _install_requires.pop(_install_requires.index(xformers_version)) @@ -53,7 +54,8 @@ def parse_requirements(): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") - if (major, minor) >= (2, 3): + elif (major, minor) >= (2, 3): + _install_requires.pop(_install_requires.index(torchao_version)) if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.26.post1") @@ -61,9 +63,11 @@ def parse_requirements(): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") elif (major, minor) >= (2, 2): + _install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") else: + _install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.23.post1")