From 9819456bf7443b86682932dee5da9e5ac6dd4878 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 14:42:42 -0800 Subject: [PATCH] fix typing --- tests/callbacks/test_async_eval_callback.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index caf5e72868..274d80b9d8 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -1,7 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import patch +import datetime +from unittest.mock import MagicMock, patch import pytest @@ -74,10 +75,13 @@ def test_get_eval_parameters(): with pytest.raises( Exception, match='Missing the following required parameters for async eval:'): - AsyncEval.get_eval_parameters(None, {}, RUN_NAME) + AsyncEval.get_eval_parameters(None, {}, RUN_NAME) # type: ignore # minimal example - params = AsyncEval.get_eval_parameters(None, BASIC_PARAMS, RUN_NAME) + params = AsyncEval.get_eval_parameters( + None, # type: ignore + BASIC_PARAMS, + RUN_NAME) assert params == { 'device_eval_batch_size': 2, @@ -105,7 +109,7 @@ def test_get_eval_parameters(): # maximal example params2 = AsyncEval.get_eval_parameters( - None, + None, # type: ignore { # required **BASIC_PARAMS, @@ -164,8 +168,8 @@ def test_get_eval_parameters(): name=RUN_NAME, image='fake-image', status=RunStatus.RUNNING, - created_at='2021-01-01', - updated_at='2021-01-01', + created_at=datetime.datetime(2021, 1, 1), + updated_at=datetime.datetime(2021, 1, 1), created_by='me', priority='low', preemptible=False, @@ -175,7 +179,7 @@ def test_get_eval_parameters(): gpus=16, cpus=0, node_count=2, - latest_resumption=None, + latest_resumption=None, # type: ignore submitted_config=RunConfig( name=RUN_NAME, image='fake-image', @@ -189,7 +193,8 @@ def test_get_eval_parameters(): return_value=FAKE_RUN) @patch('llmfoundry.callbacks.async_eval_callback.create_run', return_value=FAKE_RUN) -def test_async_eval_callback_minimal(mock_create_run, mock_get_run): +def test_async_eval_callback_minimal(mock_create_run: MagicMock, + mock_get_run: MagicMock): callback = AsyncEval(interval='2ba', compute={ 'cluster': 'c2z3',