Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 9, 2023
1 parent 547ec21 commit 9819456
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import patch
import datetime
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -74,10 +75,13 @@ def test_get_eval_parameters():
with pytest.raises(
Exception,
match='Missing the following required parameters for async eval:'):
AsyncEval.get_eval_parameters(None, {}, RUN_NAME)
AsyncEval.get_eval_parameters(None, {}, RUN_NAME) # type: ignore

# minimal example
params = AsyncEval.get_eval_parameters(None, BASIC_PARAMS, RUN_NAME)
params = AsyncEval.get_eval_parameters(
None, # type: ignore
BASIC_PARAMS,
RUN_NAME)
assert params == {
'device_eval_batch_size':
2,
Expand Down Expand Up @@ -105,7 +109,7 @@ def test_get_eval_parameters():

# maximal example
params2 = AsyncEval.get_eval_parameters(
None,
None, # type: ignore
{
# required
**BASIC_PARAMS,
Expand Down Expand Up @@ -164,8 +168,8 @@ def test_get_eval_parameters():
name=RUN_NAME,
image='fake-image',
status=RunStatus.RUNNING,
created_at='2021-01-01',
updated_at='2021-01-01',
created_at=datetime.datetime(2021, 1, 1),
updated_at=datetime.datetime(2021, 1, 1),
created_by='me',
priority='low',
preemptible=False,
Expand All @@ -175,7 +179,7 @@ def test_get_eval_parameters():
gpus=16,
cpus=0,
node_count=2,
latest_resumption=None,
latest_resumption=None, # type: ignore
submitted_config=RunConfig(
name=RUN_NAME,
image='fake-image',
Expand All @@ -189,7 +193,8 @@ def test_get_eval_parameters():
return_value=FAKE_RUN)
@patch('llmfoundry.callbacks.async_eval_callback.create_run',
return_value=FAKE_RUN)
def test_async_eval_callback_minimal(mock_create_run, mock_get_run):
def test_async_eval_callback_minimal(mock_create_run: MagicMock,
mock_get_run: MagicMock):
callback = AsyncEval(interval='2ba',
compute={
'cluster': 'c2z3',
Expand Down

0 comments on commit 9819456

Please sign in to comment.