From b95efb13c5266dfe5b8939a9a6ea25ca4e23ca93 Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Mon, 17 Aug 2020 14:10:09 +0200 Subject: [PATCH 01/13] first version of fitzhughnagumo --- tests/fitzhughnagumo.json | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/fitzhughnagumo.json diff --git a/tests/fitzhughnagumo.json b/tests/fitzhughnagumo.json new file mode 100644 index 00000000..d2498aee --- /dev/null +++ b/tests/fitzhughnagumo.json @@ -0,0 +1,30 @@ +{ + "parameters": { + "C_m": "1", + "g_Ca": "1.1", + "g_K": "2", + "g_L": ".5", + "E_Ca": "100", + "E_K": "-70", + "E_L": "-50", + "I_ext": "30" + }, + + "dynamics": [ + { + "expression": "V' = (I_ext - g_Ca * (.5 + .5*tanh((V + 1) / 15)) * (V - E_Ca) - g_K * W * (V - E_K) - g_L * (V - E_L)) / C_m * 1E3", + "initial_value": "-25" + }, + { + "expression": "W' = (.5 + .5 * tanh(V / 30) - W) / (5 / cosh(V / 60)) * 1E3", + "initial_value": ".15" + } + ], + + "options": { + "sim_time": "45E-3", + "max_step_size": ".25E-3", + "integration_accuracy_abs" : "1E-9", + "integration_accuracy_rel" : "1E-9" + } +} From 0da69955a4b5f024056e543737c829d7611f891f Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Tue, 1 Sep 2020 16:56:49 +0200 Subject: [PATCH 02/13] fitzhughnagumo model --- tests/test_fitzhughnagumo.py | 164 +++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 tests/test_fitzhughnagumo.py diff --git a/tests/test_fitzhughnagumo.py b/tests/test_fitzhughnagumo.py new file mode 100644 index 00000000..1dc79f3d --- /dev/null +++ b/tests/test_fitzhughnagumo.py @@ -0,0 +1,164 @@ +# +# test_mixed_integrator_numeric.py +# +# This file is part of the NEST ODE toolbox. +# +# Copyright (C) 2017 The NEST Initiative +# +# The NEST ODE toolbox is free software: you can redistribute it +# and/or modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation, either version 2 of +# the License, or (at your option) any later version. +# +# The NEST ODE toolbox is distributed in the hope that it will be +# useful, but WITHOUT ANY WARRANTY; without even the implied warranty +# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +# + + +import json +import os +import pytest +import unittest +import sympy +import numpy as np +import matplotlib.pyplot as plt +from scipy.signal import find_peaks +from sympy.parsing.sympy_parser import parse_expr +from sympy.solvers import solve +from sympy import Symbol +#np.seterr(under="warn") + +try: + import matplotlib as mpl + mpl.use('Agg') + import matplotlib.pyplot as plt + INTEGRATION_TEST_DEBUG_PLOTS = True +except: + INTEGRATION_TEST_DEBUG_PLOTS = False + + +import odetoolbox +from odetoolbox.mixed_integrator import MixedIntegrator + +from math import e +from sympy import exp, sympify +import sympy.parsing.sympy_parser + +import scipy +import scipy.special +import scipy.linalg + +try: + import pygsl.odeiv as odeiv + PYGSL_AVAILABLE = True +except ImportError as ie: + PYGSL_AVAILABLE = False + + +def open_json(fname): + absfname = os.path.join(os.path.abspath(os.path.dirname(__file__)), fname) + with open(absfname) as infile: + indict = json.load(infile) + return indict + + +class TestMixedIntegrationNumeric(unittest.TestCase): + + """ + Numerical validation of MixedIntegrator. Note that this test uses all-numeric (no analytic part) integration to test for time grid aliasing effects of spike times. + + Simulate a conductance-based integrate-and-fire neuron which is receiving spikes. Check for a match of the final system state with a numerical reference value that was validated by hand. + """ + + def initial__values(self, curr): + I_ext = Symbol("I_ext") + V = Symbol("V") + expr = solve((sympy.parsing.sympy_parser.parse_expr("8*V**3 + 6*V + 21 - 24*I_ext")), V) # expr gives a list of three roots for V: first two are complex, third one is real + final_val_V = (expr[2].subs(I_ext,curr)).evalf() + final_val_W = ((10*final_val_V) + 7)/8 + return float(final_val_V), float(final_val_W) + + @pytest.mark.skipif(not PYGSL_AVAILABLE, reason="Need GSL integrator to perform test") + def test_mixed_integrator_numeric(self): + debug = True + + h = 1 # [ms] #time steps + T = 1000 # [ms] #total simulation time + n = 25 #total number of current values between 0 and 1 + I_ext = np.linspace(0,1,n) + time_analysis_start = 200 + N1 = (int)((time_analysis_start)/h) #index of the starting time + num_peaks = np.zeros(n) + indict = open_json("fitzhughnagumo.json") + analysis_json, shape_sys, shapes = odetoolbox._analysis(indict, disable_stiffness_check=True, disable_analytic_solver=True) + print("Got analysis result from ode-toolbox: ") + print(json.dumps(analysis_json, indent=2)) + assert len(analysis_json) == 1 + assert analysis_json[0]["solver"].startswith("numeric") + alias_spikes = True + integrator = odeiv.step_rk4 + + for j in range(n): + #loop over current values + initial_values = { "V" : (self.initial__values(I_ext[j])[0] + 0.001), "W": self.initial__values(I_ext[j])[1]} + initial_values = { sympy.Symbol(k) : v for k, v in initial_values.items() } + mixed_integrator = MixedIntegrator( + integrator, + shape_sys, + shapes, + analytic_solver_dict=None, + parameters={"I_ext":str(I_ext[j])}, + random_seed=123, + max_step_size=h, + integration_accuracy_abs=1E-5, + integration_accuracy_rel=1E-5, + sim_time=T, + alias_spikes=alias_spikes) + h_min, h_avg, runtime, upper_bound_crossed, t_log, h_log, y_log, sym_list = mixed_integrator.integrate_ode( + initial_values=initial_values, + h_min_lower_bound=1E-12, raise_errors=True, debug=True) # debug needs to be True here to obtain the right return values + peaks, _ = find_peaks(np.array(y_log)[N1:,0], height = 1.5 ) #finding peaks above 1.5 microvolts ignoring the first 200 ms + num_peaks[j] = (int)(len(peaks)/((T-200)*0.001)) #frequency (in Hz) of the peaks for every value of current + if(I_ext[j] >(1/3)): + assert(num_peaks[j]>20) + self._timeseries_plot(N1,t_log, h_log, y_log, sym_list, basedir="", fn_snip = " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j]), title_snip= " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j])) + + self._FI_curve(I_ext,num_peaks,basedir="",fn_snip = "FI curve", title_snip = "FI curve") + + def _timeseries_plot(self,N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip="", title_snip=""): + fig, ax = plt.subplots(len(y_log[0]), sharex=True) + for i, sym in enumerate(sym_list): + ax[i].plot(1E3 * np.array(t_log)[N1:], np.array(y_log)[N1:, i], label=str(sym)) + + for _ax in ax: + _ax.legend() + _ax.grid(True) + + ax[-1].set_xlabel("Time [ms]") + fig.suptitle("V vs time" + title_snip) + + fn = os.path.join(basedir, "test_fitzhughnagumo" + fn_snip + ".png") + print("Saving to " + fn) + plt.savefig(fn, dpi=600) + plt.close(fig) + + def _FI_curve(self,I_ext,num_peaks,basedir="",fn_snip="",title_snip=""): + plt.title(title_snip) + plt.xlabel("External current (arbitrary units)") + plt.ylabel("Frequency of spikes in Hz") + plt.plot(I_ext, num_peaks) #plotting the frequency of peaks vs external current + fn = os.path.join(basedir, "test_fitzhughnagumo " + fn_snip + ".png") + print("Saving to " + fn) + plt.savefig(fn,dpi=600) + plt.close() + + + + +if __name__ == '__main__': + unittest.main() From b1c84551ae51d31432a2e2ee6b4c6e8e98b702cc Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Tue, 1 Sep 2020 22:11:20 +0200 Subject: [PATCH 03/13] added test for matplotlib --- tests/test_fitzhughnagumo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_fitzhughnagumo.py b/tests/test_fitzhughnagumo.py index 1dc79f3d..8a33139e 100644 --- a/tests/test_fitzhughnagumo.py +++ b/tests/test_fitzhughnagumo.py @@ -126,9 +126,10 @@ def test_mixed_integrator_numeric(self): num_peaks[j] = (int)(len(peaks)/((T-200)*0.001)) #frequency (in Hz) of the peaks for every value of current if(I_ext[j] >(1/3)): assert(num_peaks[j]>20) - self._timeseries_plot(N1,t_log, h_log, y_log, sym_list, basedir="", fn_snip = " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j]), title_snip= " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j])) - - self._FI_curve(I_ext,num_peaks,basedir="",fn_snip = "FI curve", title_snip = "FI curve") + if(INTEGRATION_TEST_DEBUG_PLOTS==True): + self._timeseries_plot(N1,t_log, h_log, y_log, sym_list, basedir="", fn_snip = " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j]), title_snip= " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j])) + if(INTEGRATION_TEST_DEBUG_PLOTS==True): + self._FI_curve(I_ext,num_peaks,basedir="",fn_snip = "FI curve", title_snip = "FI curve") def _timeseries_plot(self,N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip="", title_snip=""): fig, ax = plt.subplots(len(y_log[0]), sharex=True) From a6747ea00129b4adc8b9103be2480f403e38626d Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Wed, 2 Sep 2020 10:22:27 +0200 Subject: [PATCH 04/13] removed the duplicate import of matplotlib --- tests/test_fitzhughnagumo.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_fitzhughnagumo.py b/tests/test_fitzhughnagumo.py index 8a33139e..d8eb4643 100644 --- a/tests/test_fitzhughnagumo.py +++ b/tests/test_fitzhughnagumo.py @@ -26,12 +26,11 @@ import unittest import sympy import numpy as np -import matplotlib.pyplot as plt from scipy.signal import find_peaks from sympy.parsing.sympy_parser import parse_expr from sympy.solvers import solve from sympy import Symbol -#np.seterr(under="warn") + try: import matplotlib as mpl @@ -126,9 +125,9 @@ def test_mixed_integrator_numeric(self): num_peaks[j] = (int)(len(peaks)/((T-200)*0.001)) #frequency (in Hz) of the peaks for every value of current if(I_ext[j] >(1/3)): assert(num_peaks[j]>20) - if(INTEGRATION_TEST_DEBUG_PLOTS==True): + if INTEGRATION_TEST_DEBUG_PLOTS: self._timeseries_plot(N1,t_log, h_log, y_log, sym_list, basedir="", fn_snip = " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j]), title_snip= " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j])) - if(INTEGRATION_TEST_DEBUG_PLOTS==True): + if INTEGRATION_TEST_DEBUG_PLOTS: self._FI_curve(I_ext,num_peaks,basedir="",fn_snip = "FI curve", title_snip = "FI curve") def _timeseries_plot(self,N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip="", title_snip=""): From c819926ed2eda703dec84dd08a366014882ee520 Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Thu, 3 Sep 2020 01:41:45 +0200 Subject: [PATCH 05/13] adding json file and fixing the dictionary bug --- odetoolbox/integrator.py | 12 ++++++++---- tests/fitzhughnagumo.json | 18 +++++------------- tests/test_fitzhughnagumo.py | 14 +++++++------- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/odetoolbox/integrator.py b/odetoolbox/integrator.py index da492095..6e75f449 100644 --- a/odetoolbox/integrator.py +++ b/odetoolbox/integrator.py @@ -23,6 +23,7 @@ import sympy import sympy.matrices import numpy as np +from typing import Dict class Integrator(): @@ -30,18 +31,21 @@ class Integrator(): Integrate a dynamical system by means of the propagators returned by ODE-toolbox (base class). """ - def set_spike_times(self, spike_times): + def set_spike_times(self, spike_times: Dict[str,float]): #spike_times is a dictionary r""" Internally converts to a global, sorted list of spike times. :param spike_times: For each variable, used as a key, the list of spike times associated with it. """ - + if spike_times is None: - self.spike_times = [] + self.spike_times = {} + else: self.spike_times = spike_times.copy() - + + + assert all([type(sym) is str for sym in self.spike_times.keys()]), "Spike time keys need to be of type str" self.all_spike_times = [] diff --git a/tests/fitzhughnagumo.json b/tests/fitzhughnagumo.json index d2498aee..c3500c91 100644 --- a/tests/fitzhughnagumo.json +++ b/tests/fitzhughnagumo.json @@ -1,30 +1,22 @@ { "parameters": { - "C_m": "1", - "g_Ca": "1.1", - "g_K": "2", - "g_L": ".5", - "E_Ca": "100", - "E_K": "-70", - "E_L": "-50", - "I_ext": "30" + "I_ext": "1" }, "dynamics": [ { - "expression": "V' = (I_ext - g_Ca * (.5 + .5*tanh((V + 1) / 15)) * (V - E_Ca) - g_K * W * (V - E_K) - g_L * (V - E_L)) / C_m * 1E3", + "expression": "V' = V - (V**3/3) - W + I_ext", "initial_value": "-25" }, { - "expression": "W' = (.5 + .5 * tanh(V / 30) - W) / (5 / cosh(V / 60)) * 1E3", + "expression": "W' = .08*(V + .7 - (.8 * W))", "initial_value": ".15" } ], "options": { "sim_time": "45E-3", - "max_step_size": ".25E-3", - "integration_accuracy_abs" : "1E-9", - "integration_accuracy_rel" : "1E-9" + "integration_accuracy_abs" : "1E-6", + "integration_accuracy_rel" : "1E-6" } } diff --git a/tests/test_fitzhughnagumo.py b/tests/test_fitzhughnagumo.py index d8eb4643..cf8647c9 100644 --- a/tests/test_fitzhughnagumo.py +++ b/tests/test_fitzhughnagumo.py @@ -66,12 +66,12 @@ def open_json(fname): return indict -class TestMixedIntegrationNumeric(unittest.TestCase): +class TestFitxhughNagumo(unittest.TestCase): """ - Numerical validation of MixedIntegrator. Note that this test uses all-numeric (no analytic part) integration to test for time grid aliasing effects of spike times. - - Simulate a conductance-based integrate-and-fire neuron which is receiving spikes. Check for a match of the final system state with a numerical reference value that was validated by hand. + Implementing the fitzhughNagumo model starting from equilibrium values, and performing a test that if the external current crosses a certain threshold value, regular spikes are obtained. + This function tests if the number of spikes cross 20 in that case. + Additionally, plots of V and W vs time are obtained for different values of current, and a FI curve is also plotted. """ def initial__values(self, curr): @@ -83,14 +83,14 @@ def initial__values(self, curr): return float(final_val_V), float(final_val_W) @pytest.mark.skipif(not PYGSL_AVAILABLE, reason="Need GSL integrator to perform test") - def test_mixed_integrator_numeric(self): + def test_fitzhugh_nagumo(self): debug = True h = 1 # [ms] #time steps T = 1000 # [ms] #total simulation time n = 25 #total number of current values between 0 and 1 - I_ext = np.linspace(0,1,n) - time_analysis_start = 200 + I_ext = np.linspace(0,1,n) #external current + time_analysis_start = 200 #starting our ananlysis after 200 ms N1 = (int)((time_analysis_start)/h) #index of the starting time num_peaks = np.zeros(n) indict = open_json("fitzhughnagumo.json") From 6f7b6865894bb3573739940c2955be9cc3c5010e Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Wed, 9 Sep 2020 15:35:30 +0200 Subject: [PATCH 06/13] removed the multiplication of time axis by 1000 --- tests/test_fitzhughnagumo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_fitzhughnagumo.py b/tests/test_fitzhughnagumo.py index cf8647c9..e6fcd880 100644 --- a/tests/test_fitzhughnagumo.py +++ b/tests/test_fitzhughnagumo.py @@ -121,6 +121,7 @@ def test_fitzhugh_nagumo(self): h_min, h_avg, runtime, upper_bound_crossed, t_log, h_log, y_log, sym_list = mixed_integrator.integrate_ode( initial_values=initial_values, h_min_lower_bound=1E-12, raise_errors=True, debug=True) # debug needs to be True here to obtain the right return values + import pdb;pdb.set_trace() peaks, _ = find_peaks(np.array(y_log)[N1:,0], height = 1.5 ) #finding peaks above 1.5 microvolts ignoring the first 200 ms num_peaks[j] = (int)(len(peaks)/((T-200)*0.001)) #frequency (in Hz) of the peaks for every value of current if(I_ext[j] >(1/3)): @@ -133,7 +134,7 @@ def test_fitzhugh_nagumo(self): def _timeseries_plot(self,N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip="", title_snip=""): fig, ax = plt.subplots(len(y_log[0]), sharex=True) for i, sym in enumerate(sym_list): - ax[i].plot(1E3 * np.array(t_log)[N1:], np.array(y_log)[N1:, i], label=str(sym)) + ax[i].plot(np.array(t_log)[N1:], np.array(y_log)[N1:, i], label=str(sym)) for _ax in ax: _ax.legend() From 79d0323246b7a0f0b9204681b28cd45286d26e48 Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Wed, 9 Sep 2020 15:39:38 +0200 Subject: [PATCH 07/13] removed the pdb break point --- tests/test_fitzhughnagumo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_fitzhughnagumo.py b/tests/test_fitzhughnagumo.py index e6fcd880..bb22edd2 100644 --- a/tests/test_fitzhughnagumo.py +++ b/tests/test_fitzhughnagumo.py @@ -121,7 +121,6 @@ def test_fitzhugh_nagumo(self): h_min, h_avg, runtime, upper_bound_crossed, t_log, h_log, y_log, sym_list = mixed_integrator.integrate_ode( initial_values=initial_values, h_min_lower_bound=1E-12, raise_errors=True, debug=True) # debug needs to be True here to obtain the right return values - import pdb;pdb.set_trace() peaks, _ = find_peaks(np.array(y_log)[N1:,0], height = 1.5 ) #finding peaks above 1.5 microvolts ignoring the first 200 ms num_peaks[j] = (int)(len(peaks)/((T-200)*0.001)) #frequency (in Hz) of the peaks for every value of current if(I_ext[j] >(1/3)): From 4063a50a9ff6f4346d102a9653bc273ef8189f6a Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Thu, 10 Sep 2020 23:05:01 +0200 Subject: [PATCH 08/13] removed extra lines --- odetoolbox/integrator.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/odetoolbox/integrator.py b/odetoolbox/integrator.py index 6e75f449..493b3c00 100644 --- a/odetoolbox/integrator.py +++ b/odetoolbox/integrator.py @@ -18,36 +18,27 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . # - import logging import sympy import sympy.matrices import numpy as np from typing import Dict - class Integrator(): r""" Integrate a dynamical system by means of the propagators returned by ODE-toolbox (base class). """ - def set_spike_times(self, spike_times: Dict[str,float]): #spike_times is a dictionary r""" Internally converts to a global, sorted list of spike times. - :param spike_times: For each variable, used as a key, the list of spike times associated with it. """ - if spike_times is None: self.spike_times = {} - else: self.spike_times = spike_times.copy() - - assert all([type(sym) is str for sym in self.spike_times.keys()]), "Spike time keys need to be of type str" - self.all_spike_times = [] self.all_spike_times_sym = [] for sym, sym_spike_times in self.spike_times.items(): @@ -65,21 +56,17 @@ def set_spike_times(self, spike_times: Dict[str,float]): #spike_times is a dicti self.all_spike_times = [ self.all_spike_times[i] for i in idx ] self.all_spike_times_sym = [ self.all_spike_times_sym[i] for i in idx ] - def get_spike_times(self): r""" Get spike times. - :return spike_times: For each variable, used as a key, the list of spike times associated with it. """ return self.spike_times - - + def get_sorted_spike_times(self): r""" Returns a global, sorted list of spike times. - :return all_spike_times: A sorted list of all spike times for all variables. :return all_spike_times_sym: For the spike at time :python:`all_spike_times[i]`, the variables to which that spike applies are listed in :python:`all_spike_times_sym[i]`. """ - return self.all_spike_times, self.all_spike_times_sym + return self.all_spike_times, self.all_spike_times_sym \ No newline at end of file From df01b55102367c0b73bfa058771affcf4b26e207 Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Thu, 10 Sep 2020 23:08:21 +0200 Subject: [PATCH 09/13] made all the changes suggested in the review --- tests/fitzhughnagumo.json | 10 ++-- tests/test_fitzhughnagumo.py | 93 +++++++++++++++++++----------------- 2 files changed, 52 insertions(+), 51 deletions(-) diff --git a/tests/fitzhughnagumo.json b/tests/fitzhughnagumo.json index c3500c91..1ce5ddf6 100644 --- a/tests/fitzhughnagumo.json +++ b/tests/fitzhughnagumo.json @@ -1,4 +1,6 @@ { + "__info" : "This is the FitzHugh-Nagumo model [http://www.scholarpedia.org/article/FitzHugh-Nagumo_model, https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model]", + "parameters": { "I_ext": "1" }, @@ -12,11 +14,5 @@ "expression": "W' = .08*(V + .7 - (.8 * W))", "initial_value": ".15" } - ], - - "options": { - "sim_time": "45E-3", - "integration_accuracy_abs" : "1E-6", - "integration_accuracy_rel" : "1E-6" - } + ] } diff --git a/tests/test_fitzhughnagumo.py b/tests/test_fitzhughnagumo.py index bb22edd2..b1589f39 100644 --- a/tests/test_fitzhughnagumo.py +++ b/tests/test_fitzhughnagumo.py @@ -1,5 +1,5 @@ # -# test_mixed_integrator_numeric.py +# test_fitzhughnagumo.py # # This file is part of the NEST ODE toolbox. # @@ -30,8 +30,6 @@ from sympy.parsing.sympy_parser import parse_expr from sympy.solvers import solve from sympy import Symbol - - try: import matplotlib as mpl mpl.use('Agg') @@ -39,15 +37,11 @@ INTEGRATION_TEST_DEBUG_PLOTS = True except: INTEGRATION_TEST_DEBUG_PLOTS = False - - import odetoolbox from odetoolbox.mixed_integrator import MixedIntegrator - from math import e from sympy import exp, sympify import sympy.parsing.sympy_parser - import scipy import scipy.special import scipy.linalg @@ -57,54 +51,62 @@ PYGSL_AVAILABLE = True except ImportError as ie: PYGSL_AVAILABLE = False - - def open_json(fname): absfname = os.path.join(os.path.abspath(os.path.dirname(__file__)), fname) with open(absfname) as infile: indict = json.load(infile) return indict - - -class TestFitxhughNagumo(unittest.TestCase): + +class TestFitxhughNagumo(unittest.TestCase): """ + This is the FitzHugh-Nagumo model [http://www.scholarpedia.org/article/FitzHugh-Nagumo_model, https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model] Implementing the fitzhughNagumo model starting from equilibrium values, and performing a test that if the external current crosses a certain threshold value, regular spikes are obtained. This function tests if the number of spikes cross 20 in that case. Additionally, plots of V and W vs time are obtained for different values of current, and a FI curve is also plotted. - """ - + """ def initial__values(self, curr): - I_ext = Symbol("I_ext") - V = Symbol("V") - expr = solve((sympy.parsing.sympy_parser.parse_expr("8*V**3 + 6*V + 21 - 24*I_ext")), V) # expr gives a list of three roots for V: first two are complex, third one is real - final_val_V = (expr[2].subs(I_ext,curr)).evalf() - final_val_W = ((10*final_val_V) + 7)/8 - return float(final_val_V), float(final_val_W) - + """ + This function returns the initial values(for every value of external current), i.e, the equilibrium values of V and W where the conditon dV/dt = dW/dt = 0 is staisfied. + Hence, V and W are the roots of the following equations: + V - V**3/3 - W + I = 0 + 0.08*(V + 0.7 - 0.8*W) = 0 + Sympy is used for the calculation of the roots. + """ + I_ext = Symbol("I_ext") + V = Symbol("V") + expr = solve((sympy.parsing.sympy_parser.parse_expr("8*V**3 + 6*V + 21 - 24*I_ext")), V) # expr gives a list of three roots for V: first two are complex, third one is real + final_val_V = (expr[2].subs(I_ext,curr)).evalf() + final_val_W = ((10*final_val_V) + 7)/8 + return float(final_val_V), float(final_val_W) #since sympy returns objects, we convert final_val_v and final_val_w to float numbers + @pytest.mark.skipif(not PYGSL_AVAILABLE, reason="Need GSL integrator to perform test") def test_fitzhugh_nagumo(self): debug = True - h = 1 # [ms] #time steps T = 1000 # [ms] #total simulation time - n = 25 #total number of current values between 0 and 1 + n = 10 #total number of current values between 0 and 1 I_ext = np.linspace(0,1,n) #external current + small_perturb = 0.001 #this value is the slight disturbance that we introduce to the equilibrium value of V returned from the initial__values() function. + threshold_V_for_peak = 1.5 #the minimum value of V for it to be counted as a peak + """ + Since about the first 200 ms correspond to a transient state of the neuron from exhibiting no spikes to gradually spiking (if the current is sufficient), + We start our analysis after ignoring the initial 200 ms and count the peaks appearing in the rest of the simulation time. N1 is therefore the index of the starting time. + """ time_analysis_start = 200 #starting our ananlysis after 200 ms - N1 = (int)((time_analysis_start)/h) #index of the starting time - num_peaks = np.zeros(n) + N1 = int(np.ceil(time_analysis_start/h)) #index of the starting time + peak_freq = np.zeros(n) + indict = open_json("fitzhughnagumo.json") analysis_json, shape_sys, shapes = odetoolbox._analysis(indict, disable_stiffness_check=True, disable_analytic_solver=True) print("Got analysis result from ode-toolbox: ") print(json.dumps(analysis_json, indent=2)) assert len(analysis_json) == 1 assert analysis_json[0]["solver"].startswith("numeric") - alias_spikes = True integrator = odeiv.step_rk4 - for j in range(n): #loop over current values - initial_values = { "V" : (self.initial__values(I_ext[j])[0] + 0.001), "W": self.initial__values(I_ext[j])[1]} + initial_values = { "V" : (self.initial__values(I_ext[j])[0] + small_perturb), "W": self.initial__values(I_ext[j])[1]} initial_values = { sympy.Symbol(k) : v for k, v in initial_values.items() } mixed_integrator = MixedIntegrator( integrator, @@ -112,36 +114,42 @@ def test_fitzhugh_nagumo(self): shapes, analytic_solver_dict=None, parameters={"I_ext":str(I_ext[j])}, - random_seed=123, max_step_size=h, integration_accuracy_abs=1E-5, integration_accuracy_rel=1E-5, - sim_time=T, - alias_spikes=alias_spikes) + sim_time=T) h_min, h_avg, runtime, upper_bound_crossed, t_log, h_log, y_log, sym_list = mixed_integrator.integrate_ode( initial_values=initial_values, h_min_lower_bound=1E-12, raise_errors=True, debug=True) # debug needs to be True here to obtain the right return values - peaks, _ = find_peaks(np.array(y_log)[N1:,0], height = 1.5 ) #finding peaks above 1.5 microvolts ignoring the first 200 ms - num_peaks[j] = (int)(len(peaks)/((T-200)*0.001)) #frequency (in Hz) of the peaks for every value of current - if(I_ext[j] >(1/3)): - assert(num_peaks[j]>20) + peak_freq[j] = self.peak_detection(y_log,N1,threshold_V_for_peak,time_analysis_start,T) + if I_ext[j] > 1/3: #this is actual unit testing part. + """ + I = 0.333333..: In the plot we see that the system gradually gets to a state where it starts spiking regularly. Therefore this current can + be regarded as the threshold current where the equilibrium shifts from a stable one to an unstable one. The threshold theoretically is (1/3) + which is a non terminating number. However, the computer cannot store an infinitely long number, and hence it rounds up the numbers. + One possibility is that (1/3) is rounded up to 0.33333333...4. Which is slightly above the threshold and hence we see that the system exhibits regular spikes + after a long transient state. Therefore, at I_ext = 1/3, we don't see a peak frequency of above 20 due to the long transient state. + """ + assert peak_freq[j]>20 if INTEGRATION_TEST_DEBUG_PLOTS: - self._timeseries_plot(N1,t_log, h_log, y_log, sym_list, basedir="", fn_snip = " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j]), title_snip= " I= " + str(I_ext[j]) + " peaks= " + str(num_peaks[j])) + self._timeseries_plot(N1,t_log, h_log, y_log, sym_list, basedir="", fn_snip = " I= " + str(I_ext[j]) + " peaks freq = " + str(peak_freq[j]), title_snip= " I= " + str(I_ext[j]) + " peaks freq= " + str(peak_freq[j])) if INTEGRATION_TEST_DEBUG_PLOTS: - self._FI_curve(I_ext,num_peaks,basedir="",fn_snip = "FI curve", title_snip = "FI curve") - + self._FI_curve(I_ext,peak_freq,basedir="",fn_snip = "FI curve", title_snip = "FI curve") + + def peak_detection(self,y_log,N1,threshold_V_for_peak,time_analysis_start,T): #function that determines the frequency of peaks in the plot for V vs time + peaks, _ = find_peaks(np.array(y_log)[N1:,0], height = threshold_V_for_peak ) #finding peaks above 1.5 microvolts ignoring the first 200 ms + frequency = int(len(peaks)/((T-time_analysis_start)*0.001)) #frequency (in Hz) of the peaks for every value of current + return frequency + def _timeseries_plot(self,N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip="", title_snip=""): fig, ax = plt.subplots(len(y_log[0]), sharex=True) for i, sym in enumerate(sym_list): ax[i].plot(np.array(t_log)[N1:], np.array(y_log)[N1:, i], label=str(sym)) - for _ax in ax: _ax.legend() _ax.grid(True) - ax[-1].set_xlabel("Time [ms]") fig.suptitle("V vs time" + title_snip) - fn = os.path.join(basedir, "test_fitzhughnagumo" + fn_snip + ".png") print("Saving to " + fn) plt.savefig(fn, dpi=600) @@ -157,8 +165,5 @@ def _FI_curve(self,I_ext,num_peaks,basedir="",fn_snip="",title_snip=""): plt.savefig(fn,dpi=600) plt.close() - - - if __name__ == '__main__': unittest.main() From 1fced8ede374f6264de056dfcf37c8fe41bb03e5 Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Mon, 14 Sep 2020 15:57:56 +0200 Subject: [PATCH 10/13] adding test_analystic_integrator.py --- tests/test_analytic_integrator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_analytic_integrator.py b/tests/test_analytic_integrator.py index a1a229ef..710a232a 100644 --- a/tests/test_analytic_integrator.py +++ b/tests/test_analytic_integrator.py @@ -121,11 +121,17 @@ def test_analytic_integrator_iaf_psc_alpha(self): _ax.grid(True) ax[-1].set_xlabel("Time [ms]") - + fn = os.path.join("", "test_analytic_integrator.png") + print("Saving to " + fn) + plt.savefig(fn, dpi=600) + plt.close(fig) + + """ fn = "/tmp/test_analytic_integrator.png" print("Saving to " + fn) plt.savefig(fn, dpi=600) plt.close(fig) + """ np.testing.assert_allclose(state[True]["timevec"], timevec) np.testing.assert_allclose(state[True]["timevec"], state[False]["timevec"]) From 0807120898b6f1a3c623abed3c253b170f517ce8 Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Tue, 15 Sep 2020 00:21:35 +0200 Subject: [PATCH 11/13] made changes suggested by pycodestyle --- odetoolbox/integrator.py | 12 ++-- tests/test_analytic_integrator.py | 1 - tests/test_fitzhughnagumo.py | 115 ++++++++++++++---------------- tests/test_iaf_analytic.py | 112 +++++++++++++++++++++++++++++ 4 files changed, 173 insertions(+), 67 deletions(-) create mode 100644 tests/test_iaf_analytic.py diff --git a/odetoolbox/integrator.py b/odetoolbox/integrator.py index 89f53a51..978f883a 100644 --- a/odetoolbox/integrator.py +++ b/odetoolbox/integrator.py @@ -18,26 +18,28 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . # + import logging import sympy import sympy.matrices import numpy as np from typing import Dict + class Integrator(): r""" Integrate a dynamical system by means of the propagators returned by ODE-toolbox (base class). """ - def set_spike_times(self, spike_times: Dict[str,float]): #spike_times is a dictionary + def set_spike_times(self, spike_times: Dict[str, float]): # spike_times is a dictionary r""" Internally converts to a global, sorted list of spike times. + :param spike_times: For each variable, used as a key, the list of spike times associated with it. """ if spike_times is None: self.spike_times = {} else: self.spike_times = spike_times.copy() - assert all([type(sym) is str for sym in self.spike_times.keys()]), "Spike time keys need to be of type str" self.all_spike_times = [] self.all_spike_times_sym = [] @@ -59,14 +61,16 @@ def set_spike_times(self, spike_times: Dict[str,float]): #spike_times is a dicti def get_spike_times(self): r""" Get spike times. + :return spike_times: For each variable, used as a key, the list of spike times associated with it. """ return self.spike_times - + def get_sorted_spike_times(self): r""" Returns a global, sorted list of spike times. + :return all_spike_times: A sorted list of all spike times for all variables. :return all_spike_times_sym: For the spike at time :python:`all_spike_times[i]`, the variables to which that spike applies are listed in :python:`all_spike_times_sym[i]`. """ - return self.all_spike_times, self.all_spike_times_sym \ No newline at end of file + return self.all_spike_times, self.all_spike_times_sym diff --git a/tests/test_analytic_integrator.py b/tests/test_analytic_integrator.py index 3d4b53ef..d1077344 100644 --- a/tests/test_analytic_integrator.py +++ b/tests/test_analytic_integrator.py @@ -125,7 +125,6 @@ def test_analytic_integrator_iaf_psc_alpha(self): print("Saving to " + fn) plt.savefig(fn, dpi=600) plt.close(fig) - """ fn = "/tmp/test_analytic_integrator.png" print("Saving to " + fn) diff --git a/tests/test_fitzhughnagumo.py b/tests/test_fitzhughnagumo.py index b1589f39..7805b606 100644 --- a/tests/test_fitzhughnagumo.py +++ b/tests/test_fitzhughnagumo.py @@ -35,7 +35,7 @@ mpl.use('Agg') import matplotlib.pyplot as plt INTEGRATION_TEST_DEBUG_PLOTS = True -except: +except Exception: INTEGRATION_TEST_DEBUG_PLOTS = False import odetoolbox from odetoolbox.mixed_integrator import MixedIntegrator @@ -51,52 +51,53 @@ PYGSL_AVAILABLE = True except ImportError as ie: PYGSL_AVAILABLE = False + + def open_json(fname): absfname = os.path.join(os.path.abspath(os.path.dirname(__file__)), fname) with open(absfname) as infile: indict = json.load(infile) return indict - -class TestFitxhughNagumo(unittest.TestCase): - + + +class TestFitxhughNagumo(unittest.TestCase): """ This is the FitzHugh-Nagumo model [http://www.scholarpedia.org/article/FitzHugh-Nagumo_model, https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model] Implementing the fitzhughNagumo model starting from equilibrium values, and performing a test that if the external current crosses a certain threshold value, regular spikes are obtained. - This function tests if the number of spikes cross 20 in that case. - Additionally, plots of V and W vs time are obtained for different values of current, and a FI curve is also plotted. - """ + This function tests if the number of spikes cross 20 in that case. + Additionally, plots of V and W vs time are obtained for different values of current, and a FI curve is also plotted. + """ def initial__values(self, curr): """ This function returns the initial values(for every value of external current), i.e, the equilibrium values of V and W where the conditon dV/dt = dW/dt = 0 is staisfied. Hence, V and W are the roots of the following equations: V - V**3/3 - W + I = 0 0.08*(V + 0.7 - 0.8*W) = 0 - Sympy is used for the calculation of the roots. + Sympy is used for the calculation of the roots. """ I_ext = Symbol("I_ext") - V = Symbol("V") - expr = solve((sympy.parsing.sympy_parser.parse_expr("8*V**3 + 6*V + 21 - 24*I_ext")), V) # expr gives a list of three roots for V: first two are complex, third one is real - final_val_V = (expr[2].subs(I_ext,curr)).evalf() - final_val_W = ((10*final_val_V) + 7)/8 - return float(final_val_V), float(final_val_W) #since sympy returns objects, we convert final_val_v and final_val_w to float numbers - + V = Symbol("V") + expr = solve((sympy.parsing.sympy_parser.parse_expr("8*V**3 + 6*V + 21 - 24*I_ext")), V) # expr gives a list of three roots for V: first two are complex, third one is real + final_val_V = (expr[2].subs(I_ext, curr)).evalf() + final_val_W = ((10 * final_val_V) + 7) / 8 + return float(final_val_V), float(final_val_W) # since sympy returns objects, we convert final_val_v and final_val_w to float numbers + @pytest.mark.skipif(not PYGSL_AVAILABLE, reason="Need GSL integrator to perform test") def test_fitzhugh_nagumo(self): debug = True - h = 1 # [ms] #time steps - T = 1000 # [ms] #total simulation time - n = 10 #total number of current values between 0 and 1 - I_ext = np.linspace(0,1,n) #external current - small_perturb = 0.001 #this value is the slight disturbance that we introduce to the equilibrium value of V returned from the initial__values() function. - threshold_V_for_peak = 1.5 #the minimum value of V for it to be counted as a peak + h = 1 # [ms] #time steps + T = 1000 # [ms] #total simulation time + n = 10 # total number of current values between 0 and 1 + I_ext = np.linspace(0, 1, n) # external current + small_perturb = 0.001 # this value is the slight disturbance that we introduce to the equilibrium value of V returned from the initial__values() function + threshold_V_for_peak = 1.5 # the minimum value of V for it to be counted as a peak """ - Since about the first 200 ms correspond to a transient state of the neuron from exhibiting no spikes to gradually spiking (if the current is sufficient), - We start our analysis after ignoring the initial 200 ms and count the peaks appearing in the rest of the simulation time. N1 is therefore the index of the starting time. + Since about the first 200 ms correspond to a transient state of the neuron from exhibiting no spikes to gradually spiking (if the current is sufficient), + We start our analysis after ignoring the initial 200 ms and count the peaks appearing in the rest of the simulation time. N1 is therefore the index of the starting time. """ - time_analysis_start = 200 #starting our ananlysis after 200 ms - N1 = int(np.ceil(time_analysis_start/h)) #index of the starting time + time_analysis_start = 200 # starting our ananlysis after 200 ms + N1 = int(np.ceil(time_analysis_start / h)) # index of the starting time peak_freq = np.zeros(n) - indict = open_json("fitzhughnagumo.json") analysis_json, shape_sys, shapes = odetoolbox._analysis(indict, disable_stiffness_check=True, disable_analytic_solver=True) print("Got analysis result from ode-toolbox: ") @@ -105,43 +106,32 @@ def test_fitzhugh_nagumo(self): assert analysis_json[0]["solver"].startswith("numeric") integrator = odeiv.step_rk4 for j in range(n): - #loop over current values - initial_values = { "V" : (self.initial__values(I_ext[j])[0] + small_perturb), "W": self.initial__values(I_ext[j])[1]} - initial_values = { sympy.Symbol(k) : v for k, v in initial_values.items() } - mixed_integrator = MixedIntegrator( - integrator, - shape_sys, - shapes, - analytic_solver_dict=None, - parameters={"I_ext":str(I_ext[j])}, - max_step_size=h, - integration_accuracy_abs=1E-5, - integration_accuracy_rel=1E-5, - sim_time=T) - h_min, h_avg, runtime, upper_bound_crossed, t_log, h_log, y_log, sym_list = mixed_integrator.integrate_ode( - initial_values=initial_values, - h_min_lower_bound=1E-12, raise_errors=True, debug=True) # debug needs to be True here to obtain the right return values - peak_freq[j] = self.peak_detection(y_log,N1,threshold_V_for_peak,time_analysis_start,T) - if I_ext[j] > 1/3: #this is actual unit testing part. + # loop over current values + initial_values = {"V": (self.initial__values(I_ext[j])[0] + small_perturb), "W": self.initial__values(I_ext[j])[1]} + initial_values = {sympy.Symbol(k): v for k, v in initial_values.items()} + mixed_integrator = MixedIntegrator(integrator, shape_sys, shapes, analytic_solver_dict=None, parameters={"I_ext": str(I_ext[j])}, max_step_size=h, integration_accuracy_abs=1E-5, integration_accuracy_rel=1E-5, sim_time=T) + h_min, h_avg, runtime, upper_bound_crossed, t_log, h_log, y_log, sym_list = mixed_integrator.integrate_ode(initial_values=initial_values, h_min_lower_bound=1E-12, raise_errors=True, debug=True) # debug needs to be True here to obtain the right return values + peak_freq[j] = self.peak_detection(y_log, N1, threshold_V_for_peak, time_analysis_start, T) + if I_ext[j] > 1 / 3: # this is actual unit testing part. """ - I = 0.333333..: In the plot we see that the system gradually gets to a state where it starts spiking regularly. Therefore this current can + I = 0.333333..: In the plot we see that the system gradually gets to a state where it starts spiking regularly. Therefore this current can be regarded as the threshold current where the equilibrium shifts from a stable one to an unstable one. The threshold theoretically is (1/3) - which is a non terminating number. However, the computer cannot store an infinitely long number, and hence it rounds up the numbers. - One possibility is that (1/3) is rounded up to 0.33333333...4. Which is slightly above the threshold and hence we see that the system exhibits regular spikes - after a long transient state. Therefore, at I_ext = 1/3, we don't see a peak frequency of above 20 due to the long transient state. + which is a non terminating number. However, the computer cannot store an infinitely long number, and hence it rounds up the numbers. + One possibility is that (1/3) is rounded up to 0.33333333...4. Which is slightly above the threshold and hence we see that the system exhibits regular spikes + after a long transient state. Therefore, at I_ext = 1/3, we don't see a peak frequency of above 20 due to the long transient state. """ - assert peak_freq[j]>20 + assert peak_freq[j] > 20 if INTEGRATION_TEST_DEBUG_PLOTS: - self._timeseries_plot(N1,t_log, h_log, y_log, sym_list, basedir="", fn_snip = " I= " + str(I_ext[j]) + " peaks freq = " + str(peak_freq[j]), title_snip= " I= " + str(I_ext[j]) + " peaks freq= " + str(peak_freq[j])) + self._timeseries_plot(N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip=" I= " + str(I_ext[j]) + " peaks freq = " + str(peak_freq[j]), title_snip=" I= " + str(I_ext[j]) + " peaks freq= " + str(peak_freq[j])) if INTEGRATION_TEST_DEBUG_PLOTS: - self._FI_curve(I_ext,peak_freq,basedir="",fn_snip = "FI curve", title_snip = "FI curve") - - def peak_detection(self,y_log,N1,threshold_V_for_peak,time_analysis_start,T): #function that determines the frequency of peaks in the plot for V vs time - peaks, _ = find_peaks(np.array(y_log)[N1:,0], height = threshold_V_for_peak ) #finding peaks above 1.5 microvolts ignoring the first 200 ms - frequency = int(len(peaks)/((T-time_analysis_start)*0.001)) #frequency (in Hz) of the peaks for every value of current + self._FI_curve(I_ext, peak_freq, basedir="", fn_snip="FI curve", title_snip="FI curve") + + def peak_detection(self, y_log, N1, threshold_V_for_peak, time_analysis_start, T): # function that determines the frequency of peaks in the plot for V vs time + peaks, _ = find_peaks(np.array(y_log)[N1:, 0], height=threshold_V_for_peak) # finding peaks above 1.5 microvolts ignoring the first 200 ms + frequency = int(len(peaks) / ((T - time_analysis_start) * 0.001)) # frequency (in Hz) of the peaks for every value of current return frequency - - def _timeseries_plot(self,N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip="", title_snip=""): + + def _timeseries_plot(self, N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip="", title_snip=""): fig, ax = plt.subplots(len(y_log[0]), sharex=True) for i, sym in enumerate(sym_list): ax[i].plot(np.array(t_log)[N1:], np.array(y_log)[N1:, i], label=str(sym)) @@ -150,20 +140,21 @@ def _timeseries_plot(self,N1, t_log, h_log, y_log, sym_list, basedir="", fn_snip _ax.grid(True) ax[-1].set_xlabel("Time [ms]") fig.suptitle("V vs time" + title_snip) - fn = os.path.join(basedir, "test_fitzhughnagumo" + fn_snip + ".png") + fn = os.path.join(basedir, "test_fitzhughnagumo" + fn_snip + ".png") print("Saving to " + fn) plt.savefig(fn, dpi=600) plt.close(fig) - - def _FI_curve(self,I_ext,num_peaks,basedir="",fn_snip="",title_snip=""): + + def _FI_curve(self, I_ext, num_peaks, basedir="", fn_snip="", title_snip=""): plt.title(title_snip) plt.xlabel("External current (arbitrary units)") plt.ylabel("Frequency of spikes in Hz") - plt.plot(I_ext, num_peaks) #plotting the frequency of peaks vs external current + plt.plot(I_ext, num_peaks) # plotting the frequency of peaks vs external current fn = os.path.join(basedir, "test_fitzhughnagumo " + fn_snip + ".png") print("Saving to " + fn) - plt.savefig(fn,dpi=600) + plt.savefig(fn, dpi=600) plt.close() - + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_iaf_analytic.py b/tests/test_iaf_analytic.py new file mode 100644 index 00000000..941ff6f6 --- /dev/null +++ b/tests/test_iaf_analytic.py @@ -0,0 +1,112 @@ +# +# test_iaf_analytic.py +# +# This file is part of the NEST ODE toolbox. +# +# Copyright (C) 2017 The NEST Initiative +# +# The NEST ODE toolbox is free software: you can redistribute it +# and/or modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation, either version 2 of +# the License, or (at your option) any later version. +# +# The NEST ODE toolbox is distributed in the hope that it will be +# useful, but WITHOUT ANY WARRANTY; without even the implied warranty +# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +# + +import json +import os +import unittest +import sympy +import numpy as np + +try: + import matplotlib as mpl + mpl.use('Agg') + import matplotlib.pyplot as plt + INTEGRATION_TEST_DEBUG_PLOTS = True +except Exception: + INTEGRATION_TEST_DEBUG_PLOTS = False + +import odetoolbox +from odetoolbox.analytic_integrator import AnalyticIntegrator +from odetoolbox.spike_generator import SpikeGenerator +from math import e +from sympy import exp, sympify +import scipy +import scipy.special +import scipy.linalg + + +def open_json(fname): + absfname = os.path.join(os.path.abspath(os.path.dirname(__file__)), fname) + with open(absfname) as infile: + indict = json.load(infile) + return indict + + +class TestAnalyticIntegratorIAF(unittest.TestCase): + """ + Test that analytic integrator returns the same result when caching is disabled and enabled. + """ + + def test_analytic_integrator_iaf(self): + debug = True + h = 1 # [ms] + T = 1000 # [ms] + indict = open_json("iaf.json") + solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True) + print("Got solver_dict from ode-toolbox: ") + print(json.dumps(solver_dict, indent=2)) + assert len(solver_dict) == 1 + solver_dict = solver_dict[0] + assert solver_dict["solver"] == "analytical" + ODE_INITIAL_VALUES = {"H_s": 1, "I_s": 0, "V_m": 0} + _parms = {"C": 1, "tau_s": 10, "tau_m": 10} + + if not "parameters" in solver_dict.keys(): + solver_dict["parameters"] = {} + solver_dict["parameters"].update(_parms) + N = int(np.ceil(T / h) + 1) + timevec = np.linspace(0., T, N) + state = {True: {}, False: {}} + state[True] = {sym: [] for sym in solver_dict["state_variables"]} + state[True]["timevec"] = [] + analytic_integrator = AnalyticIntegrator(solver_dict, spike_times=None, enable_caching=True) # spike_times, enable_caching=use_caching) + analytic_integrator.set_initial_values(ODE_INITIAL_VALUES) + analytic_integrator.reset() + for step, t in enumerate(timevec): + state_ = analytic_integrator.get_value(t) + state[True]["timevec"].append(t) + for sym, val in state_.items(): + state[True][sym].append(val) + for k, v in state[True].items(): + state[True][k] = np.array(v) + + if INTEGRATION_TEST_DEBUG_PLOTS: + fig, ax = plt.subplots(3, sharex=True) + ax[0].plot(timevec, state[True]["H_s"], label="H_s") + ax[1].plot(timevec, state[True]["I_s"], label="I_s") + ax[2].plot(timevec, state[True]["V_m"], label="V_m") + for _ax in ax: + _ax.legend() + _ax.grid(True) + ax[-1].set_xlabel("Time [ms]") + fn = os.path.join("", "test_iaf_analytic.png") + print("Saving to " + fn) + plt.savefig(fn, dpi=600) + plt.close(fig) + np.testing.assert_allclose(state[True]["timevec"], timevec) + np.testing.assert_allclose(state[True]["timevec"], state[False]["timevec"]) + for sym, val in state_.items(): + np.testing.assert_allclose(state[True][sym], state[False][sym]) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) From 61d98e4a05eecbe35c073f8cad4dfa0de2356de8 Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Tue, 15 Sep 2020 00:32:23 +0200 Subject: [PATCH 12/13] removed test_iaf_analystic.py --- tests/test_iaf_analytic.py | 112 ------------------------------------- 1 file changed, 112 deletions(-) delete mode 100644 tests/test_iaf_analytic.py diff --git a/tests/test_iaf_analytic.py b/tests/test_iaf_analytic.py deleted file mode 100644 index 941ff6f6..00000000 --- a/tests/test_iaf_analytic.py +++ /dev/null @@ -1,112 +0,0 @@ -# -# test_iaf_analytic.py -# -# This file is part of the NEST ODE toolbox. -# -# Copyright (C) 2017 The NEST Initiative -# -# The NEST ODE toolbox is free software: you can redistribute it -# and/or modify it under the terms of the GNU General Public License -# as published by the Free Software Foundation, either version 2 of -# the License, or (at your option) any later version. -# -# The NEST ODE toolbox is distributed in the hope that it will be -# useful, but WITHOUT ANY WARRANTY; without even the implied warranty -# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with NEST. If not, see . -# - -import json -import os -import unittest -import sympy -import numpy as np - -try: - import matplotlib as mpl - mpl.use('Agg') - import matplotlib.pyplot as plt - INTEGRATION_TEST_DEBUG_PLOTS = True -except Exception: - INTEGRATION_TEST_DEBUG_PLOTS = False - -import odetoolbox -from odetoolbox.analytic_integrator import AnalyticIntegrator -from odetoolbox.spike_generator import SpikeGenerator -from math import e -from sympy import exp, sympify -import scipy -import scipy.special -import scipy.linalg - - -def open_json(fname): - absfname = os.path.join(os.path.abspath(os.path.dirname(__file__)), fname) - with open(absfname) as infile: - indict = json.load(infile) - return indict - - -class TestAnalyticIntegratorIAF(unittest.TestCase): - """ - Test that analytic integrator returns the same result when caching is disabled and enabled. - """ - - def test_analytic_integrator_iaf(self): - debug = True - h = 1 # [ms] - T = 1000 # [ms] - indict = open_json("iaf.json") - solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True) - print("Got solver_dict from ode-toolbox: ") - print(json.dumps(solver_dict, indent=2)) - assert len(solver_dict) == 1 - solver_dict = solver_dict[0] - assert solver_dict["solver"] == "analytical" - ODE_INITIAL_VALUES = {"H_s": 1, "I_s": 0, "V_m": 0} - _parms = {"C": 1, "tau_s": 10, "tau_m": 10} - - if not "parameters" in solver_dict.keys(): - solver_dict["parameters"] = {} - solver_dict["parameters"].update(_parms) - N = int(np.ceil(T / h) + 1) - timevec = np.linspace(0., T, N) - state = {True: {}, False: {}} - state[True] = {sym: [] for sym in solver_dict["state_variables"]} - state[True]["timevec"] = [] - analytic_integrator = AnalyticIntegrator(solver_dict, spike_times=None, enable_caching=True) # spike_times, enable_caching=use_caching) - analytic_integrator.set_initial_values(ODE_INITIAL_VALUES) - analytic_integrator.reset() - for step, t in enumerate(timevec): - state_ = analytic_integrator.get_value(t) - state[True]["timevec"].append(t) - for sym, val in state_.items(): - state[True][sym].append(val) - for k, v in state[True].items(): - state[True][k] = np.array(v) - - if INTEGRATION_TEST_DEBUG_PLOTS: - fig, ax = plt.subplots(3, sharex=True) - ax[0].plot(timevec, state[True]["H_s"], label="H_s") - ax[1].plot(timevec, state[True]["I_s"], label="I_s") - ax[2].plot(timevec, state[True]["V_m"], label="V_m") - for _ax in ax: - _ax.legend() - _ax.grid(True) - ax[-1].set_xlabel("Time [ms]") - fn = os.path.join("", "test_iaf_analytic.png") - print("Saving to " + fn) - plt.savefig(fn, dpi=600) - plt.close(fig) - np.testing.assert_allclose(state[True]["timevec"], timevec) - np.testing.assert_allclose(state[True]["timevec"], state[False]["timevec"]) - for sym, val in state_.items(): - np.testing.assert_allclose(state[True][sym], state[False][sym]) - - -if __name__ == "__main__": - import pytest - pytest.main([__file__]) From 6e23d97bfece99e3e9ed220d0736a0fb9c0c5745 Mon Sep 17 00:00:00 2001 From: Shraddha P Jain Date: Tue, 15 Sep 2020 17:40:14 +0200 Subject: [PATCH 13/13] added an extra line between functions --- odetoolbox/integrator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/odetoolbox/integrator.py b/odetoolbox/integrator.py index 978f883a..97c2ea47 100644 --- a/odetoolbox/integrator.py +++ b/odetoolbox/integrator.py @@ -58,6 +58,7 @@ def set_spike_times(self, spike_times: Dict[str, float]): # spike_times is a di self.all_spike_times = [self.all_spike_times[i] for i in idx] self.all_spike_times_sym = [self.all_spike_times_sym[i] for i in idx] + def get_spike_times(self): r""" Get spike times. @@ -66,6 +67,7 @@ def get_spike_times(self): """ return self.spike_times + def get_sorted_spike_times(self): r""" Returns a global, sorted list of spike times.