Skip to content

Commit

Permalink
Merge pull request #5 from artificial-life-lab/intergrate
Browse files Browse the repository at this point in the history
harmonization of simulation solver method imports
  • Loading branch information
pranjaldhole authored Nov 13, 2024
2 parents c59fd0e + d021afd commit bb9de21
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 66 deletions.
45 changes: 10 additions & 35 deletions causal_inference/base/lv_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,12 @@
import logging
import datetime

import h5py
import matplotlib.pyplot as plt

from causal_inference.config import RESULTS_DIR
from causal_inference.utils.log_config import log_LV_params
from causal_inference.base.ode_solver import ODE_solver
from causal_inference.base.runge_kutta_solver import RungeKuttaSolver

def _save_population(prey_list, predator_list):
filename = os.path.join(RESULTS_DIR, 'populations.h5')
hf = h5py.File(filename, 'w')
hf.create_dataset('prey_pop', data=prey_list)
hf.create_dataset('pred_pop', data=predator_list)
hf.close()

def plot_population_over_time(prey_list, predator_list, save=True, filename='predator_prey'):
fig = plt.figure(figsize=(15, 5))
ax = fig.add_subplot(2, 1, 1)
PreyLine, = plt.plot(prey_list , color='g')
PredatorsLine, = plt.plot(predator_list, color='r')
ax.set_xscale('log')

plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators'])
plt.ylabel('Population')
plt.xlabel('Time')
if save:
plt.savefig(os.path.join(RESULTS_DIR, f"{filename}.svg"),
format='svg', transparent=False, bbox_inches='tight')
else:
plt.show()
plt.close()
from causal_inference.utils.writer import _save_population
from causal_inference.utils.visualisations import plot_population_over_time

def get_solver(method):
'''
Expand All @@ -50,15 +25,15 @@ def get_solver(method):
raise AssertionError(f'{method} is not implemented!')
return solver

def main(method):
def main(method, results_dir):
'''
Main function that solves LV system.
'''
log_LV_params()
solver = get_solver(method)
prey_list, predator_list = solver._solve()
_save_population(prey_list, predator_list)
plot_population_over_time(prey_list, predator_list)
_save_population(prey_list, predator_list, solver.time_stamps, results_dir)
plot_population_over_time(prey_list, predator_list, solver.time_stamps, results_dir)

if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
Expand All @@ -68,12 +43,12 @@ def main(method):
choices=['RK4', 'ODE'], default='RK4')
ARGS = PARSER.parse_args()

RESULTS_DIR = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir)))
results_dir = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir)))

if not os.path.exists(RESULTS_DIR):
os.makedirs(RESULTS_DIR)
if not os.path.exists(results_dir):
os.makedirs(results_dir)

LOG_FILE = os.path.join(RESULTS_DIR, f"{ARGS.logfile}.txt") # write logg to this file
LOG_FILE = os.path.join(results_dir, f"{ARGS.logfile}.txt") # write logg to this file
logging.basicConfig(
level=logging.INFO,
handlers=[
Expand All @@ -82,4 +57,4 @@ def main(method):
]
)
solver = ARGS.solver
main(solver)
main(solver, results_dir)
33 changes: 22 additions & 11 deletions causal_inference/base/lv_system.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

import numpy as np
from causal_inference.config import LV_PARAMS

class LotkaVolterra():
'''
Class simulates predator-prey dynamics and solves it with 4th order Runge-Kutta method.
Base Lotka-Volterra Class that defines a predator-prey system.
'''
def __init__(self):
self.A = LV_PARAMS['A']
self.B = LV_PARAMS['B']
self.C = LV_PARAMS['C']
self.D = LV_PARAMS['D']
self.time = LV_PARAMS['INITIAL_TIME']
self.step_size = LV_PARAMS['STEP_SIZE']
self.max_iterations = LV_PARAMS['MAX_ITERATIONS']
def __init__(self,
A=LV_PARAMS['A'], B=LV_PARAMS['B'], C=LV_PARAMS['C'], D=LV_PARAMS['D'],
prey_population=LV_PARAMS['INITIAL_PREY_POPULATION'],
pred_population=LV_PARAMS['INITIAL_PREDATOR_POPULATION'],
total_time=LV_PARAMS['TOTAL_TIME'], step_size=LV_PARAMS['STEP_SIZE'],
max_iter=LV_PARAMS['MAX_ITERATIONS']):
# Lotka-Volterra parameters
self.A = A
self.B = B
self.C = C
self.D = D

self.prey_population = LV_PARAMS['INITIAL_PREY_POPULATION']
self.predator_population = LV_PARAMS['INITIAL_PREDATOR_POPULATION']
self.prey_population = prey_population # Initial prey population
self.predator_population = pred_population # Initial predator population

self.init_time = 0 # initial time
self.total_time = total_time # total time in units
self.step_size = step_size # increment for each time step
self.max_iterations = max_iter # tolerance parameter

self.time_stamps = np.arange(self.init_time, self.total_time, self.step_size)
31 changes: 19 additions & 12 deletions causal_inference/base/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
from scipy import integrate

from causal_inference.config import LV_PARAMS
from causal_inference.base.lv_system import LotkaVolterra

class ODE_solver(LotkaVolterra):
Expand All @@ -14,20 +13,28 @@ class ODE_solver(LotkaVolterra):
'''
def __init__(self):
super().__init__()
logging.info('Solving Lotka-Volterra predator-prey dynamics odeint solver')
logging.info('Simulating Lotka-Volterra predator-prey dynamics with odeint solver')

