diff --git a/causal_inference/base/lv_simulator.py b/causal_inference/base/lv_simulator.py index d9bce29..d333f02 100644 --- a/causal_inference/base/lv_simulator.py +++ b/causal_inference/base/lv_simulator.py @@ -6,37 +6,12 @@ import logging import datetime -import h5py -import matplotlib.pyplot as plt - from causal_inference.config import RESULTS_DIR from causal_inference.utils.log_config import log_LV_params from causal_inference.base.ode_solver import ODE_solver from causal_inference.base.runge_kutta_solver import RungeKuttaSolver - -def _save_population(prey_list, predator_list): - filename = os.path.join(RESULTS_DIR, 'populations.h5') - hf = h5py.File(filename, 'w') - hf.create_dataset('prey_pop', data=prey_list) - hf.create_dataset('pred_pop', data=predator_list) - hf.close() - -def plot_population_over_time(prey_list, predator_list, save=True, filename='predator_prey'): - fig = plt.figure(figsize=(15, 5)) - ax = fig.add_subplot(2, 1, 1) - PreyLine, = plt.plot(prey_list , color='g') - PredatorsLine, = plt.plot(predator_list, color='r') - ax.set_xscale('log') - - plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators']) - plt.ylabel('Population') - plt.xlabel('Time') - if save: - plt.savefig(os.path.join(RESULTS_DIR, f"{filename}.svg"), - format='svg', transparent=False, bbox_inches='tight') - else: - plt.show() - plt.close() +from causal_inference.utils.writer import _save_population +from causal_inference.utils.visualisations import plot_population_over_time def get_solver(method): ''' @@ -50,15 +25,15 @@ def get_solver(method): raise AssertionError(f'{method} is not implemented!') return solver -def main(method): +def main(method, results_dir): ''' Main function that solves LV system. ''' log_LV_params() solver = get_solver(method) prey_list, predator_list = solver._solve() - _save_population(prey_list, predator_list) - plot_population_over_time(prey_list, predator_list) + _save_population(prey_list, predator_list, solver.time_stamps, results_dir) + plot_population_over_time(prey_list, predator_list, solver.time_stamps, results_dir) if __name__ == '__main__': PARSER = argparse.ArgumentParser() @@ -68,12 +43,12 @@ def main(method): choices=['RK4', 'ODE'], default='RK4') ARGS = PARSER.parse_args() - RESULTS_DIR = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir))) + results_dir = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir))) - if not os.path.exists(RESULTS_DIR): - os.makedirs(RESULTS_DIR) + if not os.path.exists(results_dir): + os.makedirs(results_dir) - LOG_FILE = os.path.join(RESULTS_DIR, f"{ARGS.logfile}.txt") # write logg to this file + LOG_FILE = os.path.join(results_dir, f"{ARGS.logfile}.txt") # write logg to this file logging.basicConfig( level=logging.INFO, handlers=[ @@ -82,4 +57,4 @@ def main(method): ] ) solver = ARGS.solver - main(solver) + main(solver, results_dir) diff --git a/causal_inference/base/lv_system.py b/causal_inference/base/lv_system.py index b580e78..af3291c 100644 --- a/causal_inference/base/lv_system.py +++ b/causal_inference/base/lv_system.py @@ -1,20 +1,31 @@ #!/usr/bin/python # -*- coding: utf-8 -*- +import numpy as np from causal_inference.config import LV_PARAMS class LotkaVolterra(): ''' - Class simulates predator-prey dynamics and solves it with 4th order Runge-Kutta method. + Base Lotka-Volterra Class that defines a predator-prey system. ''' - def __init__(self): - self.A = LV_PARAMS['A'] - self.B = LV_PARAMS['B'] - self.C = LV_PARAMS['C'] - self.D = LV_PARAMS['D'] - self.time = LV_PARAMS['INITIAL_TIME'] - self.step_size = LV_PARAMS['STEP_SIZE'] - self.max_iterations = LV_PARAMS['MAX_ITERATIONS'] + def __init__(self, + A=LV_PARAMS['A'], B=LV_PARAMS['B'], C=LV_PARAMS['C'], D=LV_PARAMS['D'], + prey_population=LV_PARAMS['INITIAL_PREY_POPULATION'], + pred_population=LV_PARAMS['INITIAL_PREDATOR_POPULATION'], + total_time=LV_PARAMS['TOTAL_TIME'], step_size=LV_PARAMS['STEP_SIZE'], + max_iter=LV_PARAMS['MAX_ITERATIONS']): + # Lotka-Volterra parameters + self.A = A + self.B = B + self.C = C + self.D = D - self.prey_population = LV_PARAMS['INITIAL_PREY_POPULATION'] - self.predator_population = LV_PARAMS['INITIAL_PREDATOR_POPULATION'] \ No newline at end of file + self.prey_population = prey_population # Initial prey population + self.predator_population = pred_population # Initial predator population + + self.init_time = 0 # initial time + self.total_time = total_time # total time in units + self.step_size = step_size # increment for each time step + self.max_iterations = max_iter # tolerance parameter + + self.time_stamps = np.arange(self.init_time, self.total_time, self.step_size) diff --git a/causal_inference/base/ode_solver.py b/causal_inference/base/ode_solver.py index 7ddb0e1..7f47151 100644 --- a/causal_inference/base/ode_solver.py +++ b/causal_inference/base/ode_solver.py @@ -5,7 +5,6 @@ import numpy as np from scipy import integrate -from causal_inference.config import LV_PARAMS from causal_inference.base.lv_system import LotkaVolterra class ODE_solver(LotkaVolterra): @@ -14,20 +13,28 @@ class ODE_solver(LotkaVolterra): ''' def __init__(self): super().__init__() - logging.info('Solving Lotka-Volterra predator-prey dynamics odeint solver') + logging.info('Simulating Lotka-Volterra predator-prey dynamics with odeint solver') @staticmethod - def LV_derivative(X, t, alpha, beta, delta, gamma): - x, y = X - dotx = x * (alpha - beta * y) - doty = y * (-delta + gamma * x) + def LV_derivative(t, Z, A, B, C, D): + ''' + Returns the rate of change of predator and prey population + ''' + x, y = Z + dotx = x * (A - B * y) + doty = y * (-C + D * x) return np.array([dotx, doty]) def _solve(self): - logging.info('Computing population over time...') - t = np.arange(0.,self.max_iterations, self.step_size) - X0 = [self.prey_population, self.predator_population] - res = integrate.odeint(self.LV_derivative, X0, t, args=(self.A, self.B, self.C, self.D)) - prey_list, predator_list = res.T + ''' + ODE solver that returns the predator and prey populations at each time step in time series. + ''' + logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...') + + INIT_POP = [self.prey_population, self.predator_population] + sol = integrate.solve_ivp(self.LV_derivative, [self.init_time, self.total_time], INIT_POP, args=(self.A, self.B, self.C, self.D), dense_output=True) + prey_list, predator_list = sol.sol(self.time_stamps) + logging.info('done!') - return prey_list, predator_list \ No newline at end of file + + return prey_list, predator_list diff --git a/causal_inference/base/runge_kutta_solver.py b/causal_inference/base/runge_kutta_solver.py index 8a54c89..440d733 100644 --- a/causal_inference/base/runge_kutta_solver.py +++ b/causal_inference/base/runge_kutta_solver.py @@ -1,7 +1,6 @@ #!/usr/bin/python # -*- coding: utf-8 -*- import logging -from math import ceil from causal_inference.base.lv_system import LotkaVolterra class RungeKuttaSolver(LotkaVolterra): @@ -11,7 +10,6 @@ class RungeKuttaSolver(LotkaVolterra): def __init__(self): super().__init__() logging.info('Solving Lotka-Volterra predator-prey dynamics with 4th order Runge-Kutta method') - self.time_stamp = [self.time] self.prey_list = [self.prey_population] self.predator_list = [self.predator_population] @@ -22,8 +20,6 @@ def compute_predator_rate(self, current_prey, current_predators): return - self.C * current_predators + self.D * current_prey * current_predators def runge_kutta_update(self, current_prey, current_predators): - self.time = self.time + self.step_size - self.time_stamp.append(self.time) k1_prey = self.step_size * self.compute_prey_rate(current_prey, current_predators) k1_pred = self.step_size * self.compute_predator_rate(current_prey, current_predators) @@ -46,11 +42,17 @@ def runge_kutta_update(self, current_prey, current_predators): return new_prey_population, new_predator_population def _solve(self): + ''' + Runge-Kutta solver that returns the predator and prey populations at each time step in time series. + ''' + #initial population current_prey, current_predators = self.prey_population, self.predator_population - logging.info('Computing population over time...') - for gen_idx in range(ceil(self.max_iterations/self.step_size)): + + logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...') + + for step_idx in self.time_stamps[1:]: current_prey, current_predators = self.runge_kutta_update(current_prey, current_predators) - msg= f'Gen: {gen_idx} | Prey population: {current_prey} | Predator population: {current_predators}' + msg= f'Step: {step_idx} | Prey population: {current_prey} | Predator population: {current_predators}' logging.info(msg) print('Done!') return self.prey_list, self.predator_list \ No newline at end of file diff --git a/causal_inference/config.py b/causal_inference/config.py index 4acd3de..bf940b2 100644 --- a/causal_inference/config.py +++ b/causal_inference/config.py @@ -12,7 +12,7 @@ 'C' : 3.0, 'D' : 5.0, 'STEP_SIZE' : 0.01, -'INITIAL_TIME' : 0, +'TOTAL_TIME' : 20, 'INITIAL_PREY_POPULATION' : 60, 'INITIAL_PREDATOR_POPULATION' : 25, 'MAX_ITERATIONS' : 200 diff --git a/causal_inference/utils/visualisations.py b/causal_inference/utils/visualisations.py new file mode 100644 index 0000000..220ba98 --- /dev/null +++ b/causal_inference/utils/visualisations.py @@ -0,0 +1,21 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +from os.path import join +import matplotlib.pyplot as plt + +def plot_population_over_time(prey_list, predator_list, time_stamps, results_dir, save=True, filename='predator_prey'): + fig = plt.figure(figsize=(15, 5)) + ax = fig.add_subplot(2, 1, 1) + PreyLine, = plt.plot(time_stamps, prey_list, color='g') + PredatorsLine, = plt.plot(time_stamps, predator_list, color='r') + ax.set_xscale('log') + + plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators']) + plt.ylabel('Population') + plt.xlabel('Time') + if save: + plt.savefig(join(results_dir, f"{filename}.svg"), + format='svg', transparent=False, bbox_inches='tight') + else: + plt.show() + plt.close() \ No newline at end of file diff --git a/causal_inference/utils/writer.py b/causal_inference/utils/writer.py new file mode 100644 index 0000000..2d5da4e --- /dev/null +++ b/causal_inference/utils/writer.py @@ -0,0 +1,12 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +from os.path import join +import h5py + +def _save_population(prey_list, predator_list, time_stamps, results_dir): + filename = join(results_dir, 'populations.h5') + hf = h5py.File(filename, 'w') + hf.create_dataset('time_stamp', data=time_stamps) + hf.create_dataset('prey_pop', data=prey_list) + hf.create_dataset('pred_pop', data=predator_list) + hf.close()