diff --git a/bmtk/simulator/filternet/modules/record_rates.py b/bmtk/simulator/filternet/modules/record_rates.py index 39da776e..1c9cd485 100644 --- a/bmtk/simulator/filternet/modules/record_rates.py +++ b/bmtk/simulator/filternet/modules/record_rates.py @@ -4,22 +4,18 @@ import h5py import numpy as np import glob -import uuid -from pathlib import Path from .base import SimModule from bmtk.utils.io.ioutils import bmtk_world_comm -from bmtk.simulator.filternet.io_tools import io - +import tempfile class RecordRates(SimModule): def __init__(self, csv_file=None, h5_file=None, tmp_dir='output', sort_order='node_id', - compression='gzip', clean_temp_files=True): + compression='gzip'): self._tmp_dir = tmp_dir self._csv_file = csv_file if csv_file is None or os.path.isabs(csv_file) else os.path.join(tmp_dir, csv_file) self._save_to_csv = csv_file is not None self._tmp_rates_path = None - self._clean_tmp_files = clean_temp_files h5_file = h5_file if h5_file is None or os.path.isabs(h5_file) else os.path.join(tmp_dir, h5_file) self._save_to_h5 = h5_file is not None @@ -38,24 +34,6 @@ def __init__(self, csv_file=None, h5_file=None, tmp_dir='output', sort_order='no self._node_ids = {} self._firing_rates = {} self._node_counter = 0 - self._runtime_gid = bmtk_world_comm.global_uuid(default='filternet') # self.get_global_uuid() - - def get_global_uuid(self): - try: - if bmtk_world_comm.MPI_size == 1: - return str(uuid.uuid4().hex) - - if bmtk_world_comm.MPI_rank == 0: - bcast_data = str(uuid.uuid4().hex) - else: - bcast_data = None - - bcast_data = bmtk_world_comm.comm.bcast(bcast_data, root=0) - return bcast_data - - except Exception as e: - return 'filternet' - def initialize(self, sim): self._node_counter = 0 @@ -75,12 +53,15 @@ def save(self, sim, cell, times, rates): self._node_ids[cell.population][self._node_counter] = cell.node_id self._node_counter += 1 - def finalize(self, sim): - io.log_debug('Writing rates to file(s)...') if bmtk_world_comm.MPI_size > 1: - self._tmp_rates_path = os.path.join(self._tmp_dir, f'.rates.{self._runtime_gid}.{bmtk_world_comm.MPI_rank}.h5') - + fd, path = tempfile.mkstemp( + prefix='.rates.rank_{}.'.format(bmtk_world_comm.MPI_rank), + suffix='.h5', + dir=self._tmp_dir + ) + self._tmp_rates_path = path + os.close(fd) self._write_rates_on_rank() bmtk_world_comm.barrier() @@ -119,9 +100,7 @@ def finalize(self, sim): csv_writer.writerow([node_id, pop, ts, fr]) bmtk_world_comm.barrier() - if self._clean_tmp_files: - self._clean() - io.log_debug('Writing rates to file(s)... done.') + self._clean() def _write_rates_on_rank(self): with h5py.File(self._tmp_rates_path, 'w') as h5: @@ -133,13 +112,12 @@ def _write_rates_on_rank(self): def _combine_rates(self): n_cells = {} + rates_paths = bmtk_world_comm.gather(self._tmp_rates_path, root=0) if bmtk_world_comm.MPI_rank == 0: - rates_paths = glob.glob(os.path.join(self._tmp_dir, f'.rates.{self._runtime_gid}.*.h5')) h5_handles = [] timestamps = None for rp in rates_paths: - assert(Path(rp).exists()) rates_h5 = h5py.File(rp, 'r') h5_handles.append(rates_h5) for pop, pop_grp in rates_h5.items(): diff --git a/bmtk/utils/io/ioutils.py b/bmtk/utils/io/ioutils.py index 1d3f2c06..8d6bfe3c 100644 --- a/bmtk/utils/io/ioutils.py +++ b/bmtk/utils/io/ioutils.py @@ -1,9 +1,6 @@ -import uuid - class BMTKWorldComm(object): def __init__(self): self._comm = None - self._global_uuid = None @property def comm(self): @@ -35,30 +32,15 @@ def MPI_size(self): else: return self.comm.Get_size() - def global_uuid(self, default='NA'): - if self._global_uuid is None: - try: - if bmtk_world_comm.MPI_size == 1: - self._global_uuid = str(uuid.uuid4().hex) - else: - if bmtk_world_comm.MPI_rank == 0: - bcast_data = str(uuid.uuid4().hex) - else: - bcast_data = None - - bcast_data = bmtk_world_comm.comm.bcast(bcast_data, root=0) - self._global_uuid = bcast_data - - except Exception as e: - self._global_uuid = default - - return self._global_uuid - - def barrier(self): if self.comm is not None: self.comm.Barrier() + def gather(self, data, root=0): + if self.comm is not None: + return self.comm.gather(data, root=root) + + bmtk_world_comm = BMTKWorldComm()