Skip to content

Commit

Permalink
Improve pytest's test collection time
Browse files Browse the repository at this point in the history
Ref. eng/recordflux/RecordFlux#1462
  • Loading branch information
treiher committed Nov 16, 2023
1 parent e7fe635 commit 8ed437c
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 289 deletions.
35 changes: 23 additions & 12 deletions tests/data/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from __future__ import annotations

from collections.abc import Callable
from functools import lru_cache
from typing import Final

from rflx.error import Location
from rflx.expression import (
Expand Down Expand Up @@ -621,10 +623,13 @@ def universal_options() -> Sequence:
return Sequence("Universal::Options", universal_option())


UNIVERSAL_MESSAGE_ID: Final = ID("Universal::Message")


@lru_cache
def universal_message() -> Message:
return Message(
"Universal::Message",
UNIVERSAL_MESSAGE_ID,
[
Link(INITIAL, Field("Message_Type")),
Link(
Expand Down Expand Up @@ -826,16 +831,22 @@ def session() -> Session:
)


def spark_test_models() -> list[Model]:
"""Return models corresponding to generated code in tests/spark/generated."""
def spark_test_models() -> list[Callable[[], Model]]:
"""
Return callables that create models corresponding to generated code in tests/spark/generated.
Using callable functions instead of the models directly enables the caller to postpone the
time-consuming creation of the models to a later time. For instance, when using this function to
parameterize a test function, no model creation is necessary during collection time.
"""
return [
derivation_model(),
enumeration_model(),
ethernet_model(),
expression_model(),
null_message_in_tlv_message_model(),
null_model(),
sequence_model(),
tlv_model(),
Model(fixed_size_simple_message().dependencies),
derivation_model,
enumeration_model,
ethernet_model,
expression_model,
null_message_in_tlv_message_model,
null_model,
sequence_model,
tlv_model,
lambda: Model(fixed_size_simple_message().dependencies),
]
95 changes: 49 additions & 46 deletions tests/unit/generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing as ty
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -154,8 +155,8 @@ def test_generate_partial_update(tmp_path: Path) -> None:


@pytest.mark.parametrize("model", models.spark_test_models())
def test_equality(model: Model, tmp_path: Path) -> None:
assert_equal_code(model, Integration(), GENERATED_DIR, tmp_path, accept_extra_files=True)
def test_equality(model: Callable[[], Model], tmp_path: Path) -> None:
assert_equal_code(model(), Integration(), GENERATED_DIR, tmp_path, accept_extra_files=True)


@pytest.mark.parametrize("embedded", [True, False])
Expand Down Expand Up @@ -295,24 +296,26 @@ def test_prefixed_type_identifier() -> None:
assert common.prefixed_type_identifier(ID(t), "P") == t.name


DUMMY_SESSION = ir.Session(
identifier=ID("P::S"),
states=[
ir.State(
"State",
[ir.Transition("Final", ir.ComplexExpr([], ir.BoolVal(value=True)), None, None)],
None,
[],
None,
None,
),
],
declarations=[],
parameters=[],
types={t.identifier: t for t in models.universal_model().types},
location=None,
variable_id=id_generator(),
)
@lru_cache
def dummy_session() -> ir.Session:
return ir.Session(
identifier=ID("P::S"),
states=[
ir.State(
"State",
[ir.Transition("Final", ir.ComplexExpr([], ir.BoolVal(value=True)), None, None)],
None,
[],
None,
None,
),
],
declarations=[],
parameters=[],
types={t.identifier: t for t in models.universal_model().types},
location=None,
variable_id=id_generator(),
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -376,8 +379,8 @@ def test_session_create_abstract_function(
expected: Sequence[ada.SubprogramDeclaration],
) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand Down Expand Up @@ -480,8 +483,8 @@ def test_session_create_abstract_functions_error(
error_msg: str,
) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand Down Expand Up @@ -711,10 +714,10 @@ def test_session_evaluate_declarations(
session_global: bool,
expected: EvaluatedDeclaration,
) -> None:
allocator = AllocatorGenerator(DUMMY_SESSION, Integration())
allocator = AllocatorGenerator(dummy_session(), Integration())

allocator._allocation_slots[Location((1, 1))] = 1 # noqa: SLF001
session_generator = SessionGenerator(DUMMY_SESSION, allocator, debug=Debug.BUILTIN)
session_generator = SessionGenerator(dummy_session(), allocator, debug=Debug.BUILTIN)
assert (
session_generator._evaluate_declarations( # noqa: SLF001
[declaration],
Expand Down Expand Up @@ -978,10 +981,10 @@ def test_session_declare(
expected: EvaluatedDeclarationStr,
) -> None:
loc: Location = Location((1, 1))
allocator = AllocatorGenerator(DUMMY_SESSION, Integration())
allocator = AllocatorGenerator(dummy_session(), Integration())

allocator._allocation_slots[loc] = 1 # noqa: SLF001
session_generator = SessionGenerator(DUMMY_SESSION, allocator, debug=Debug.BUILTIN)
session_generator = SessionGenerator(dummy_session(), allocator, debug=Debug.BUILTIN)

result = session_generator._declare( # noqa: SLF001
ID("X"),
Expand Down Expand Up @@ -1073,8 +1076,8 @@ def test_session_declare_error(
error_msg: str,
) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand Down Expand Up @@ -1268,8 +1271,8 @@ def _update_str(self) -> None:
],
)
def test_session_state_action(action: ir.Stmt, expected: str) -> None:
allocator = AllocatorGenerator(DUMMY_SESSION, Integration())
session_generator = SessionGenerator(DUMMY_SESSION, allocator, debug=Debug.BUILTIN)
allocator = AllocatorGenerator(dummy_session(), Integration())
session_generator = SessionGenerator(dummy_session(), allocator, debug=Debug.BUILTIN)

allocator._allocation_slots[Location((1, 1))] = 1 # noqa: SLF001
assert (
Expand Down Expand Up @@ -1323,8 +1326,8 @@ def test_session_state_action_error(
error_msg: str,
) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand Down Expand Up @@ -1696,8 +1699,8 @@ def test_session_assign_error(
error_type: type[BaseError],
error_msg: str,
) -> None:
allocator = AllocatorGenerator(DUMMY_SESSION, Integration())
session_generator = SessionGenerator(DUMMY_SESSION, allocator, debug=Debug.BUILTIN)
allocator = AllocatorGenerator(dummy_session(), Integration())
session_generator = SessionGenerator(dummy_session(), allocator, debug=Debug.BUILTIN)
alloc_id = Location((1, 1))

allocator._allocation_slots[alloc_id] = 1 # noqa: SLF001
Expand Down Expand Up @@ -1773,8 +1776,8 @@ def test_session_append_error(
error_msg: str,
) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand Down Expand Up @@ -1816,8 +1819,8 @@ def test_session_append_error(
)
def test_session_read_error(read: ir.Read, error_type: type[BaseError], error_msg: str) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand Down Expand Up @@ -1847,8 +1850,8 @@ def test_session_read_error(read: ir.Read, error_type: type[BaseError], error_ms
)
def test_session_write_error(write: ir.Write, error_type: type[BaseError], error_msg: str) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand Down Expand Up @@ -1908,8 +1911,8 @@ def test_session_write_error(write: ir.Write, error_type: type[BaseError], error
)
def test_session_to_ada_expr(expression: ir.Expr, expected: ada.Expr) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand All @@ -1932,8 +1935,8 @@ def test_session_to_ada_expr_equality(
expected: ada.Expr,
) -> None:
session_generator = SessionGenerator(
DUMMY_SESSION,
AllocatorGenerator(DUMMY_SESSION, Integration()),
dummy_session(),
AllocatorGenerator(dummy_session(), Integration()),
debug=Debug.BUILTIN,
)

Expand Down
Loading

0 comments on commit 8ed437c

Please sign in to comment.