diff --git a/bmtk/simulator/filternet/filtersimulator.py b/bmtk/simulator/filternet/filtersimulator.py index 488101876..2438a58de 100644 --- a/bmtk/simulator/filternet/filtersimulator.py +++ b/bmtk/simulator/filternet/filtersimulator.py @@ -206,10 +206,11 @@ def run(self): for mod in self._sim_mods: mod.save(self, cell, ts, f_rates) # io.log_info('Max firing rate: {}'.format(np.max(max_fr))) - io.log_info('Done.') for mod in self._sim_mods: mod.finalize(self) + io.log_info('Done.') + def local_cells(self): return self._network.cells() diff --git a/bmtk/simulator/filternet/modules/create_spikes.py b/bmtk/simulator/filternet/modules/create_spikes.py index 1f1e0e44c..cf56fe65a 100644 --- a/bmtk/simulator/filternet/modules/create_spikes.py +++ b/bmtk/simulator/filternet/modules/create_spikes.py @@ -7,11 +7,12 @@ from bmtk.utils.reports.spike_trains import SpikeTrains, pop_na, sort_order, sort_order_lu from bmtk.simulator.filternet.lgnmodel import poissongeneration as pg from bmtk.utils.io.ioutils import bmtk_world_comm +from bmtk.simulator.filternet.io_tools import io class SpikesGenerator(SimModule): def __init__(self, spikes_file_csv=None, spikes_file=None, spikes_file_nwb=None, tmp_dir='output', - sort_order='node_id', compression='gzip'): + sort_order='node_id', compression='gzip', clean_temp_files=True): def _get_file_path(file_name): if file_name is None or os.path.isabs(file_name): return file_name @@ -35,10 +36,12 @@ def _get_file_path(file_name): self._save_nwb = spikes_file_nwb is not None self._tmpdir = tmp_dir + self._clean_temp_files = clean_temp_files # self._spike_writer = SpikeTrainWriter(tmp_dir=tmp_dir) self._spike_writer = SpikeTrains(cache_dir=tmp_dir) self._sort_order = sort_order_lu[sort_order] + self._runtime_gid = bmtk_world_comm.global_uuid(default='filternet') def save(self, sim, cell, times, rates): try: @@ -52,6 +55,7 @@ def save(self, sim, cell, times, rates): self._spike_writer.add_spikes(node_ids=cell.gid, timestamps=spike_trains, population=cell.population) def finalize(self, sim): + io.log_debug('Writing spikes to file(s)...') self._spike_writer.flush() if self._save_csv: @@ -64,6 +68,7 @@ def finalize(self, sim): self._spike_writer.to_nwb(self._nwb_fname, sort_order=self._sort_order) self._spike_writer.close() + io.log_debug('Writing spikes to file(s)... done.') def f_rate_to_spike_train(t, f_rate, random_seed, t_window_start, t_window_end, p_spike_max): diff --git a/bmtk/simulator/filternet/modules/record_rates.py b/bmtk/simulator/filternet/modules/record_rates.py index d92f473a2..39da776ed 100644 --- a/bmtk/simulator/filternet/modules/record_rates.py +++ b/bmtk/simulator/filternet/modules/record_rates.py @@ -4,18 +4,22 @@ import h5py import numpy as np import glob +import uuid +from pathlib import Path from .base import SimModule from bmtk.utils.io.ioutils import bmtk_world_comm +from bmtk.simulator.filternet.io_tools import io class RecordRates(SimModule): def __init__(self, csv_file=None, h5_file=None, tmp_dir='output', sort_order='node_id', - compression='gzip'): + compression='gzip', clean_temp_files=True): self._tmp_dir = tmp_dir self._csv_file = csv_file if csv_file is None or os.path.isabs(csv_file) else os.path.join(tmp_dir, csv_file) self._save_to_csv = csv_file is not None self._tmp_rates_path = None + self._clean_tmp_files = clean_temp_files h5_file = h5_file if h5_file is None or os.path.isabs(h5_file) else os.path.join(tmp_dir, h5_file) self._save_to_h5 = h5_file is not None @@ -34,6 +38,24 @@ def __init__(self, csv_file=None, h5_file=None, tmp_dir='output', sort_order='no self._node_ids = {} self._firing_rates = {} self._node_counter = 0 + self._runtime_gid = bmtk_world_comm.global_uuid(default='filternet') # self.get_global_uuid() + + def get_global_uuid(self): + try: + if bmtk_world_comm.MPI_size == 1: + return str(uuid.uuid4().hex) + + if bmtk_world_comm.MPI_rank == 0: + bcast_data = str(uuid.uuid4().hex) + else: + bcast_data = None + + bcast_data = bmtk_world_comm.comm.bcast(bcast_data, root=0) + return bcast_data + + except Exception as e: + return 'filternet' + def initialize(self, sim): self._node_counter = 0 @@ -53,9 +75,12 @@ def save(self, sim, cell, times, rates): self._node_ids[cell.population][self._node_counter] = cell.node_id self._node_counter += 1 + def finalize(self, sim): + io.log_debug('Writing rates to file(s)...') if bmtk_world_comm.MPI_size > 1: - self._tmp_rates_path = os.path.join(self._tmp_dir, '.rates.{}.h5'.format(bmtk_world_comm.MPI_rank)) + self._tmp_rates_path = os.path.join(self._tmp_dir, f'.rates.{self._runtime_gid}.{bmtk_world_comm.MPI_rank}.h5') + self._write_rates_on_rank() bmtk_world_comm.barrier() @@ -94,7 +119,9 @@ def finalize(self, sim): csv_writer.writerow([node_id, pop, ts, fr]) bmtk_world_comm.barrier() - self._clean() + if self._clean_tmp_files: + self._clean() + io.log_debug('Writing rates to file(s)... done.') def _write_rates_on_rank(self): with h5py.File(self._tmp_rates_path, 'w') as h5: @@ -107,11 +134,12 @@ def _write_rates_on_rank(self): def _combine_rates(self): n_cells = {} if bmtk_world_comm.MPI_rank == 0: - rates_paths = glob.glob(os.path.join(self._tmp_dir, '.rates.*.h5')) + rates_paths = glob.glob(os.path.join(self._tmp_dir, f'.rates.{self._runtime_gid}.*.h5')) h5_handles = [] timestamps = None for rp in rates_paths: + assert(Path(rp).exists()) rates_h5 = h5py.File(rp, 'r') h5_handles.append(rates_h5) for pop, pop_grp in rates_h5.items(): diff --git a/bmtk/utils/io/ioutils.py b/bmtk/utils/io/ioutils.py index 1471723e2..1d3f2c066 100644 --- a/bmtk/utils/io/ioutils.py +++ b/bmtk/utils/io/ioutils.py @@ -1,6 +1,9 @@ +import uuid + class BMTKWorldComm(object): def __init__(self): self._comm = None + self._global_uuid = None @property def comm(self): @@ -32,6 +35,26 @@ def MPI_size(self): else: return self.comm.Get_size() + def global_uuid(self, default='NA'): + if self._global_uuid is None: + try: + if bmtk_world_comm.MPI_size == 1: + self._global_uuid = str(uuid.uuid4().hex) + else: + if bmtk_world_comm.MPI_rank == 0: + bcast_data = str(uuid.uuid4().hex) + else: + bcast_data = None + + bcast_data = bmtk_world_comm.comm.bcast(bcast_data, root=0) + self._global_uuid = bcast_data + + except Exception as e: + self._global_uuid = default + + return self._global_uuid + + def barrier(self): if self.comm is not None: self.comm.Barrier()