Skip to content

Commit

Permalink
Merge pull request #8 from artificial-life-lab/inference
Browse files Browse the repository at this point in the history
lokta-volterra inference module
  • Loading branch information
pranjaldhole authored Nov 16, 2024
2 parents bb9de21 + 8a3bfe8 commit c20ee61
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 58 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@ python causal_inference/base/lv_simulator.py
This will take all the default arguments and configuration to run a simulation instance of lotka-volterra population dynamics.

- The simulation statistics will be saved in the `repo/results` directory by default.

## Simulation and inference

- The simulation and inference methods are separately implemented in `repo/causal_inference/base/lotka_volterra/lv_system.py`.
- Currently, this inference method is experimental and may not always converge to correct optimal parameters.
- More work is needed to find a good approximation schema to initiate the parameters of the LV-system.
Empty file.
129 changes: 129 additions & 0 deletions causal_inference/base/lotka_volterra/lv_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

import numpy as np
from scipy import integrate
from scipy.optimize import minimize

from causal_inference.config import LV_PARAMS

class LotkaVolterra():
'''
Base Lotka-Volterra Class that defines a predator-prey system.
'''
def __init__(self,
A=LV_PARAMS['A'], B=LV_PARAMS['B'], C=LV_PARAMS['C'], D=LV_PARAMS['D'],
prey_population=LV_PARAMS['INITIAL_PREY_POPULATION'],
pred_population=LV_PARAMS['INITIAL_PREDATOR_POPULATION'],
total_time=LV_PARAMS['TOTAL_TIME'], step_size=LV_PARAMS['STEP_SIZE'],
max_iter=LV_PARAMS['MAX_ITERATIONS']):
# Lotka-Volterra parameters
self.A = A
self.B = B
self.C = C
self.D = D

self.prey_population = prey_population # Initial prey population
self.predator_population = pred_population # Initial predator population

self.init_time = 0 # initial time
self.total_time = total_time # total time in units
self.step_size = step_size # increment for each time step
self.max_iterations = max_iter # tolerance parameter

self.time_stamps = np.arange(self.init_time, self.total_time, self.step_size)

@staticmethod
def LV_derivative(t, Z, A, B, C, D):
'''
Returns the rate of change of predator and prey population
Simulates Lotka-Volterra dynamics
Parameters:
t (list): [t0, tf] initial and final time points for simulation (Not used but necessary for integration step)
Z (tuple): (x, y) state of the system
A: prey growth rate (model parameter)
B: predation rate (model parameter)
C: predator death rate (model parameter)
D: predator growth rate from eating prey (model parameter)
Returns:
array: rate of change of prey and predator population
'''
x, y = Z
dotx = x * (A - B * y)
doty = y * (-C + D * x)
return np.array([dotx, doty])

def simulate_lotka_volterra(params, t, initial_conditions):
"""
Simulates Lotka-Volterra dynamics
Parameters:
params (tuple): (A, B, C, D) model parameters
A: prey growth rate
B: predation rate
C: predator death rate
D: predator growth rate from eating prey
t (array): time points for simulation
initial_conditions (tuple): (prey0, predator0) initial populations
Returns:
population: (n, 2) array where each row is [prey_pop, predator_pop]
"""
A, B, C, D = params

solution = integrate.solve_ivp(LV_derivative, [t[0], t[-1]], initial_conditions,
args=(A, B, C, D), dense_output=True)

population = solution.sol(t)
return population

def fit_lotka_volterra(time_points, observed_data, initial_guess):
"""
Fits Lotka-Volterra parameters to observed population data
Parameters:
time_points (array): time points of observations
observed_data (array): observed population data [prey, predator]
initial_guess (tuple): initial parameter guess (A, B, C, D)
Returns:
tuple: Fitted parameters (A, B, C, D) after optimization
"""
def objective_function(params):
# Simulate with current parameters
simulated = simulate_lotka_volterra(params, time_points,
observed_data[:, 0])
# Calculate mean squared error
mse = np.mean((simulated - observed_data) ** 2)
return mse

# Parameter bounds (all parameters must be positive)
bounds = [(0, None) for _ in range(4)]

# Optimize parameters
result = minimize(objective_function, initial_guess,
bounds=bounds, method='L-BFGS-B')

return result.x

# Example usage:
if __name__ == "__main__":
# Generate synthetic data
m = LotkaVolterra()
true_params = (m.A, m.B, m.C, m.D)
initial_conditions = (m.prey_population, m.predator_population) # (prey0, predator0)

