diff --git a/data_juicer/config/__init__.py b/data_juicer/config/__init__.py index 4060b6ac7..ba6deb866 100644 --- a/data_juicer/config/__init__.py +++ b/data_juicer/config/__init__.py @@ -1,7 +1,7 @@ -from .config import (export_config, get_init_configs, init_configs, - merge_config, prepare_side_configs) +from .config import (export_config, get_default_cfg, get_init_configs, + init_configs, merge_config, prepare_side_configs) __all__ = [ 'init_configs', 'get_init_configs', 'export_config', 'merge_config', - 'prepare_side_configs' + 'prepare_side_configs', 'get_default_cfg' ] diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 96c8498b5..86f0c411f 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -920,3 +920,36 @@ def get_init_configs(cfg: Union[Namespace, Dict]): json.dump(cfg, f) inited_dj_cfg = init_configs(['--config', temp_file]) return inited_dj_cfg + + +def get_default_cfg(): + """Get default config values from config_all.yaml""" + cfg = Namespace() + + # Get path to config_all.yaml + config_dir = os.path.dirname(os.path.abspath(__file__)) + default_config_path = os.path.join(config_dir, + '../../configs/config_all.yaml') + + # Load default values from yaml + with open(default_config_path, 'r', encoding='utf-8') as f: + defaults = yaml.safe_load(f) + + # Convert to flat dictionary for namespace + flat_defaults = { + 'executor_type': 'default', + 'ray_address': 'auto', + 'suffixes': None, + 'text_keys': 'text', + 'add_suffix': False, + 'export_path': './outputs', + # Add other top-level keys from config_all.yaml + **defaults + } + + # Update cfg with defaults + for key, value in flat_defaults.items(): + if not hasattr(cfg, key): + setattr(cfg, key, value) + + return cfg diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index 8b24ecd22..6d4fdc636 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -36,7 +36,7 @@ def __init__(self, cfg: Optional[Namespace] = None): :param cfg: optional jsonargparse Namespace. """ super().__init__(cfg) - self.executor_type = 'local' + self.executor_type = 'default' self.work_dir = self.cfg.work_dir self.tracer = None diff --git a/demos/data_mixture/app.py b/demos/data_mixture/app.py index cc8efefb6..c05358806 100644 --- a/demos/data_mixture/app.py +++ b/demos/data_mixture/app.py @@ -2,8 +2,8 @@ import pandas as pd import streamlit as st - from data_juicer.core.data.dataset_builder import DatasetBuilder +from data_juicer.config import get_default_cfg if st.__version__ >= '1.23.0': data_editor = st.data_editor @@ -96,7 +96,10 @@ def mix_dataset(): ' '.join([str(weight), ds_file]) for ds_file, weight in zip(ds_files, weights) ]) - df = pd.DataFrame(DatasetBuilder(data_path).load_dataset()) + cfg = get_default_cfg() + cfg.dataset_path = data_path + dataset_builder = DatasetBuilder(cfg) + df = pd.DataFrame(dataset_builder.load_dataset()) st.session_state.dataset = df else: diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index 9ae5bd55e..87f3ff85c 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -5,7 +5,7 @@ from jsonargparse import Namespace -from data_juicer.config import init_configs +from data_juicer.config import init_configs, get_default_cfg from data_juicer.ops import load_ops from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -276,5 +276,31 @@ def test_op_params_parsing(self): self.assertIn(base_param_key, params) + def test_get_default_cfg(self): + """Test getting default configuration from config_all.yaml""" + # Get default config + cfg = get_default_cfg() + + # Verify basic default values + self.assertIsInstance(cfg, Namespace) + + # Test essential defaults + self.assertEqual(cfg.executor_type, 'default') + self.assertEqual(cfg.ray_address, 'auto') + self.assertEqual(cfg.text_keys, 'text') + self.assertEqual(cfg.add_suffix, False) + self.assertEqual(cfg.export_path, '/path/to/result/dataset.jsonl') + self.assertEqual(cfg.suffixes, []) + + # Test other important defaults from config_all.yaml + self.assertTrue(hasattr(cfg, 'np')) # Number of processes + self.assertTrue(hasattr(cfg, 'use_cache')) # Cache usage flag + self.assertTrue(hasattr(cfg, 'temp_dir')) # Temporary directory + + # Test default values are of correct type + self.assertIsInstance(cfg.executor_type, str) + self.assertIsInstance(cfg.add_suffix, bool) + self.assertIsInstance(cfg.export_path, str) + if __name__ == '__main__': unittest.main()