From ca887221465c3bd6bb752d58d71d539866fd6d66 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 17 May 2023 18:31:51 +0200 Subject: [PATCH] Warn if ASA is used in combination with events. --- python/sdist/amici/swig_wrappers.py | 30 ++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/swig_wrappers.py b/python/sdist/amici/swig_wrappers.py index 7146e818ee..2f694938a4 100644 --- a/python/sdist/amici/swig_wrappers.py +++ b/python/sdist/amici/swig_wrappers.py @@ -1,6 +1,8 @@ """Convenience wrappers for the swig interface""" import logging import sys + +import warnings from contextlib import contextmanager, suppress from typing import Any, Dict, List, Optional, Sequence, Union @@ -102,6 +104,17 @@ def runAmiciSimulation( :returns: ReturnData object with simulation results """ + if ( + model.ne > 0 + and solver.getSensitivityMethod() + == amici_swig.SensitivityMethod.adjoint + and solver.getSensitivityOrder() == amici_swig.SensitivityOrder.first + ): + warnings.warn( + "Adjoint sensitivity analysis for models with events with parameter-dependent trigger functions has not been thoroughly tested. " + "Sensitivities might be wrong. Tracked at https://github.com/AMICI-dev/AMICI/issues/18." + ) + with _capture_cstdout(): rdata = amici_swig.runAmiciSimulation( _get_ptr(solver), _get_ptr(edata), _get_ptr(model) @@ -152,10 +165,25 @@ def runAmiciSimulations( :returns: list of simulation results """ + if ( + model.ne > 0 + and solver.getSensitivityMethod() + == amici_swig.SensitivityMethod.adjoint + and solver.getSensitivityOrder() == amici_swig.SensitivityOrder.first + ): + warnings.warn( + "Adjoint sensitivity analysis for models with events with parameter-dependent trigger functions has not been thoroughly tested. " + "Sensitivities might be wrong. Tracked at https://github.com/AMICI-dev/AMICI/issues/18." + ) + with _capture_cstdout(): edata_ptr_vector = amici_swig.ExpDataPtrVector(edata_list) rdata_ptr_list = amici_swig.runAmiciSimulations( - _get_ptr(solver), edata_ptr_vector, _get_ptr(model), failfast, num_threads + _get_ptr(solver), + edata_ptr_vector, + _get_ptr(model), + failfast, + num_threads, ) for rdata in rdata_ptr_list: _log_simulation(rdata)