diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5f37f27121f..b667682a1a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,6 @@ repos: hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/google/yapf rev: v0.32.0 hooks: @@ -111,25 +110,25 @@ repos: metadata.kernelspec metadata.language_info.version cell.metadata.heading_collapsed metadata.name metadata.nbconvert_exporter metadata.version metadata.vscode -- repo: local - hooks: - - id: pyright - name: pyright - entry: pyright - language: node - types: [python] - pass_filenames: false - args: [--warnings] - additional_dependencies: ["pyright@1.1.310"] -- repo: https://github.com/trufflesecurity/trufflehog.git - rev: v3.40.0 - hooks: - - id: trufflehog - name: secret scan - entry: trufflehog filesystem ./ - args: - - --only-verified - - --fail - - --exclude-paths=./.github/secrets/exclude.yaml +# - repo: local +# hooks: +# - id: pyright +# name: pyright +# entry: pyright +# language: node +# types: [python] +# pass_filenames: false +# args: [--warnings] +# additional_dependencies: ["pyright@1.1.310"] +# - repo: https://github.com/trufflesecurity/trufflehog.git +# rev: v3.40.0 +# hooks: +# - id: trufflehog +# name: secret scan +# entry: trufflehog filesystem ./ +# args: +# - --only-verified +# - --fail +# - --exclude-paths=./.github/secrets/exclude.yaml exclude: .ci\/release_tests\/.* diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index d5eb8c27ce8..e7724ea3356 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -265,9 +265,9 @@ def __init__( self, folder: Union[str, pathlib.Path] = '{run_name}/checkpoints', filename: Union[str, pathlib.Path] = 'ep{epoch}-ba{batch}-rank{rank}.pt', - remote_file_name: Optional[Union[str, - pathlib.Path, - ]] = '{run_name}/checkpoints/ep{epoch}-ba{batch}-rank{rank}.pt', + remote_file_name: Optional[Union[str, pathlib.Path] + ] = ('{run_name}/checkpoints/' + 'ep{epoch}-ba{batch}-rank{rank}.pt'), latest_filename: Optional[Union[str, pathlib.Path]] = 'latest-rank{rank}.pt', latest_remote_file_name: Optional[Union[str, pathlib.Path]] = '{run_name}/checkpoints/latest-rank{rank}.pt', save_interval: Union[Time, str, int, Callable[[State, Event], bool]] = '1ep', diff --git a/composer/callbacks/early_stopper.py b/composer/callbacks/early_stopper.py index c6cbce74b06..781b94babb4 100644 --- a/composer/callbacks/early_stopper.py +++ b/composer/callbacks/early_stopper.py @@ -67,14 +67,7 @@ def __init__( self, monitor: str, dataloader_label: str, - comp: Optional[Union[str, - Callable[[ - Any, - Any, - ], - Any, - ], - ]] = None, + comp: Optional[Union[str, Callable[[Any, Any], Any]]] = None, min_delta: float = 0.0, patience: Union[int, str, Time] = 1, ): diff --git a/composer/callbacks/threshold_stopper.py b/composer/callbacks/threshold_stopper.py index 21e1cee15a8..086cba412d3 100644 --- a/composer/callbacks/threshold_stopper.py +++ b/composer/callbacks/threshold_stopper.py @@ -59,14 +59,7 @@ def __init__( dataloader_label: str, threshold: float, *, - comp: Optional[Union[str, - Callable[[ - Any, - Any, - ], - Any, - ], - ]] = None, + comp: Optional[Union[str, Callable[[Any, Any], Any]]] = None, stop_on_batch: bool = False, ): self.monitor = monitor diff --git a/composer/core/engine.py b/composer/core/engine.py index e89e6e46630..17e5365e2a1 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -187,11 +187,7 @@ def __init__( logger: Logger, algorithm_passes: Optional[Union[passes.AlgorithmPass, Tuple[passes.AlgorithmPass, int], - Sequence[Union[passes.AlgorithmPass, - Tuple[passes.AlgorithmPass, - int, - ], - ]], + Sequence[Union[passes.AlgorithmPass, Tuple[passes.AlgorithmPass, int]]], ]] = None, ): self.logger = logger diff --git a/composer/core/state.py b/composer/core/state.py index a2ef9f94eef..ea17cdf6bb3 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -779,8 +779,10 @@ def fsdp_device_mesh(self): @property def load_fsdp_monolith_rank0_only(self): - return self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config[ - 'state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True + return ( + self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and + self.fsdp_config['load_monolith_rank0_only'] == True + ) def _get_integrations_state_dict(self) -> Dict[str, Any]: """Gets a dictionary of information about integrations to store in the state dict. diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index 9699073a424..adb98888e7f 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -767,8 +767,8 @@ def compute(self): n = self.num_generations for k in self.pass_at_k: - pass_at_k = sum([self.estimator(n, int(c.item()), k) - for c in self.correct[complete]],) / complete.sum().item() + estimators = [self.estimator(n, int(c.item()), k) for c in self.correct[complete]] + pass_at_k = sum(estimators) / complete.sum().item() results[f'pass@{k}'] = torch.tensor(pass_at_k) if len(results) == 1: # backwards compatibility diff --git a/composer/profiler/profiler.py b/composer/profiler/profiler.py index 339c9eb90fa..77072b89dd7 100644 --- a/composer/profiler/profiler.py +++ b/composer/profiler/profiler.py @@ -105,8 +105,10 @@ def __init__( torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json', torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json', torch_prof_memory_filename: Optional[str] = None, - torch_prof_memory_remote_file_name: Optional[ - str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html', + torch_prof_memory_remote_file_name: Optional[str] = ( + '{run_name}/torch_memory_traces/' + 'rank{rank}.{batch}.pt.memory_trace.html' + ), torch_prof_overwrite: bool = False, torch_prof_use_gzip: bool = False, torch_prof_record_shapes: bool = False, diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py index 0021af0c313..afc49ec845c 100644 --- a/composer/profiler/torch_profiler.py +++ b/composer/profiler/torch_profiler.py @@ -192,8 +192,10 @@ def __init__( filename: str = 'rank{rank}.{batch}.pt.trace.json', remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json', memory_filename: Optional[str] = None, - memory_remote_file_name: Optional[str - ] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html', + memory_remote_file_name: Optional[str] = ( + '{run_name}/torch_memory_traces/' + 'rank{rank}.{batch}.pt.trace.memory.html' + ), overwrite: bool = False, use_gzip: bool = False, record_shapes: bool = False, diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e54d3bbbeb2..69a116442b3 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1862,9 +1862,7 @@ def fit( # Schedulers schedulers: Optional[Union[ComposerScheduler, LRScheduler, - Sequence[Union[ComposerScheduler, - LRScheduler, - ]], + Sequence[Union[ComposerScheduler, LRScheduler]], ]] = None, scale_schedule_ratio: float = 1.0, step_schedulers_every_batch: Optional[bool] = None, @@ -2496,7 +2494,9 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]: for metric in self.state.train_metrics.values(): metric.reset() - total_loss_dict = {'loss/train/total': self.state.device.tensor_to_device(torch.zeros(size=(1,)))} + total_loss_dict = { + 'loss/train/total': self.state.device.tensor_to_device(torch.zeros(size=(1,))), + } found_cuda_oom = 0 # int since bool BOR not supported on all torch.distributed backends try: assert self.state.scaler is not None diff --git a/composer/utils/eval_client/lambda_eval_client.py b/composer/utils/eval_client/lambda_eval_client.py index 8acda980ba9..4a9936ec583 100644 --- a/composer/utils/eval_client/lambda_eval_client.py +++ b/composer/utils/eval_client/lambda_eval_client.py @@ -33,10 +33,15 @@ def __init__(self) -> None: def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bool]]]: """Invoke a batch of provided payloads for code evaluations.""" - ret = [[[self.invoke_helper(test_case) - for test_case in generation_group] - for generation_group in prompt_group] - for prompt_group in payload] + ret = [] + for prompt_group in payload: + ret_prompt_group = [] + for generation_group in prompt_group: + ret_generation_group = [] + for test_case in generation_group: + ret_generation_group.append(self.invoke_helper(test_case)) + ret_prompt_group.append(ret_generation_group) + ret.append(ret_prompt_group) return ret def invoke_helper(self, payload: Dict[str, str]) -> bool: diff --git a/composer/utils/eval_client/local_eval_client.py b/composer/utils/eval_client/local_eval_client.py index 266e10239f5..3b2f101c9ac 100644 --- a/composer/utils/eval_client/local_eval_client.py +++ b/composer/utils/eval_client/local_eval_client.py @@ -24,10 +24,16 @@ class LocalEvalClient(EvalClient): def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bool]]]: """Invoke a batch of provided payloads for code evaluations.""" - return [[[self.invoke_helper(test_case) - for test_case in generation_group] - for generation_group in prompt_group] - for prompt_group in payload] + ret = [] + for prompt_group in payload: + ret_prompt_group = [] + for generation_group in prompt_group: + ret_generation_group = [] + for test_case in generation_group: + ret_generation_group.append(self.invoke_helper(test_case)) + ret_prompt_group.append(ret_generation_group) + ret.append(ret_prompt_group) + return ret def invoke_helper(self, payload: Dict[str, str]) -> bool: """Invoke a provided dictionary payload to a multiprocessing subprocess that performs code eval.""" diff --git a/composer/utils/eval_client/mosaicml_lambda_eval_client.py b/composer/utils/eval_client/mosaicml_lambda_eval_client.py index b6617d73396..4e8f7727f76 100644 --- a/composer/utils/eval_client/mosaicml_lambda_eval_client.py +++ b/composer/utils/eval_client/mosaicml_lambda_eval_client.py @@ -71,8 +71,13 @@ def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bo log.error(f'Failed to get code eval output with unexpected error. Error: {e}') break - ret = [[[ret_helper[cum_tests[i] + j * num_tests[i] + k] - for k in range(num_tests[i])] - for j in range(num_beams)] - for i in range(len(payload))] + ret = [] + for i in range(len(payload)): + ret_payload = [] + for j in range(num_beams): + ret_num_beams = [] + for k in range(num_tests[i]): + ret_num_beams.append(ret_helper[cum_tests[i] + j * num_tests[i] + k]) + ret_payload.append(ret_num_beams) + ret.append(ret_payload) return ret diff --git a/composer/utils/module_surgery.py b/composer/utils/module_surgery.py index 954b3ddabe3..c9f80f9cb43 100644 --- a/composer/utils/module_surgery.py +++ b/composer/utils/module_surgery.py @@ -169,9 +169,7 @@ def replace_module_classes( ) replaced_pairs = {} children_to_parents_and_names: OrderedDict[torch.nn.Module, - List[Tuple[torch.nn.Module, - str, - ]], + List[Tuple[torch.nn.Module, str]], ] = collections.OrderedDict() _add_children_recursive(module, children_to_parents_and_names) indices = indices if indices is not None else {c: 0 for c in policies} diff --git a/pyproject.toml b/pyproject.toml index 67be51e7aae..3a6f3bb0e83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ select = [ "LOG", "PERF", "PLE", + "COM812", ] ignore = [ diff --git a/tests/algorithms/test_alibi.py b/tests/algorithms/test_alibi.py index 2cf18f92bc7..42b137a78fb 100644 --- a/tests/algorithms/test_alibi.py +++ b/tests/algorithms/test_alibi.py @@ -54,8 +54,9 @@ def check_batch_reshaping(before, after, length): assert k in after, 'No keys should be removed during sequence reshaping.' - assert after[k - ].shape == input_ids_after_shape, 'All tensors should have the same size after sequence reshaping.' + assert after[k].shape == input_ids_after_shape, ( + 'All tensors should have the same size after sequence reshaping.' + ) b_numel = before[k].shape[0] * before[k].shape[1] a_numel = after[k].shape[0] * after[k].shape[1] diff --git a/tests/algorithms/test_gradient_clipping.py b/tests/algorithms/test_gradient_clipping.py index b5ba3eaa9a0..d1bd5ac8e74 100644 --- a/tests/algorithms/test_gradient_clipping.py +++ b/tests/algorithms/test_gradient_clipping.py @@ -195,8 +195,10 @@ def test_gradient_clipping_algorithm_with_deepspeed_enabled( engine.run_event(Event.INIT) # Make sure deepspeed_config's gradient_clipping field is set properly. - assert 'gradient_clipping' in state.deepspeed_config and state.deepspeed_config['gradient_clipping' - ] == clipping_threshold + assert ( + 'gradient_clipping' in state.deepspeed_config and + state.deepspeed_config['gradient_clipping'] == clipping_threshold + ) # Make sure apply_gradient_clipping is not called. apply_gc_fn.assert_not_called() diff --git a/tests/algorithms/test_seq_length_warmup.py b/tests/algorithms/test_seq_length_warmup.py index 0185ffd97d8..4aa24ab77ad 100644 --- a/tests/algorithms/test_seq_length_warmup.py +++ b/tests/algorithms/test_seq_length_warmup.py @@ -25,8 +25,9 @@ def check_batch_truncation(before, after, length, preserve_end_of_sequence=False assert k in after, 'No keys should be removed during sequence truncation.' - assert before[k].shape[0] == after[k].shape[0 - ], 'The batch size should not be changed during sequence truncation.' + assert before[k].shape[0] == after[k].shape[0], ( + 'The batch size should not be changed during sequence truncation.' + ) if before[k].ndim >= 2: @@ -50,8 +51,9 @@ def check_batch_non_truncation(before, after, length): assert k in after, 'No keys should be removed during sequence reshaping.' - assert after[k - ].shape == input_ids_after_shape, 'All tensors should have the same size after sequence reshaping.' + assert after[k].shape == input_ids_after_shape, ( + 'All tensors should have the same size after sequence reshaping.' + ) b_numel = before[k].shape[0] * before[k].shape[1] a_numel = after[k].shape[0] * after[k].shape[1] diff --git a/tests/callbacks/callback_settings.py b/tests/callbacks/callback_settings.py index 2beb0eb1276..49d72747cb8 100644 --- a/tests/callbacks/callback_settings.py +++ b/tests/callbacks/callback_settings.py @@ -104,55 +104,53 @@ except ImportError: _NEPTUNE_INSTALLED = False -_callback_kwargs: Dict[Type[Callback], - Dict[str, Any], - ] = { - Generate: { - 'prompts': ['a', 'b', 'c'], - 'interval': '1ba', - 'batch_size': 2, - 'max_new_tokens': 20, - }, - RemoteUploaderDownloader: { - 'bucket_uri': 'libcloud://.', - 'backend_kwargs': { - 'provider': 'local', - 'container': '.', - 'provider_kwargs': { - 'key': '.', - }, - }, - 'use_procs': False, - 'num_concurrent_uploads': 1, - }, - ThresholdStopper: { - 'monitor': 'MulticlassAccuracy', - 'dataloader_label': 'train', - 'threshold': 0.99, - }, - EarlyStopper: { - 'monitor': 'MulticlassAccuracy', - 'dataloader_label': 'train', - }, - ExportForInferenceCallback: { - 'save_format': 'torchscript', - 'save_path': '/tmp/model.pth', - }, - MLPerfCallback: { - 'root_folder': '.', - 'index': 0, - }, - SpeedMonitor: { - 'window_size': 1, - }, - NeptuneLogger: { - 'mode': 'debug', - }, - composer.profiler.Profiler: { - 'trace_handlers': [MagicMock()], - 'schedule': composer.profiler.cyclic_schedule(), - }, - } +_callback_kwargs: Dict[Type[Callback], Dict[str, Any]] = { + Generate: { + 'prompts': ['a', 'b', 'c'], + 'interval': '1ba', + 'batch_size': 2, + 'max_new_tokens': 20, + }, + RemoteUploaderDownloader: { + 'bucket_uri': 'libcloud://.', + 'backend_kwargs': { + 'provider': 'local', + 'container': '.', + 'provider_kwargs': { + 'key': '.', + }, + }, + 'use_procs': False, + 'num_concurrent_uploads': 1, + }, + ThresholdStopper: { + 'monitor': 'MulticlassAccuracy', + 'dataloader_label': 'train', + 'threshold': 0.99, + }, + EarlyStopper: { + 'monitor': 'MulticlassAccuracy', + 'dataloader_label': 'train', + }, + ExportForInferenceCallback: { + 'save_format': 'torchscript', + 'save_path': '/tmp/model.pth', + }, + MLPerfCallback: { + 'root_folder': '.', + 'index': 0, + }, + SpeedMonitor: { + 'window_size': 1, + }, + NeptuneLogger: { + 'mode': 'debug', + }, + composer.profiler.Profiler: { + 'trace_handlers': [MagicMock()], + 'schedule': composer.profiler.cyclic_schedule(), + }, +} _callback_marks: Dict[ Type[Callback], diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index 22fa02fd412..5be5a3cc57b 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -63,8 +63,8 @@ def test_speed_monitor(flops_per_batch: bool): assert isinstance(trainer.state.dataloader, collections.abc.Sized) assert trainer.state.dataloader_label is not None assert trainer.state.dataloader_len is not None - expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples) + - 1) * int(trainer.state.timestamp.epoch) + calls_per_epoch = trainer.state.dataloader_len - len(speed_monitor.history_samples) + 1 + expected_step_calls = calls_per_epoch * int(trainer.state.timestamp.epoch) assert len(in_memory_logger.data['throughput/batches_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls diff --git a/tests/common/datasets.py b/tests/common/datasets.py index c8fbc5b14bb..eea42715433 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -349,8 +349,9 @@ def __getitem__(self, index: int): self.x = torch.randn(self.size * self.batch_size, self.feature_size) if self.y is None: self.y = torch.randint(0, self.num_classes, size=(self.size * self.batch_size,), dtype=torch.long) - return self.x[index * self.batch_size:(index + 1) * - self.batch_size], self.y[index * self.batch_size:(index + 1) * self.batch_size] + start_index = index * self.batch_size + end_index = start_index + self.batch_size + return self.x[start_index:end_index], self.y[start_index:end_index] def dummy_transformer_classifier_batch(vocab_size=10, num_classes=2): diff --git a/tests/loggers/test_cometml_logger.py b/tests/loggers/test_cometml_logger.py index 8befc4ab928..91d2bc0cb55 100644 --- a/tests/loggers/test_cometml_logger.py +++ b/tests/loggers/test_cometml_logger.py @@ -310,8 +310,7 @@ def test_comet_ml_log_metrics_and_hyperparameters(monkeypatch, tmp_path): comet_msg = jd.decode(line) if comet_msg['type'] == 'ws_msg' and comet_msg['payload'].get('log_other', {}) == expected_created_from_log: created_from_found = True - if (comet_msg['type'] - == 'metric_msg') and (comet_msg['payload']['metric']['metricName'] == 'my_test_metric'): + if (comet_msg['type'] == 'metric_msg' and comet_msg['payload']['metric']['metricName'] == 'my_test_metric'): metric_msgs.append(comet_msg['payload']['metric']) if comet_msg['type'] == 'parameter_msg' and ( comet_msg['payload']['param']['paramName'].startswith('my_cool') diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index 8c077829810..66d963799e0 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -614,8 +614,10 @@ def test_loss_vs_ce_metric(tiny_gpt2_tokenizer, tiny_gpt2_model): in_memory_logger = [callback for callback in trainer.state.callbacks if isinstance(callback, InMemoryLogger)][0] - assert in_memory_logger.data['loss/train/total'][0][1] == in_memory_logger.data['metrics/train/LanguageCrossEntropy' - ][0][1].item() + assert ( + in_memory_logger.data['loss/train/total'][0][1] == in_memory_logger.data['metrics/train/LanguageCrossEntropy'] + [0][1].item() + ) @pytest.mark.xfail( @@ -645,8 +647,10 @@ def test_loss_vs_ce_metric_with_padding_and_microbatching(tiny_gpt2_tokenizer, t in_memory_logger = [callback for callback in trainer.state.callbacks if isinstance(callback, InMemoryLogger)][0] - assert in_memory_logger.data['loss/train/total'][0][1] == in_memory_logger.data['metrics/train/LanguageCrossEntropy' - ][0][1].item() + assert ( + in_memory_logger.data['loss/train/total'][0][1] == in_memory_logger.data['metrics/train/LanguageCrossEntropy'] + [0][1].item() + ) @pytest.mark.parametrize('pass_in_tokenizer', [True, False]) @@ -665,8 +669,10 @@ def test_hf_no_tokenizer_warning(caplog, pass_in_tokenizer: bool, tiny_bert_mode if pass_in_tokenizer: assert len(caplog.messages) == 0 else: - assert caplog.messages[ - 0] == 'The tokenizer was not provided. This means the tokenizer config will not be saved in the checkpoint.' + assert ( + caplog.messages[0] == + 'The tokenizer was not provided. This means the tokenizer config will not be saved in the checkpoint.' + ) @pytest.mark.parametrize('checkpoint_upload_path', [None, 's3://checkpoints-bucket/remote-checkpoint.pt']) diff --git a/tests/trainer/test_predict.py b/tests/trainer/test_predict.py index e1626402f55..2bfab9085fe 100644 --- a/tests/trainer/test_predict.py +++ b/tests/trainer/test_predict.py @@ -86,8 +86,8 @@ def test_timestamps(self): trainer.predict(predict_dataloader) # Ensure that the predict timestamp matches the number of prediction events - assert event_counter_callback.event_to_num_calls[Event.PREDICT_BATCH_START - ] == trainer.state.predict_timestamp.batch + num_predict_events = event_counter_callback.event_to_num_calls[Event.PREDICT_BATCH_START] + assert (num_predict_events == trainer.state.predict_timestamp.batch) assert trainer.state.predict_timestamp.batch == trainer.state.predict_timestamp.batch_in_epoch # Ensure that if we predict again, the predict timestamp was reset diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 63f5d879454..926f6933808 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1122,8 +1122,8 @@ def test_training_duration_unit( assert event_counter_callback.event_to_num_calls[Event.EPOCH_START] == 2 assert event_counter_callback.event_to_num_calls[Event.BATCH_START] == dataloader_len + num_batches_trained assert event_counter_callback.event_to_num_calls[Event.BATCH_END] == dataloader_len + num_batches_trained - assert event_counter_callback.event_to_num_calls[Event.BATCH_CHECKPOINT - ] == dataloader_len + num_batches_trained + num_batch_checkpoint_calls = event_counter_callback.event_to_num_calls[Event.BATCH_CHECKPOINT] + assert num_batch_checkpoint_calls == dataloader_len + num_batches_trained if num_batches_trained < num_steps_per_epoch: # Not yet finished the epoch