Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 11 additions & 33 deletions bmtk/simulator/filternet/modules/record_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,18 @@
import h5py
import numpy as np
import glob
import uuid
from pathlib import Path

from .base import SimModule
from bmtk.utils.io.ioutils import bmtk_world_comm
from bmtk.simulator.filternet.io_tools import io

import tempfile

class RecordRates(SimModule):
def __init__(self, csv_file=None, h5_file=None, tmp_dir='output', sort_order='node_id',
compression='gzip', clean_temp_files=True):
compression='gzip'):
self._tmp_dir = tmp_dir
self._csv_file = csv_file if csv_file is None or os.path.isabs(csv_file) else os.path.join(tmp_dir, csv_file)
self._save_to_csv = csv_file is not None
self._tmp_rates_path = None
self._clean_tmp_files = clean_temp_files

h5_file = h5_file if h5_file is None or os.path.isabs(h5_file) else os.path.join(tmp_dir, h5_file)
self._save_to_h5 = h5_file is not None
Expand All @@ -38,24 +34,6 @@ def __init__(self, csv_file=None, h5_file=None, tmp_dir='output', sort_order='no
self._node_ids = {}
self._firing_rates = {}
self._node_counter = 0
self._runtime_gid = bmtk_world_comm.global_uuid(default='filternet') # self.get_global_uuid()

def get_global_uuid(self):
try:
if bmtk_world_comm.MPI_size == 1:
return str(uuid.uuid4().hex)

if bmtk_world_comm.MPI_rank == 0:
bcast_data = str(uuid.uuid4().hex)
else:
bcast_data = None

bcast_data = bmtk_world_comm.comm.bcast(bcast_data, root=0)
return bcast_data

except Exception as e:
return 'filternet'


def initialize(self, sim):
self._node_counter = 0
Expand All @@ -75,12 +53,15 @@ def save(self, sim, cell, times, rates):
self._node_ids[cell.population][self._node_counter] = cell.node_id
self._node_counter += 1


def finalize(self, sim):
io.log_debug('Writing rates to file(s)...')
if bmtk_world_comm.MPI_size > 1:
self._tmp_rates_path = os.path.join(self._tmp_dir, f'.rates.{self._runtime_gid}.{bmtk_world_comm.MPI_rank}.h5')

fd, path = tempfile.mkstemp(
prefix='.rates.rank_{}.'.format(bmtk_world_comm.MPI_rank),
suffix='.h5',
dir=self._tmp_dir
)
self._tmp_rates_path = path
os.close(fd)
self._write_rates_on_rank()
bmtk_world_comm.barrier()

Expand Down Expand Up @@ -119,9 +100,7 @@ def finalize(self, sim):
csv_writer.writerow([node_id, pop, ts, fr])

bmtk_world_comm.barrier()
if self._clean_tmp_files:
self._clean()
io.log_debug('Writing rates to file(s)... done.')
self._clean()

def _write_rates_on_rank(self):
with h5py.File(self._tmp_rates_path, 'w') as h5:
Expand All @@ -133,13 +112,12 @@ def _write_rates_on_rank(self):

def _combine_rates(self):
n_cells = {}
rates_paths = bmtk_world_comm.gather(self._tmp_rates_path, root=0)
if bmtk_world_comm.MPI_rank == 0:
rates_paths = glob.glob(os.path.join(self._tmp_dir, f'.rates.{self._runtime_gid}.*.h5'))
h5_handles = []
timestamps = None

for rp in rates_paths:
assert(Path(rp).exists())
rates_h5 = h5py.File(rp, 'r')
h5_handles.append(rates_h5)
for pop, pop_grp in rates_h5.items():
Expand Down
28 changes: 5 additions & 23 deletions bmtk/utils/io/ioutils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import uuid

class BMTKWorldComm(object):
def __init__(self):
self._comm = None
self._global_uuid = None

@property
def comm(self):
Expand Down Expand Up @@ -35,30 +32,15 @@ def MPI_size(self):
else:
return self.comm.Get_size()

def global_uuid(self, default='NA'):
if self._global_uuid is None:
try:
if bmtk_world_comm.MPI_size == 1:
self._global_uuid = str(uuid.uuid4().hex)
else:
if bmtk_world_comm.MPI_rank == 0:
bcast_data = str(uuid.uuid4().hex)
else:
bcast_data = None

bcast_data = bmtk_world_comm.comm.bcast(bcast_data, root=0)
self._global_uuid = bcast_data

except Exception as e:
self._global_uuid = default

return self._global_uuid


def barrier(self):
if self.comm is not None:
self.comm.Barrier()

def gather(self, data, root=0):
if self.comm is not None:
return self.comm.gather(data, root=root)



bmtk_world_comm = BMTKWorldComm()

Expand Down
Loading