From a5895dc5a6cfd3d1ab1a18b0c7a3f29da5fdf5ec Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 22 Jan 2024 12:07:14 +0100 Subject: [PATCH] Set default sample size for linear models in noise-free and serial case to 2 * n + 2 (#14) --- src/tranquilo/options.py | 10 +++++++-- src/tranquilo/process_arguments.py | 8 +++++-- tests/test_process_arguments.py | 35 ++++++++++++++++++++++++++---- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/src/tranquilo/options.py b/src/tranquilo/options.py index 17b9478..c83eced 100644 --- a/src/tranquilo/options.py +++ b/src/tranquilo/options.py @@ -39,11 +39,17 @@ def get_default_acceptance_decider(noisy): return "noisy" if noisy else "classic" -def get_default_sample_size(model_type, x): +def get_default_sample_size(model_type, x, noisy, batch_size): if model_type == "quadratic": out = 2 * len(x) + 1 else: - out = len(x) + 1 + # Use one point more for the standard least-squares case. Benchmarks have not + # shown an improved performance for the noisy or parallel case with one + # additional point. + if noisy or batch_size > 1: + out = len(x) + 1 + else: + out = len(x) + 2 return out diff --git a/src/tranquilo/process_arguments.py b/src/tranquilo/process_arguments.py index 4bb9821..579fd65 100644 --- a/src/tranquilo/process_arguments.py +++ b/src/tranquilo/process_arguments.py @@ -165,6 +165,8 @@ def process_arguments( sample_size=sample_size, model_type=model_type, x=x, + noisy=noisy, + batch_size=batch_size, ) model_fitter = _process_model_fitter( model_fitter, model_type=model_type, sample_size=target_sample_size, x=x @@ -285,9 +287,11 @@ def _process_sample_filter(sample_filter, batch_size): return out -def _process_sample_size(sample_size, model_type, x): +def _process_sample_size(sample_size, model_type, x, noisy, batch_size): if sample_size is None: - out = get_default_sample_size(model_type=model_type, x=x) + out = get_default_sample_size( + model_type=model_type, x=x, noisy=noisy, batch_size=batch_size + ) elif callable(sample_size): out = sample_size(x=x, model_type=model_type) else: diff --git a/tests/test_process_arguments.py b/tests/test_process_arguments.py index 4e4af10..7841593 100644 --- a/tests/test_process_arguments.py +++ b/tests/test_process_arguments.py @@ -41,15 +41,42 @@ def test_process_batch_size_invalid(): def test_process_sample_size(): x = np.arange(3) - assert _process_sample_size(sample_size=None, model_type="linear", x=x) == 4 - assert _process_sample_size(sample_size=None, model_type="quadratic", x=x) == 7 - assert _process_sample_size(10, None, None) == 10 + assert ( + _process_sample_size( + sample_size=None, model_type="linear", x=x, noisy=True, batch_size=1 + ) + == 4 + ) + assert ( + _process_sample_size( + sample_size=None, model_type="linear", x=x, noisy=False, batch_size=2 + ) + == 4 + ) + assert ( + _process_sample_size( + sample_size=None, model_type="linear", x=x, noisy=False, batch_size=1 + ) + == 5 + ) + assert ( + _process_sample_size( + sample_size=None, model_type="quadratic", x=x, noisy=False, batch_size=1 + ) + == 7 + ) + assert _process_sample_size(10, None, None, False, 1) == 10 def test_process_sample_size_callable(): x = np.arange(3) sample_size = lambda x, model_type: len(x) ** 2 - assert _process_sample_size(sample_size=sample_size, model_type="linear", x=x) == 9 + assert ( + _process_sample_size( + sample_size=sample_size, model_type="linear", x=x, noisy=False, batch_size=1 + ) + == 9 + ) def test_process_model_type():