From 6b1f8a22f246f95ac515bdbfc41194ee55eb1b3f Mon Sep 17 00:00:00 2001 From: Kael Dai Date: Sat, 28 Mar 2026 19:29:31 -0700 Subject: [PATCH] NEURON 9.0.1 is missing PtrValue.ptr_update_callback() function. Allowing BioNet to run without it --- bmtk/simulator/bionet/biocell.py | 9 ++-- bmtk/simulator/bionet/biosimulator.py | 13 +++++- bmtk/simulator/bionet/modules/ecp.py | 2 +- .../cell_metrics/sus_sus_cells_v3.csv | 2 - .../cell_metrics/trans_sus_cells_v3.csv | 2 - bmtk/simulator/pointnet/pointnetwork.py | 4 +- .../spike_trains/spike_train_readers.py | 44 ++++++++++++------- .../reports/spike_trains/spike_trains.py | 6 ++- bmtk/utils/sonata/column_property.py | 2 +- bmtk/utils/sonata/group.py | 2 +- .../lgnmodel/test_subclass_metrics.py | 18 +++++--- 11 files changed, 68 insertions(+), 36 deletions(-) diff --git a/bmtk/simulator/bionet/biocell.py b/bmtk/simulator/bionet/biocell.py index ea5850915..e83fbe20c 100644 --- a/bmtk/simulator/bionet/biocell.py +++ b/bmtk/simulator/bionet/biocell.py @@ -32,6 +32,7 @@ pc = h.ParallelContext() # object to access MPI methods + class ConnectionStruct(object): def __init__(self, edge_prop, src_node, syn, connector, is_virtual=False, is_gap_junc=False): self._src_node = src_node @@ -378,10 +379,12 @@ def __set_extracell_mechanism(self): def setup_ecp(self): self.im_ptr = h.PtrVector(self.morphology.nseg) # pointer vector - # used for gathering an array of i_membrane values from the pointer vector - self.im_ptr.ptr_update_callback(self.set_im_ptr) - self.imVec = h.Vector(self.morphology.nseg) + try: + self.im_ptr.ptr_update_callback(self.set_im_ptr) + except AttributeError as e: + pass + self.imVec = h.Vector(self.morphology.nseg) self.__set_extracell_mechanism() # for sec in self.hobj.all: # sec.insert('extracellular') diff --git a/bmtk/simulator/bionet/biosimulator.py b/bmtk/simulator/bionet/biosimulator.py index 954e5f9b7..105d9bf29 100644 --- a/bmtk/simulator/bionet/biosimulator.py +++ b/bmtk/simulator/bionet/biosimulator.py @@ -38,6 +38,16 @@ pc = h.ParallelContext() # object to access MPI methods +if not hasattr(h.PtrVector(1), 'ptr_update_callback'): + io.log_warning( + f'NEURON {h.nrnversion()} is missing "ptr_update_callback" that may sometimes effect ECP results.' + ' If not seeing sensible results or simulation returns a pointer error please try a different version of neuron (ex 8.2.4)' + ) + cache_efficient = False +else: + cache_efficient = True + + class BioSimulator(Simulator): """Includes methods to run and control the simulation""" @@ -68,7 +78,8 @@ def __init__(self, network, dt, tstop, v_init, celsius, nsteps_block, start_from h.steps_per_ms = 1/h.dt pc.setup_transfer()#Sets up gap junctions. self._set_init_conditions() # call to save state - h.cvode.cache_efficient(1) + if cache_efficient: + h.cvode.cache_efficient(1) h.pysim = self # use this objref to be able to call postFadvance from proc advance in advance.hoc self._iclamps = [] diff --git a/bmtk/simulator/bionet/modules/ecp.py b/bmtk/simulator/bionet/modules/ecp.py index c03a1d04d..a5ed1fb38 100644 --- a/bmtk/simulator/bionet/modules/ecp.py +++ b/bmtk/simulator/bionet/modules/ecp.py @@ -135,7 +135,7 @@ def _create_cell_file(self, gid): file_name = os.path.join(self._contributions_dir, '{}.h5'.format(int(gid))) file_h5 = h5py.File(file_name, 'a') self._cell_var_files[gid] = file_h5 - file_h5.create_dataset('/ecp/data', (self._nsteps, self._rel_nsites), maxshape=(None, self._rel_nsites), chunks=True) + file_h5.create_dataset('/ecp/data', (self._nsteps, self._rel_nsites), maxshape=(None, self._rel_nsites), dtype=float, chunks=True) # self._cell_var_files[gid] = file_h5['ecp'] def _calculate_ecp(self, sim): diff --git a/bmtk/simulator/filternet/lgnmodel/cell_metrics/sus_sus_cells_v3.csv b/bmtk/simulator/filternet/lgnmodel/cell_metrics/sus_sus_cells_v3.csv index 240114eaf..23d02a762 100755 --- a/bmtk/simulator/filternet/lgnmodel/cell_metrics/sus_sus_cells_v3.csv +++ b/bmtk/simulator/filternet/lgnmodel/cell_metrics/sus_sus_cells_v3.csv @@ -7,5 +7,3 @@ Mouse_id,Shank,Clu,Area,Layer,WVF_ratio,WVF_duration,Bl_ctr_row,Bl_ctr_col,Wh_ct 75,3,12,1,NaN,0.045806832,0.35,7,9,8,9,3.652408635,1.029283045,74.5,105.5,50.48434961,19.19976844,0.57805952,0.496553402,0.067924028,45.16947841,0.001503759,0.463704924,4,151,152,5,7,3.9,2.75,90,0.16,4,50,45.83333333,17.58333333,0.441646919,0.181271486,0.464006938,0.547915143,0.595463138,0.746445498,0.836846909,0.991992685,0.001800185,11.04166667,11.5,17.58333333,11,8.041666667,6.885552531,6.495778681,14.71455815,11.77583318,5.816827922 78,2,8,1,NaN,0.337036581,0.3,6,11,7,10,4.865683305,4.845646089,148.5,108.5,25.39909748,20.21502353,0.4358264,0.550653522,31.38090079,116.7614036,0.268760908,0.209763145,5.656854249,186,169,9,5,3,2.25,180,0.16,15,50,33.33333333,13.70833333,0.488981763,0.27877874,0.602923264,0.728106756,0.27027027,0.425531915,0.35278986,0.42206496,0.001579724,5.708333333,7.375,7.25,7.541666667,13.70833333,2.465575603,3.101341176,3.666361028,2.655445907,4.836161001 78,3,13,1,NaN,0.146747351,0.3,5,7,5,8,5.36647229,1.800870186,79.5,109.5,40.72018458,16.63837297,0.497626777,0.479883721,10.18860415,90.47480487,0.112612613,0.398203514,4,113,131,5,7,4.1,2.25,315,0.08,4,62.5,37.5,13.5,0.602623457,0.21507263,0.367088608,0.613899614,0.033492823,0.064814815,1.364038002,1.636845602,6.12E-06,9.083333333,11.41666667,13.5,11.70833333,5.666666667,6.518626565,11.17129891,18.41451302,18.08773131,7.404301371 -,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, -,,,,,,,,,,,,,,,36,18,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, diff --git a/bmtk/simulator/filternet/lgnmodel/cell_metrics/trans_sus_cells_v3.csv b/bmtk/simulator/filternet/lgnmodel/cell_metrics/trans_sus_cells_v3.csv index a18e3155d..609f49cdb 100755 --- a/bmtk/simulator/filternet/lgnmodel/cell_metrics/trans_sus_cells_v3.csv +++ b/bmtk/simulator/filternet/lgnmodel/cell_metrics/trans_sus_cells_v3.csv @@ -4,5 +4,3 @@ Mouse_id,Shank,Clu,Area,Layer,WVF_ratio,WVF_duration,Bl_ctr_row,Bl_ctr_col,Wh_ct 75,1,2,1,NaN,0.179020336,0.3,7,12,7,13,5.870579009,5.231830094,64.5,87.5,14.71608404,10.38016037,0.416686052,0.30332155,31.99221704,162.2705021,0.197153621,0.435460466,4,205,223,5,6,1,0.5,270,0.16,2,37.5,12.5,3.291666667,0.674157303,0.064616927,0.244094488,0.451428571,0.244094488,0.392405063,1.063534332,1.254018092,0.001120253,3.166666667,3.291666667,3.25,1.75,2.291666667,2.326214764,3.500800508,2.853026994,2.125669008,2.369890296 78,3,7,1,NaN,0.25704279,0.3,6,6,8,5,8.514390322,3.457795816,69.5,68.5,24.97441569,17.92925808,0.255980641,0.343553303,14.87536206,250.4358901,0.059397884,0.774372866,8.94427191,96,80,5,5,3.3,2.5,225,0.02,15,50,25,5.708333333,0.682432432,0.051506564,0.033962264,0.37254902,-0.021428571,-0.04379562,1.492219197,2.654987403,0.005427301,3.208333333,3.458333333,3.791666667,5.125,5.708333333,4.43147768,4.621665853,5.537427652,8.081975,8.518084585 78,3,11,1,NaN,0.2538379,0.25,6,9,6,10,5.969189721,2.135885692,59.5,83.5,78.43921026,33.31181738,0.268908037,0.485015497,1.290556526,124.9802109,0.010326087,0.398062208,4,150,168,4,6,11.2,13,315,0.04,8,75,50,21.83333333,0.654341603,0.112935538,0.195664575,0.454123113,0.116080937,0.208015267,1.650983868,4.080733712,0.000335046,10.29166667,13.125,19.16666667,21.83333333,17.29166667,10.79237938,19.16833868,29.71094753,36.04648112,25.37906864 -,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, -,,,,,,,,,,,,,,,46,30,,,,,,,,,,,,5,,,,,,,,,,,,,,,,,,,,,,,,,, diff --git a/bmtk/simulator/pointnet/pointnetwork.py b/bmtk/simulator/pointnet/pointnetwork.py index 409a537c4..f2d5c1a96 100644 --- a/bmtk/simulator/pointnet/pointnetwork.py +++ b/bmtk/simulator/pointnet/pointnetwork.py @@ -180,7 +180,7 @@ def build_recurrent_edges(self, force_resolution=False): if np.isscalar(edge.nest_params['weight']): edge.nest_params['weight'] = np.full(shape=len(nest_srcs), fill_value=edge.nest_params['weight']) - self._nest_connect(nest_srcs, nest_trgs, conn_spec='one_to_one', syn_spec=edge.nest_params) + self._nest_connect(nest_srcs.copy(), nest_trgs.copy(), conn_spec='one_to_one', syn_spec=edge.nest_params) def find_edges(self, source_nodes=None, target_nodes=None): # TODO: Move to parent @@ -239,7 +239,7 @@ def add_spike_trains(self, spike_trains, node_set, sg_params={'precise_times': T def _nest_connect(self, nest_srcs, nest_trgs, conn_spec='one_to_one', syn_spec=None): """Calls nest.Connect but with some extra error logging and exception handling.""" try: - nest.Connect(nest_srcs, nest_trgs, conn_spec=conn_spec, syn_spec=syn_spec) + nest.Connect(nest_srcs.copy(), nest_trgs.copy(), conn_spec=conn_spec, syn_spec=syn_spec) except nest.kernel.NESTErrors.BadDelay as bde: # An occuring issue is when dt > delay, add some extra messaging in log to help users fix problem. diff --git a/bmtk/utils/reports/spike_trains/spike_train_readers.py b/bmtk/utils/reports/spike_trains/spike_train_readers.py index b1b3a949a..6377a98f0 100644 --- a/bmtk/utils/reports/spike_trains/spike_train_readers.py +++ b/bmtk/utils/reports/spike_trains/spike_train_readers.py @@ -30,6 +30,7 @@ from .spike_trains_api import SpikeTrainsReadOnlyAPI from .core import SortOrder, csv_headers, col_population, col_timestamps, col_node_ids, pop_na +from .core import MPI_rank, MPI_size, comm_barrier, comm GRP_spikes_root = 'spikes' @@ -47,6 +48,20 @@ } +def _open_h5(path, mode='r'): + if h5py.get_config().mpi: + return h5py.File(path, mode, driver='mpio', comm=comm) + else: + # If opening the spike-train h5 file independently across multiple ranks then stagger the + # opening. With some file-systems opening the same file across ranks can cause deadlocks. + h5_obj = None + for r in range(MPI_size): + if r == MPI_rank: + h5_obj = h5py.File(path, 'r') + comm_barrier() + return h5_obj + + def load_sonata_file(path, version=None, **kwargs): """Loads a Sonata file reader, making sure it matches the correct version. @@ -55,28 +70,27 @@ def load_sonata_file(path, version=None, **kwargs): :param kwargs: :return: """ + h5_handle = _open_h5(path, 'r') + try: - with h5py.File(path, 'r') as h5: - spikes_root = h5[GRP_spikes_root] - for name, h5_obj in spikes_root.items(): - if isinstance(h5_obj, h5py.Group): - # In case there exists a population subgroup - return SonataSTReader(path, **kwargs) + spikes_root = h5_handle[GRP_spikes_root] + for name, h5_obj in spikes_root.items(): + if isinstance(h5_obj, h5py.Group): + # In case there exists a population subgroup + return SonataSTReader(path, h5_handle=h5_handle, **kwargs) except Exception: pass try: - with h5py.File(path, 'r') as h5: - spikes_root = h5[GRP_spikes_root] - if 'gids' in spikes_root and 'timestamps' in spikes_root: - return SonataOldReader(path, **kwargs) + spikes_root = h5_handle[GRP_spikes_root] + if 'gids' in spikes_root and 'timestamps' in spikes_root: + return SonataOldReader(path, h5_handle=h5_handle, **kwargs) except Exception: pass try: - with h5py.File(path, 'r') as h5: - if '/spikes' in h5: - return EmptySonataReader(path, **kwargs) + if '/spikes' in h5_handle: + return EmptySonataReader(path, h5_handle=h5_handle, **kwargs) except Exception: pass @@ -91,9 +105,9 @@ def to_list(v): class SonataSTReader(SpikeTrainsReadOnlyAPI): - def __init__(self, path, **kwargs): + def __init__(self, path, h5_handle=None, **kwargs): self._path = path - self._h5_handle = h5py.File(self._path, 'r') + self._h5_handle = h5_handle or _open_h5(path, 'r') self._DATASET_node_ids = 'node_ids' self._n_spikes = None # TODO: Create a function for looking up population and can return errors if more than one diff --git a/bmtk/utils/reports/spike_trains/spike_trains.py b/bmtk/utils/reports/spike_trains/spike_trains.py index ca69319c6..128321b97 100644 --- a/bmtk/utils/reports/spike_trains/spike_trains.py +++ b/bmtk/utils/reports/spike_trains/spike_trains.py @@ -27,7 +27,11 @@ from .spike_train_readers import load_sonata_file, CSVSTReader, NWBSTReader from .spike_train_buffer import STMemoryBuffer, STCSVBuffer, STMPIBuffer, STCSVMPIBufferV2 from bmtk.utils.sonata.utils import get_node_ids -from scipy.stats import gamma + +try: + from scipy.stats import gamma +except ImportError as ie: + pass import warnings class SpikeTrains(object): diff --git a/bmtk/utils/sonata/column_property.py b/bmtk/utils/sonata/column_property.py index 34eaa5aee..d48368af5 100644 --- a/bmtk/utils/sonata/column_property.py +++ b/bmtk/utils/sonata/column_property.py @@ -85,7 +85,7 @@ def from_csv(cls, pd_obj, name=None): return cls(c_name, c_dtype, 1) elif isinstance(pd_obj, pd.DataFrame): - return [cls(name, pd_obj[name].dtype, 1) for name in pd_obj.columns] + return [cls(name, pd_obj[name].to_numpy().dtype, 1) for name in pd_obj.columns] else: raise Exception('Unable to convert pandas object {} to a property or list of properties.'.format(pd_obj)) diff --git a/bmtk/utils/sonata/group.py b/bmtk/utils/sonata/group.py index 4a7ab4189..38eaaa432 100644 --- a/bmtk/utils/sonata/group.py +++ b/bmtk/utils/sonata/group.py @@ -204,7 +204,7 @@ def get_values(self, property_name, filtered_indicies=True): # TODO: Need to performance test, I think this code could be optimized. node_types_table = self._parent.node_types_table nt_col = node_types_table.column(property_name) - tmp_array = np.empty(shape=len(self._parent_indicies), dtype=nt_col.dtype) + tmp_array = np.empty(shape=len(self._parent_indicies), dtype=type(nt_col.dtype)) for i, ntid in enumerate(self.node_type_ids): tmp_array[i] = node_types_table[ntid][property_name] diff --git a/tests/simulator/filternet/lgnmodel/test_subclass_metrics.py b/tests/simulator/filternet/lgnmodel/test_subclass_metrics.py index adc7b6163..189abcad4 100644 --- a/tests/simulator/filternet/lgnmodel/test_subclass_metrics.py +++ b/tests/simulator/filternet/lgnmodel/test_subclass_metrics.py @@ -19,9 +19,11 @@ def cmp_dicts(d1, d2): return False else: for k in d1.keys(): + if k in ['N_class']: + continue + if isinstance(d1[k], dict): if not cmp_dicts(d1[k], d2[k]): - print(k) return False elif not cmp_vals(d1[k], d2[k]): @@ -68,13 +70,14 @@ def cmp_dicts(d1, d2): trans_sus_expected = { - 'TF1': {'f0_exp': np.array([36., 21.04166667, 22.41666667, 23.54166667, 12.04166667]), 'f1_exp': np.array([ 8.85216752, 7.71744362, 8.79852287, 16.1074593 , 8.09474585]), 'spont_exp': np.array([5.]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 7}, - 'TF15': {'f0_exp': np.array([3.20833333, 3.45833333, 3.79166667, 5.125, 5.70833333]), 'f1_exp': np.array([4.43147768, 4.62166585, 5.53742765, 8.081975 , 8.51808459]), 'spont_exp': np.array([3.3]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 7}, - 'TF2': {'f0_exp': np.array([3.16666667, 3.29166667, 3.25, 1.75, 2.29166667]), 'f1_exp': np.array([2.32621476, 3.50080051, 2.85302699, 2.12566901, 2.3698903 ]), 'spont_exp': np.array([1.]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 7}, - 'TF4': {'f0_exp': np.array([ 5.91666667, 7.95833333, 13.375, 6.58333333, 5.41666667]), 'f1_exp': np.array([5.24514294, 3.8352691 , 4.99435957, 3.69495923, 1.75006614]), 'spont_exp': np.array([4.4]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 7}, - 'TF8': {'f0_exp': np.array([10.29166667, 13.125, 19.16666667, 21.83333333, 17.29166667]), 'f1_exp': np.array([10.79237938, 19.16833868, 29.71094753, 36.04648112, 25.37906864]), 'spont_exp': np.array([11.2]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 7} + 'TF1': {'f0_exp': np.array([36., 21.04166667, 22.41666667, 23.54166667, 12.04166667]), 'f1_exp': np.array([ 8.85216752, 7.71744362, 8.79852287, 16.1074593 , 8.09474585]), 'spont_exp': np.array([5.]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 5}, + 'TF15': {'f0_exp': np.array([3.20833333, 3.45833333, 3.79166667, 5.125, 5.70833333]), 'f1_exp': np.array([4.43147768, 4.62166585, 5.53742765, 8.081975 , 8.51808459]), 'spont_exp': np.array([3.3]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 5}, + 'TF2': {'f0_exp': np.array([3.16666667, 3.29166667, 3.25, 1.75, 2.29166667]), 'f1_exp': np.array([2.32621476, 3.50080051, 2.85302699, 2.12566901, 2.3698903 ]), 'spont_exp': np.array([1.]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 5}, + 'TF4': {'f0_exp': np.array([ 5.91666667, 7.95833333, 13.375, 6.58333333, 5.41666667]), 'f1_exp': np.array([5.24514294, 3.8352691 , 4.99435957, 3.69495923, 1.75006614]), 'spont_exp': np.array([4.4]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 5}, + 'TF8': {'f0_exp': np.array([10.29166667, 13.125, 19.16666667, 21.83333333, 17.29166667]), 'f1_exp': np.array([10.79237938, 19.16833868, 29.71094753, 36.04648112, 25.37906864]), 'spont_exp': np.array([11.2]), 'si_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_inf_exp': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'ttp_exp': np.array([[np.nan, np.nan]]), 'f0_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'f1_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'spont_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'si_std': np.array([[np.nan, np.nan, np.nan, np.nan, np.nan]]), 'nsub': 1, 'N_class': 5} } +from pprint import pprint @pytest.mark.parametrize("cell_subclass,expected_val", [ @@ -96,7 +99,8 @@ def test_get_data_metrics(cell_subclass, expected_val): if __name__ == '__main__': cell_metrics = get_data_metrics_for_each_subclass('tOFF') - print(cmp_dicts(cell_metrics, tOFF_expected)) + # print(cmp_dicts(cell_metrics, tOFF_expected)) + test_get_data_metrics('trans_sus', trans_sus_expected) #print(cell_metrics)