diff --git a/tests/test_cadet_adapter.py b/tests/test_cadet_adapter.py index c9f013119..78df4c9e1 100644 --- a/tests/test_cadet_adapter.py +++ b/tests/test_cadet_adapter.py @@ -7,6 +7,7 @@ import pytest import numpy as np import numpy.testing as npt +from itertools import product from tests.create_LWE import create_lwe @@ -68,6 +69,31 @@ def tearDown(self): 'LumpedRateModelWithoutPores', 'LumpedRateModelWithPores', 'MCT' ] +parameter_combinations = [ + {}, # Default parameters + {"n_par": 1}, + {"n_col": 1}, +] + +# Parameters to skip for specific unit types +exclude_rules = { + 'Cstr': [{"n_col": 1}, {"n_par": 1}], + 'TubularReactor': [{"n_par": 1}], + 'LumpedRateModelWithoutPores': [{"n_par": 1}], + 'LumpedRateModelWithPores': [{"n_par": 1}], + 'MCT': [{"n_par": 1}], + +} + +test_cases = [ + pytest.param( + (unit_type, params), + id=f"{unit_type}-{'default' if not params else '-'.join(f'{k}={v}' for k, v in params.items())}" + ) + for unit_type in unit_types + for params in parameter_combinations + if not (unit_type in exclude_rules and params in exclude_rules[unit_type]) +] def run_simulation( process: Process, @@ -108,28 +134,27 @@ def run_simulation( raise CADETProcessError(f"CADET simulation failed: {e}.") from e -@pytest.fixture(scope="class", params=unit_types) +@pytest.fixture() def process(request: pytest.FixtureRequest): """ Fixture to set up the process for each unit type without running the simulation. """ - unit_type = request.param - process = create_lwe(unit_type) + unit_type, kwargs = request.param + process = create_lwe(unit_type, **kwargs) return process -@pytest.fixture(scope="class", params=unit_types) +@pytest.fixture def simulation_results(request: pytest.FixtureRequest): """ Fixture to set up the simulation for each unit type. """ - unit_type = request.param - process = create_lwe(unit_type) + unit_type, kwargs = request.param + process = create_lwe(unit_type, **kwargs) simulation_results = run_simulation(process, install_path) return simulation_results - -@pytest.mark.parametrize("process", unit_types, indirect=True) +@pytest.mark.parametrize("process", test_cases, indirect=True) class TestProcessWithLWE: def return_process_config(self, process: Process) -> dict: @@ -540,7 +565,7 @@ def test_sensitivity_config(self, process: Process): npt.assert_equal(sensitivity_config, expected_sensitivity_config) -@pytest.mark.parametrize("simulation_results", unit_types, indirect=True) +@pytest.mark.parametrize("simulation_results", test_cases, indirect=True) class TestResultsWithLWE: def test_trigger_simulation(self, simulation_results): """ @@ -621,9 +646,9 @@ def test_compare_solution_shape(self, simulation_results): unit.discretization.npar, process.component_system.n_comp ) - # for units with particle mobiles phase and particle discretization + # for units with particle mobiles phase and without particle discretization else: - # assert solution particle has shape (t, n_col, n_par, n_comp) + # assert solution particle has shape (t, n_col, n_comp) assert simulation_results.solution[unit.name].particle.solution_shape == ( int(process.cycle_time+1), unit.discretization.ncol,