From b051681418e7b80d0fdbd04a9f6a04000c119181 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 Sep 2023 12:27:34 -0400 Subject: [PATCH] update per PR feedback to handle deprecated sharegpt types --- requirements.txt | 2 +- src/axolotl/utils/config.py | 19 +++++++++++++++++++ tests/test_validation.py | 23 +++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7dddca6651..a38c87f433 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,4 +32,4 @@ scikit-learn==1.2.2 pynvml art wandb -fschat +fschat==0.2.29 diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1c0487ff8e..dddb0930cf 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -262,6 +262,25 @@ def validate_config(cfg): "`model_type: MixFormerSequentialForCausalLM` required for sample_packing" ) + if cfg.datasets: + for idx, ds_cfg in enumerate(cfg.datasets): + if ds_cfg.type == "sharegpt:chat": + LOG.warning( + PendingDeprecationWarning( + "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead." + ) + ) + cfg.datasets[idx].type = "sharegpt" + if "sharegpt_simple" in ds_cfg.type: + LOG.warning( + PendingDeprecationWarning( + "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead." + ) + ) + cfg.datasets[idx].type = cfg.datasets[idx].type.replace( + "sharegpt_simple", "sharegpt" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index f250e5cb47..536e7e2fd0 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -351,3 +351,26 @@ def test_packing(self): regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) + + def test_sharegpt_deprecation(self): + cfg = DictDefault( + {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]} + ) + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "`type: sharegpt:chat` will soon be deprecated." in record.message + for record in self._caplog.records + ) + assert cfg.datasets[0].type == "sharegpt" + + cfg = DictDefault( + {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]} + ) + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "`type: sharegpt_simple` will soon be deprecated." in record.message + for record in self._caplog.records + ) + assert cfg.datasets[0].type == "sharegpt:load_role"