From e0a87afbf9fa33955f11e148dd09e0876533cbf3 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Sun, 1 Dec 2024 02:35:55 -0500 Subject: [PATCH] mock and reuse --- tests/test_events.py | 87 +++++++++----------------------------------- 1 file changed, 18 insertions(+), 69 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 80e929fa37..52a76c2e21 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -63,62 +63,6 @@ def event_counter_callback(): return EventCounterCallback() -@pytest.fixture -def trainer( - model, - optimizer, - train_dataset, - evaluator1, - evaluator2, - event_counter_callback, - request, -): - # extract parameters from the test function - params = request.param - precision = params.get('precision', 'fp32') - max_duration = params.get('max_duration', '1ep') - save_interval = params.get('save_interval', '1ep') - device = params.get('device', 'cpu') - deepspeed_zero_stage = params.get('deepspeed_zero_stage', None) - use_fsdp = params.get('use_fsdp', False) - - deepspeed_config = None - if deepspeed_zero_stage: - deepspeed_config = {'zero_optimization': {'stage': deepspeed_zero_stage}} - - parallelism_config = None - if use_fsdp: - parallelism_config = { - 'fsdp': { - 'sharding_strategy': 'FULL_SHARD', - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - }, - } - - return Trainer( - model=model, - train_dataloader=DataLoader( - dataset=train_dataset, - batch_size=4, - sampler=dist.get_sampler(train_dataset), - num_workers=0, - ), - eval_dataloader=(evaluator1, evaluator2), - device_train_microbatch_size=2, - precision=precision, - train_subset_num_batches=1, - eval_subset_num_batches=1, - max_duration=max_duration, - save_interval=save_interval, - optimizers=optimizer, - callbacks=[event_counter_callback], - device=device, - deepspeed_config=deepspeed_config, - parallelism_config=parallelism_config, - ) - - @pytest.mark.parametrize('event', list(Event)) def test_event_values(event: Event): assert event.name.lower() == event.value @@ -177,9 +121,22 @@ def test_event_calls( event_counter_callback, ): with patch.object(Trainer, 'save_checkpoint', return_value=None): - # mock forward and backward to speed up - with patch.object(model, 'forward', return_value=torch.tensor(0.0)) as mock_forward, \ - patch.object(model, 'backward', return_value=None) as mock_backward: + # mock forward method + with patch.object(model, 'forward', return_value=torch.tensor(0.0)): + # initialize the Trainer with the current parameters + deepspeed_config = None + if deepspeed_zero_stage: + deepspeed_config = {'zero_optimization': {'stage': deepspeed_zero_stage}} + + parallelism_config = None + if use_fsdp: + parallelism_config = { + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD', + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + }, + } trainer_instance = Trainer( model=model, @@ -199,16 +156,8 @@ def test_event_calls( optimizers=optimizer, callbacks=[event_counter_callback], device=device, - deepspeed_config={'zero_optimization': { - 'stage': deepspeed_zero_stage, - }} if deepspeed_zero_stage else None, - parallelism_config={ - 'fsdp': { - 'sharding_strategy': 'FULL_SHARD', - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - }, - } if use_fsdp else None, + deepspeed_config=deepspeed_config, + parallelism_config=parallelism_config, ) trainer_instance.fit()