@staticmethod
def LV_derivative(X, t, alpha, beta, delta, gamma):
x, y = X
dotx = x * (alpha - beta * y)
doty = y * (-delta + gamma * x)
def LV_derivative(t, Z, A, B, C, D):
'''
Returns the rate of change of predator and prey population
'''
x, y = Z
dotx = x * (A - B * y)
doty = y * (-C + D * x)
return np.array([dotx, doty])

def _solve(self):
logging.info('Computing population over time...')
t = np.arange(0.,self.max_iterations, self.step_size)
X0 = [self.prey_population, self.predator_population]
res = integrate.odeint(self.LV_derivative, X0, t, args=(self.A, self.B, self.C, self.D))
prey_list, predator_list = res.T
'''
ODE solver that returns the predator and prey populations at each time step in time series.
'''
logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...')

INIT_POP = [self.prey_population, self.predator_population]
sol = integrate.solve_ivp(self.LV_derivative, [self.init_time, self.total_time], INIT_POP, args=(self.A, self.B, self.C, self.D), dense_output=True)
prey_list, predator_list = sol.sol(self.time_stamps)

logging.info('done!')
return prey_list, predator_list

return prey_list, predator_list
16 changes: 9 additions & 7 deletions causal_inference/base/runge_kutta_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import logging
from math import ceil
from causal_inference.base.lv_system import LotkaVolterra

class RungeKuttaSolver(LotkaVolterra):
Expand All @@ -11,7 +10,6 @@ class RungeKuttaSolver(LotkaVolterra):
def __init__(self):
super().__init__()
logging.info('Solving Lotka-Volterra predator-prey dynamics with 4th order Runge-Kutta method')
self.time_stamp = [self.time]
self.prey_list = [self.prey_population]
self.predator_list = [self.predator_population]

Expand All @@ -22,8 +20,6 @@ def compute_predator_rate(self, current_prey, current_predators):
return - self.C * current_predators + self.D * current_prey * current_predators

def runge_kutta_update(self, current_prey, current_predators):
self.time = self.time + self.step_size
self.time_stamp.append(self.time)

k1_prey = self.step_size * self.compute_prey_rate(current_prey, current_predators)
k1_pred = self.step_size * self.compute_predator_rate(current_prey, current_predators)
Expand All @@ -46,11 +42,17 @@ def runge_kutta_update(self, current_prey, current_predators):
return new_prey_population, new_predator_population

def _solve(self):
'''
Runge-Kutta solver that returns the predator and prey populations at each time step in time series.
'''
#initial population
current_prey, current_predators = self.prey_population, self.predator_population
logging.info('Computing population over time...')
for gen_idx in range(ceil(self.max_iterations/self.step_size)):

logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...')

for step_idx in self.time_stamps[1:]:
current_prey, current_predators = self.runge_kutta_update(current_prey, current_predators)
msg= f'Gen: {gen_idx} | Prey population: {current_prey} | Predator population: {current_predators}'
msg= f'Step: {step_idx} | Prey population: {current_prey} | Predator population: {current_predators}'
logging.info(msg)
print('Done!')
return self.prey_list, self.predator_list
2 changes: 1 addition & 1 deletion causal_inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
'C' : 3.0,
'D' : 5.0,
'STEP_SIZE' : 0.01,
'INITIAL_TIME' : 0,
'TOTAL_TIME' : 20,
'INITIAL_PREY_POPULATION' : 60,
'INITIAL_PREDATOR_POPULATION' : 25,
'MAX_ITERATIONS' : 200
Expand Down
21 changes: 21 additions & 0 deletions causal_inference/utils/visualisations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from os.path import join
import matplotlib.pyplot as plt

def plot_population_over_time(prey_list, predator_list, time_stamps, results_dir, save=True, filename='predator_prey'):
fig = plt.figure(figsize=(15, 5))
ax = fig.add_subplot(2, 1, 1)
PreyLine, = plt.plot(time_stamps, prey_list, color='g')
PredatorsLine, = plt.plot(time_stamps, predator_list, color='r')
ax.set_xscale('log')

plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators'])
plt.ylabel('Population')
plt.xlabel('Time')
if save:
plt.savefig(join(results_dir, f"{filename}.svg"),
format='svg', transparent=False, bbox_inches='tight')
else:
plt.show()
plt.close()
12 changes: 12 additions & 0 deletions causal_inference/utils/writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from os.path import join
import h5py

def _save_population(prey_list, predator_list, time_stamps, results_dir):
filename = join(results_dir, 'populations.h5')
hf = h5py.File(filename, 'w')
hf.create_dataset('time_stamp', data=time_stamps)
hf.create_dataset('prey_pop', data=prey_list)
hf.create_dataset('pred_pop', data=predator_list)
hf.close()

0 comments on commit bb9de21

Please sign in to comment.