Skip to content

Commit

Permalink
Merge pull request #20 from fenfisdi/issue/11/update_unit_tests
Browse files Browse the repository at this point in the history
Issue/11/update unit tests
  • Loading branch information
jearistiz authored May 22, 2021
2 parents f2f3e9d + 4f93dc6 commit f6be724
Show file tree
Hide file tree
Showing 18 changed files with 608 additions and 630 deletions.
5 changes: 3 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ name = "pypi"
[packages]
scipy = "*"
numpy = "*"
matplotlib = "*"
pillow = ">=8.1.1"

[dev-packages]
pytest = "*"
Expand All @@ -16,6 +14,9 @@ flake8 = "*"
pytest-cov = "*"
sphinx = "*"
sphinx-rtd-theme = "*"
matplotlib = "*"
pillow = ">=8.1.1"
dinjo = {editable = true, path = "."}

[scripts]
tests = "python -m pytest tests"
Expand Down
366 changes: 164 additions & 202 deletions Pipfile.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions dinjo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Parameter(Variable):
"""
def __init__(
self, name: str, representation: str, initial_value: float = 0,
*args, bounds: Optional[List[float]] = None, **kwargs
bounds: Optional[List[float]] = None, *args, **kwargs
) -> None:
super().__init__(name, representation, initial_value)
self.bounds = bounds if bounds else [initial_value, initial_value]
Expand All @@ -69,18 +69,18 @@ def bounds(self, bounds_input):
)

if not type_check:
raise AttributeError(attr_err_message)
raise ValueError(attr_err_message)

order_check: bool = bounds_input[0] <= bounds_input[1]

if not order_check:
raise AttributeError(attr_err_message)
raise ValueError(attr_err_message)

if not (
bounds_input[0] <= self.initial_value
and bounds_input[1] >= self.initial_value
):
raise AttributeError(init_val_not_in_bounds_range)
raise ValueError(init_val_not_in_bounds_range)

self._bounds = bounds_input

Expand Down Expand Up @@ -151,7 +151,7 @@ def build_model(self, t, y, *args):
time at which the differential equation must be evaluated.
y : list[float]
state vector at which the differential must be evaluated.
\*args : any
*args : any
other parameters of the differential equation
Returns
Expand Down Expand Up @@ -240,7 +240,7 @@ def run_model(
)

parameters_permitted_types = (list, tuple, np.ndarray)
parameters_type_is_permitted = True
parameters_type_is_permitted = False

for permitted_type in parameters_permitted_types:
parameters_type_is_permitted += isinstance(parameters, permitted_type)
Expand Down
8 changes: 4 additions & 4 deletions dinjo/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
reference_state_variable
)
except ValueError:
raise AttributeError(
raise ValueError(
"self.reference_state_variable must be in model.state_variables"
)

Expand All @@ -63,13 +63,13 @@ def reference_t_values(self):
@reference_t_values.setter
def reference_t_values(self, reference_t_values_input: List[float]):
if len(self.reference_values) != len(reference_t_values_input):
raise AttributeError(
raise ValueError(
"self.reference_values and self.reference_t_values must have the same length"
)

for i, t in enumerate(reference_t_values_input[:-1]):
if not t < reference_t_values_input[i + 1]:
raise AttributeError(
raise ValueError(
"self.reference_t_values must be a list of floats in increasing order"
)

Expand All @@ -83,7 +83,7 @@ def reference_t_values(self, reference_t_values_input: List[float]):
)

if not t_span_condition:
raise AttributeError(
raise ValueError(
"self.model.t_span and self.reference_t_values initial and "
"final entries must coincide."
)
Expand Down
11 changes: 5 additions & 6 deletions examples/optimizer_oscillator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from datetime import datetime, timedelta
import os
import sys
from typing import Any, Dict, List, Union
import pickle
from datetime import datetime, timedelta


import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Add project root directory to path
from dinjo import model, optimizer


this_file_dir = os.path.dirname(__file__)
project_root_dir = os.path.join(this_file_dir, '..')
sys.path.append(project_root_dir)

from dinjo import model, optimizer


class ModelOscillator(model.ModelIVP):
Expand Down
17 changes: 6 additions & 11 deletions examples/optimizer_seirv_model_colombia.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Optimize the seirv model for colombia
# This script may take some hours to execute
import sys
# This script may take several minutes to execute
import os
import pickle
from time import time
Expand All @@ -11,19 +10,15 @@
import pandas as pd
from pandas.errors import EmptyDataError

from cmodel_examples_utilities import int_to_str_date, setup_csv

# Add project root and this file's directories to path in order to find cmodel
# package
this_file_dir = os.path.dirname(__file__)
project_root_dir = os.path.join(this_file_dir, '..')
sys.path.append(this_file_dir)
sys.path.append(project_root_dir)

import dinjo.optimizer as optimizer
from seirv_model_colombia import (
seirv_state_variables_colombia, seirv_model_example, infected_reference_col
)
from cmodel_examples_utilities import int_to_str_date, setup_csv


this_file_dir = os.path.dirname(__file__)
project_root_dir = os.path.join(this_file_dir, '..')


def optimizer_seirv_model_colombia_example(
Expand Down
15 changes: 5 additions & 10 deletions examples/seirv_model_colombia.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import sys
import os
from time import time
from typing import Any, Dict, List, Union
from time import time

import matplotlib.pyplot as plt
import pandas as pd

# Add project root and this file's directories to path in order to find cmodel
# package
this_file_dir = os.path.dirname(__file__)
project_root_dir = os.path.join(this_file_dir, '..')
sys.path.append(this_file_dir)
sys.path.append(project_root_dir)

from dinjo import model, predefined
from dinjo import model
from dinjo.predefined.epidemiology import ModelSEIRV
# State variables and parameters are setup in col_vars_params.py module
from examples.col_vars_params import (
Expand All @@ -26,6 +18,9 @@
t_span_col = [0, 171]
t_steps_col = 172

this_file_dir = os.path.dirname(__file__)
project_root_dir = os.path.join(this_file_dir, '..')

infected_reference_col_path = os.path.join(
this_file_dir, '..', 'example_data', 'infected_reference_col.csv'
)
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta:__legacy__"
11 changes: 6 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# requirements. To emit only development requirements, pass "--dev-only".

-i https://pypi.org/simple
-e .
alabaster==0.7.12
attrs==21.2.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
babel==2.9.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
Expand All @@ -21,9 +22,9 @@ flake8==3.9.2
idna==2.10; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
imagesize==1.2.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
iniconfig==1.1.1
jinja2==2.11.3; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
jinja2==3.0.1; python_version >= '3.6'
kiwisolver==1.3.1; python_version >= '3.6'
markupsafe==1.1.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
markupsafe==2.0.1; python_version >= '3.6'
matplotlib==3.4.2
mccabe==0.6.1
numpy==1.20.3
Expand All @@ -45,12 +46,12 @@ scipy==1.6.3
six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
snowballstemmer==2.1.0
sphinx-rtd-theme==0.5.2
sphinx==4.0.1
sphinx==4.0.2
sphinxcontrib-applehelp==1.0.2; python_version >= '3.5'
sphinxcontrib-devhelp==1.0.2; python_version >= '3.5'
sphinxcontrib-htmlhelp==1.0.3; python_version >= '3.5'
sphinxcontrib-htmlhelp==2.0.0; python_version >= '3.6'
sphinxcontrib-jsmath==1.0.1; python_version >= '3.5'
sphinxcontrib-qthelp==1.0.3; python_version >= '3.5'
sphinxcontrib-serializinghtml==1.1.4; python_version >= '3.5'
sphinxcontrib-serializinghtml==1.1.5; python_version >= '3.5'
toml==0.10.2; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'
urllib3==1.26.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' and python_version < '4'
7 changes: 0 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,5 @@
#

-i https://pypi.org/simple
cycler==0.10.0
kiwisolver==1.3.1; python_version >= '3.6'
matplotlib==3.4.2
numpy==1.20.3
pillow==8.2.0
pyparsing==2.4.7; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'
python-dateutil==2.8.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
scipy==1.6.3
six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name='DINJO',
version='0.0.0',
version='0.0.dev1',
description='DINJO lets you find optimal values of initial value problems\' parameters',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
80 changes: 80 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, List
import pytest

from numpy import pi
from numpy.random import random

from dinjo.model import StateVariable, Parameter
from dinjo.optimizer import Optimizer
from dinjo.predefined.physics import ModelOscillator


@pytest.fixture(scope='session')
def ho_state_vars():
# Harmonic Oscillator Initial Value Problem
q = StateVariable(
name='position', representation='q', initial_value=1.0
)
p = StateVariable(
name='momentum', representation='p', initial_value=0.0
)
return [q, p]


@pytest.fixture(scope='session')
def ho_params():
# Define Paramters
omega = Parameter(
name='frequency', representation='w', initial_value=2 * pi, bounds=[4, 8]
)
return [omega]


@pytest.fixture(scope='session')
def t_span():
return [0, 1]


@pytest.fixture(scope='session')
def t_steps():
return 50


@pytest.fixture(scope='session')
def model_oscillator(ho_state_vars, ho_params, t_span, t_steps):
# Instantiate the IVP class with appropiate State Variables and Parameters
return ModelOscillator(
state_variables=ho_state_vars,
parameters=ho_params,
t_span=t_span,
t_steps=t_steps
)


@pytest.fixture(scope='session')
def oscillator_solution(model_oscillator: ModelOscillator):
return model_oscillator.run_model()


@pytest.fixture(scope='session')
def ho_mock_values(oscillator_solution: Any, t_steps: List[float]):
noise_factor = 0.3
return (
oscillator_solution.y[0]
+ (2 * random(t_steps) - 1) * noise_factor
)


@pytest.fixture(scope='session')
def ho_optimizer(
model_oscillator: ModelOscillator,
oscillator_solution: Any,
ho_mock_values: List[float],
ho_state_vars: List[StateVariable]
):
return Optimizer(
model_oscillator,
ho_state_vars[0],
ho_mock_values,
oscillator_solution.t
)
Empty file added tests/predefined/__init__.py
Empty file.
Loading

0 comments on commit f6be724

Please sign in to comment.