Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bmtk/simulator/filternet/filtersimulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,11 @@ def run(self):
for mod in self._sim_mods:
mod.save(self, cell, ts, f_rates)
# io.log_info('Max firing rate: {}'.format(np.max(max_fr)))
io.log_info('Done.')
for mod in self._sim_mods:
mod.finalize(self)

io.log_info('Done.')

def local_cells(self):
return self._network.cells()

Expand Down
7 changes: 6 additions & 1 deletion bmtk/simulator/filternet/modules/create_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from bmtk.utils.reports.spike_trains import SpikeTrains, pop_na, sort_order, sort_order_lu
from bmtk.simulator.filternet.lgnmodel import poissongeneration as pg
from bmtk.utils.io.ioutils import bmtk_world_comm
from bmtk.simulator.filternet.io_tools import io


class SpikesGenerator(SimModule):
def __init__(self, spikes_file_csv=None, spikes_file=None, spikes_file_nwb=None, tmp_dir='output',
sort_order='node_id', compression='gzip'):
sort_order='node_id', compression='gzip', clean_temp_files=True):
def _get_file_path(file_name):
if file_name is None or os.path.isabs(file_name):
return file_name
Expand All @@ -35,10 +36,12 @@ def _get_file_path(file_name):
self._save_nwb = spikes_file_nwb is not None

self._tmpdir = tmp_dir
self._clean_temp_files = clean_temp_files

# self._spike_writer = SpikeTrainWriter(tmp_dir=tmp_dir)
self._spike_writer = SpikeTrains(cache_dir=tmp_dir)
self._sort_order = sort_order_lu[sort_order]
self._runtime_gid = bmtk_world_comm.global_uuid(default='filternet')

def save(self, sim, cell, times, rates):
try:
Expand All @@ -52,6 +55,7 @@ def save(self, sim, cell, times, rates):
self._spike_writer.add_spikes(node_ids=cell.gid, timestamps=spike_trains, population=cell.population)

def finalize(self, sim):
io.log_debug('Writing spikes to file(s)...')
self._spike_writer.flush()

if self._save_csv:
Expand All @@ -64,6 +68,7 @@ def finalize(self, sim):
self._spike_writer.to_nwb(self._nwb_fname, sort_order=self._sort_order)

self._spike_writer.close()
io.log_debug('Writing spikes to file(s)... done.')


def f_rate_to_spike_train(t, f_rate, random_seed, t_window_start, t_window_end, p_spike_max):
Expand Down
36 changes: 32 additions & 4 deletions bmtk/simulator/filternet/modules/record_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,22 @@
import h5py
import numpy as np
import glob
import uuid
from pathlib import Path

from .base import SimModule
from bmtk.utils.io.ioutils import bmtk_world_comm
from bmtk.simulator.filternet.io_tools import io


class RecordRates(SimModule):
def __init__(self, csv_file=None, h5_file=None, tmp_dir='output', sort_order='node_id',
compression='gzip'):
compression='gzip', clean_temp_files=True):
self._tmp_dir = tmp_dir
self._csv_file = csv_file if csv_file is None or os.path.isabs(csv_file) else os.path.join(tmp_dir, csv_file)
self._save_to_csv = csv_file is not None
self._tmp_rates_path = None
self._clean_tmp_files = clean_temp_files

h5_file = h5_file if h5_file is None or os.path.isabs(h5_file) else os.path.join(tmp_dir, h5_file)
self._save_to_h5 = h5_file is not None
Expand All @@ -34,6 +38,24 @@ def __init__(self, csv_file=None, h5_file=None, tmp_dir='output', sort_order='no
self._node_ids = {}
self._firing_rates = {}
self._node_counter = 0
self._runtime_gid = bmtk_world_comm.global_uuid(default='filternet') # self.get_global_uuid()

def get_global_uuid(self):
try:
if bmtk_world_comm.MPI_size == 1:
return str(uuid.uuid4().hex)

if bmtk_world_comm.MPI_rank == 0:
bcast_data = str(uuid.uuid4().hex)
else:
bcast_data = None

bcast_data = bmtk_world_comm.comm.bcast(bcast_data, root=0)
return bcast_data

except Exception as e:
return 'filternet'


def initialize(self, sim):
self._node_counter = 0
Expand All @@ -53,9 +75,12 @@ def save(self, sim, cell, times, rates):
self._node_ids[cell.population][self._node_counter] = cell.node_id
self._node_counter += 1


def finalize(self, sim):
io.log_debug('Writing rates to file(s)...')
if bmtk_world_comm.MPI_size > 1:
self._tmp_rates_path = os.path.join(self._tmp_dir, '.rates.{}.h5'.format(bmtk_world_comm.MPI_rank))
self._tmp_rates_path = os.path.join(self._tmp_dir, f'.rates.{self._runtime_gid}.{bmtk_world_comm.MPI_rank}.h5')

self._write_rates_on_rank()
bmtk_world_comm.barrier()

Expand Down Expand Up @@ -94,7 +119,9 @@ def finalize(self, sim):
csv_writer.writerow([node_id, pop, ts, fr])

bmtk_world_comm.barrier()
self._clean()
if self._clean_tmp_files:
self._clean()
io.log_debug('Writing rates to file(s)... done.')

def _write_rates_on_rank(self):
with h5py.File(self._tmp_rates_path, 'w') as h5:
Expand All @@ -107,11 +134,12 @@ def _write_rates_on_rank(self):
def _combine_rates(self):
n_cells = {}
if bmtk_world_comm.MPI_rank == 0:
rates_paths = glob.glob(os.path.join(self._tmp_dir, '.rates.*.h5'))
rates_paths = glob.glob(os.path.join(self._tmp_dir, f'.rates.{self._runtime_gid}.*.h5'))
h5_handles = []
timestamps = None

for rp in rates_paths:
assert(Path(rp).exists())
rates_h5 = h5py.File(rp, 'r')
h5_handles.append(rates_h5)
for pop, pop_grp in rates_h5.items():
Expand Down
23 changes: 23 additions & 0 deletions bmtk/utils/io/ioutils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import uuid

class BMTKWorldComm(object):
def __init__(self):
self._comm = None
self._global_uuid = None

@property
def comm(self):
Expand Down Expand Up @@ -32,6 +35,26 @@ def MPI_size(self):
else:
return self.comm.Get_size()

def global_uuid(self, default='NA'):
if self._global_uuid is None:
try:
if bmtk_world_comm.MPI_size == 1:
self._global_uuid = str(uuid.uuid4().hex)
else:
if bmtk_world_comm.MPI_rank == 0:
bcast_data = str(uuid.uuid4().hex)
else:
bcast_data = None

bcast_data = bmtk_world_comm.comm.bcast(bcast_data, root=0)
self._global_uuid = bcast_data

except Exception as e:
self._global_uuid = default

return self._global_uuid


def barrier(self):
if self.comm is not None:
self.comm.Barrier()
Expand Down
Loading