From 8ffea7f11bc2586dfac918030fdc8aede8b361b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Fri, 16 Aug 2024 20:15:23 -0400 Subject: [PATCH] BF: Set the superclass `fit_method` param value to the one provided Set the superclass initialization method `fit_method` parameter value to the one provided to the child's initialization method. Add a test to check that the fit method matches the model instance fit method. --- dipy/reconst/dki_micro.py | 4 +++- dipy/reconst/tests/test_dki_micro.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/dipy/reconst/dki_micro.py b/dipy/reconst/dki_micro.py index d66b8bad29..08b4312772 100644 --- a/dipy/reconst/dki_micro.py +++ b/dipy/reconst/dki_micro.py @@ -356,7 +356,9 @@ def __init__(self, gtab, fit_method="WLS", *args, **kwargs): ---------- .. footbibliography:: """ - DiffusionKurtosisModel.__init__(self, gtab, fit_method="WLS", *args, **kwargs) + DiffusionKurtosisModel.__init__( + self, gtab, fit_method=fit_method, *args, **kwargs + ) def fit(self, data, mask=None, sphere="repulsion100", gtol=1e-2, awf_only=False): """Fit method of the Diffusion Kurtosis Microstructural Model diff --git a/dipy/reconst/tests/test_dki_micro.py b/dipy/reconst/tests/test_dki_micro.py index 33e185e65a..064d490988 100644 --- a/dipy/reconst/tests/test_dki_micro.py +++ b/dipy/reconst/tests/test_dki_micro.py @@ -10,13 +10,19 @@ assert_array_almost_equal, assert_raises, ) +import pytest from dipy.core.gradients import gradient_table from dipy.data import default_sphere, get_fnames, get_sphere from dipy.io.gradients import read_bvals_bvecs +from dipy.reconst.dki import common_fit_methods import dipy.reconst.dki_micro as dki_micro from dipy.reconst.dti import eig_from_lo_tri from dipy.sims.voxel import _check_directions, multi_tensor, multi_tensor_dki +from dipy.utils.optpkg import optional_package + +cvxpy, have_cvxpy, _ = optional_package("cvxpy", min_version="1.4.1") +needs_cvxpy = pytest.mark.skipif(not have_cvxpy, reason="Requires CVXPY") gtab_2s, DWIsim, DWIsim_all_taylor = None, None, None FIE, RDI, ADI, ADE, Tor, RDE = None, None, None, None, None, None @@ -91,6 +97,18 @@ def teardown_module(): FIE, RDI, ADI, ADE, Tor, RDE = None, None, None, None, None, None +@needs_cvxpy +def test_fit_selection(): + for model in [ + dki_micro.KurtosisMicrostructureModel, + dki_micro.DiffusionKurtosisModel, + ]: + for name, method in common_fit_methods.items(): + model_instance = model(gtab_2s, fit_method=name) + + assert model_instance.fit_method == method + + def test_single_fiber_model(): # single fiber simulate (which is the assumption of our model) fie = 0.49