From c0982b3464371bef9851a4aafe74685965d5ba62 Mon Sep 17 00:00:00 2001 From: kaeldai Date: Mon, 26 Jan 2026 14:15:31 -0800 Subject: [PATCH] adding converter for networkx multigraph --- bmtk/simulator/core/simulator_network.py | 71 ++++++++++++++++++++++++ bmtk/utils/sonata/population.py | 8 ++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/bmtk/simulator/core/simulator_network.py b/bmtk/simulator/core/simulator_network.py index 1d86f1f30..60b64c0ad 100644 --- a/bmtk/simulator/core/simulator_network.py +++ b/bmtk/simulator/core/simulator_network.py @@ -204,6 +204,77 @@ def build_recurrent_edges(self, **opts): def build_virtual_connections(self): raise NotImplementedError() + def to_multigraph(self, source=None, target=None, edge_attrs=['nsyns'], **filter_attrs): + import networkx as nx + + # Get node population(s) and node_id(s) that match "source" query. If no source filter + # then just leave blank and assume all nodes should be returned. + src_pops = [e._edge_pop.source_population for e in self._edge_populations] + src_node_ids = None + if source is not None: + src_node_set = self.get_node_set(source) + src_pops = src_node_set.population_names() + src_node_ids = src_node_set.node_ids + + # Like above get node population(s)/node_id(s) for the given "target" query. + trg_pops = [e._edge_pop.target_population for e in self._edge_populations] + trg_node_ids = None + if target is not None: + trg_node_set = self.get_node_set(target) + trg_pops = trg_node_set.population_names() + trg_node_ids = trg_node_set.node_ids + + # Iterative through all SONATA edge populations in network + edges_df = pd.DataFrame() + for epop in self._edge_populations: + # Skip if not valid source and target names. + if epop._edge_pop.source_population not in src_pops or epop._edge_pop.target_population not in trg_pops: + continue + + # convert edges to pd dataframe + pop_df = epop._edge_pop.to_dataframe() + if 'nsyns' not in pop_df.columns: + pop_df['nsyns'] = 1 + + # Filter by source and/or target nodes, if applicable. + if src_node_ids: + pop_df = pop_df[pop_df['source_node_id'].isin(src_node_ids)] + + if trg_node_ids: + pop_df = pop_df[pop_df['target_node_id'].isin(trg_node_ids)] + + # Filter by edge attributes, if applicable. + for k, v in filter_attrs.items(): + if pop_df is None or len(pop_df) == 0: + break + + if k not in pop_df.columns: + pop_df = None + elif isinstance(v, (list, tuple)): + pop_df = pop_df[pop_df[k].isin(v)] + else: + pop_df = pop_df[pop_df[k] == v] + + if pop_df is not None: + pop_df = pop_df[['source_population', 'source_node_id', 'target_population', 'target_node_id'] + edge_attrs] + edges_df = pd.concat([edges_df, pop_df], sort=False) + + if edges_df is None or len(edges_df) == 0: + return nx.MultiDiGraph() + + edges_df = edges_df.assign( + source_nodes=lambda df: list(zip(df['source_population'], df['source_node_id'].astype(int))), + target_nodes=lambda df: list(zip(df['target_population'], df['target_node_id'].astype(int))), + ) + + return nx.from_pandas_edgelist( + edges_df, + source='source_nodes', + target='target_nodes', + edge_attr=edge_attrs, + create_using=nx.MultiDiGraph() + ) + @classmethod def from_config(cls, conf, **properties): """Generates a graph structure from a json config file or dictionary. diff --git a/bmtk/utils/sonata/population.py b/bmtk/utils/sonata/population.py index a439d6664..433448976 100644 --- a/bmtk/utils/sonata/population.py +++ b/bmtk/utils/sonata/population.py @@ -410,8 +410,14 @@ def edge_types_table(self): return self._types_table def to_dataframe(self): - raise NotImplementedError() + ret_df = pd.DataFrame() + for grp_id in self.group_ids: + grp_df = self.get_group(grp_id).to_dataframe() + ret_df = pd.concat([ret_df, grp_df], sort=False) + ret_df['source_population'] = self.source_population + ret_df['target_population'] = self.target_population + return ret_df def build_indicies(self): indicies_grp = None