diff --git a/requirements.txt b/requirements.txt index ec571570bb..981a625580 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ transformers==4.43.1 tokenizers==0.19.1 bitsandbytes==0.43.1 accelerate==0.32.0 -deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b +deepspeed==0.14.4 pydantic==2.6.3 addict fire diff --git a/setup.py b/setup.py index ceba636690..1d164e0a18 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,7 @@ def parse_requirements(): "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib", ], "deepspeed": [ - "deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", + "deepspeed==0.14.4", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/tests/e2e/test_imports.py b/tests/e2e/test_imports.py new file mode 100644 index 0000000000..f186eaac46 --- /dev/null +++ b/tests/e2e/test_imports.py @@ -0,0 +1,20 @@ +""" +test module to import various submodules that have historically broken due to dependency issues +""" +import unittest + + +class TestImports(unittest.TestCase): + """ + Test class to import various submodules that have historically broken due to dependency issues + """ + + def test_import_causal_trainer(self): + from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401 + HFCausalTrainerBuilder, + ) + + def test_import_rl_trainer(self): + from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401 + HFRLTrainerBuilder, + )