# Generate synthetic data with some noise
data = simulate_lotka_volterra(true_params, m.time_stamps, initial_conditions)
noisy_data = data + np.random.normal(0, 0.1, data.shape)

# Fit parameters
initial_guess = (0.5, 0.05, 0.05, 0.05) #FIXME Use more educated schema for initial guess

fitted_params = fit_lotka_volterra(m.time_stamps, noisy_data, initial_guess)

print("True parameters:", true_params)
print("Fitted parameters:", fitted_params)
14 changes: 9 additions & 5 deletions causal_inference/base/lv_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@ def get_solver(method):
raise AssertionError(f'{method} is not implemented!')
return solver

def simulate_lotka_volterra(method):
log_LV_params()
solver = get_solver(method)
prey_list, predator_list = solver._solve()
return prey_list, predator_list, solver.time_stamps

def main(method, results_dir):
'''
Main function that solves LV system.
'''
log_LV_params()
solver = get_solver(method)
prey_list, predator_list = solver._solve()
_save_population(prey_list, predator_list, solver.time_stamps, results_dir)
plot_population_over_time(prey_list, predator_list, solver.time_stamps, results_dir)
prey_list, predator_list, t = simulate_lotka_volterra(method)
_save_population(prey_list, predator_list, t, results_dir)
plot_population_over_time(prey_list, predator_list, t, results_dir)

if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
Expand Down
31 changes: 0 additions & 31 deletions causal_inference/base/lv_system.py

This file was deleted.

15 changes: 2 additions & 13 deletions causal_inference/base/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# -*- coding: utf-8 -*-

import logging
import numpy as np
from scipy import integrate

from causal_inference.base.lv_system import LotkaVolterra
from causal_inference.base.lotka_volterra.lv_system import LotkaVolterra, LV_derivative

class ODE_solver(LotkaVolterra):
'''
Expand All @@ -15,24 +14,14 @@ def __init__(self):
super().__init__()
logging.info('Simulating Lotka-Volterra predator-prey dynamics with odeint solver')

@staticmethod
def LV_derivative(t, Z, A, B, C, D):
'''
Returns the rate of change of predator and prey population
'''
x, y = Z
dotx = x * (A - B * y)
doty = y * (-C + D * x)
return np.array([dotx, doty])

def _solve(self):
'''
ODE solver that returns the predator and prey populations at each time step in time series.
'''
logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...')

INIT_POP = [self.prey_population, self.predator_population]
sol = integrate.solve_ivp(self.LV_derivative, [self.init_time, self.total_time], INIT_POP, args=(self.A, self.B, self.C, self.D), dense_output=True)
sol = integrate.solve_ivp(LV_derivative, [self.init_time, self.total_time], INIT_POP, args=(self.A, self.B, self.C, self.D), dense_output=True)
prey_list, predator_list = sol.sol(self.time_stamps)

logging.info('done!')
Expand Down
2 changes: 1 addition & 1 deletion causal_inference/base/runge_kutta_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import logging
from causal_inference.base.lv_system import LotkaVolterra
from causal_inference.base.lotka_volterra.lv_system import LotkaVolterra

class RungeKuttaSolver(LotkaVolterra):
'''
Expand Down
14 changes: 7 additions & 7 deletions causal_inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

# Lotka-Volterra Parameters
LV_PARAMS = {
'A' : 10.0,
'B' : 7.0,
'C' : 3.0,
'D' : 5.0,
'A' : 1.0,
'B' : 0.1,
'C' : 0.3,
'D' : 0.4,
'STEP_SIZE' : 0.01,
'TOTAL_TIME' : 20,
'INITIAL_PREY_POPULATION' : 60,
'TOTAL_TIME' : 10,
'INITIAL_PREY_POPULATION' : 40,
'INITIAL_PREDATOR_POPULATION' : 25,
'MAX_ITERATIONS' : 200
'MAX_ITERATIONS' : 100
}

# PATHS
Expand Down
2 changes: 1 addition & 1 deletion causal_inference/tests/test_lv_simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# flake8: noqa
from causal_inference.base.lv_system import LotkaVolterra
from causal_inference.base.lotka_volterra.lv_system import LotkaVolterra
from causal_inference.config import LV_PARAMS

def test_lotka_volterra():
Expand Down

0 comments on commit c20ee61

Please sign in to comment.