From c0a185b390a48f5f82e51d207528bd8a5f279247 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 16 Aug 2023 11:25:05 +0800 Subject: [PATCH] [test] fix gemini checkpoint io test --- .../test_gemini_checkpoint_io.py | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 43fdcb21df2e..6720be58490b 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -18,12 +18,45 @@ ) from tests.kit.model_zoo import model_zoo +MODEL_PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 1.0 + }, # zero3 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.5 + }, # zero3-half +] + +OPTIM_PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 1.0 + }, # zero2-offload + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.5 + }, # zero2-offload-half +] + @clear_cache_before_run() -@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('placement_config', MODEL_PLACEMENT_CONFIGS) @parameterize('model_name', ['transformers_bert_for_sequence_classification']) @parameterize('use_safetensors', [False, True]) -def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() @@ -32,7 +65,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b pretrained_path = os.path.join(tempdir, 'pretrained') bert_model.config.save_pretrained(save_directory=pretrained_path) - plugin = GeminiPlugin(placement_policy=placement_policy) + plugin = GeminiPlugin(**placement_config) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -51,14 +84,14 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b @clear_cache_before_run() -@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS) @parameterize('shard', [False, True]) @parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) -def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int): +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14)) + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14)) booster = Booster(plugin=plugin) model = model_fn()