Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/neuralnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,13 @@ impl<const I: usize, const O: usize> NeuralNetwork<I, O> {
pub fn remove_cycles(&mut self) {
let mut visited = HashMap::new();
let mut edges_to_remove: HashSet<Connection> = HashSet::new();
let mut path = Vec::new();

for i in 0..I {
self.remove_cycles_dfs(
&mut visited,
&mut edges_to_remove,
None,
&mut path,
NeuronLocation::Input(i),
);
}
Expand All @@ -693,7 +694,7 @@ impl<const I: usize, const O: usize> NeuralNetwork<I, O> {
for i in 0..self.hidden_layers.len() {
let loc = NeuronLocation::Hidden(i);
if !visited.contains_key(&loc) {
self.remove_cycles_dfs(&mut visited, &mut edges_to_remove, None, loc);
self.remove_cycles_dfs(&mut visited, &mut edges_to_remove, &mut path, loc);
}
}

Expand All @@ -704,22 +705,23 @@ impl<const I: usize, const O: usize> NeuralNetwork<I, O> {
}
}

// colored dfs
// colored dfs using an explicit path stack so the back-edge (parent → current)
// is always identified correctly, regardless of HashMap iteration order.
fn remove_cycles_dfs(
&mut self,
visited: &mut HashMap<NeuronLocation, u8>,
edges_to_remove: &mut HashSet<Connection>,
prev: Option<NeuronLocation>,
path: &mut Vec<NeuronLocation>,
current: NeuronLocation,
) {
if let Some(&existing) = visited.get(&current) {
if existing == 0 {
// part of current dfs - found a cycle
// prev must exist here since visited would be empty on first call.
let prev = prev.unwrap();
if self[prev].outputs.contains_key(&current) {
// part of current dfs path - found a cycle.
// path.last() is the node that just tried to visit `current`,
// so path.last() → current is the back-edge to remove.
if let Some(&parent) = path.last() {
edges_to_remove.insert(Connection {
from: prev,
from: parent,
to: current,
});
}
Expand All @@ -730,12 +732,14 @@ impl<const I: usize, const O: usize> NeuralNetwork<I, O> {
}

visited.insert(current, 0);
path.push(current);

let outputs = self[current].outputs.keys().cloned().collect::<Vec<_>>();
for loc in outputs {
self.remove_cycles_dfs(visited, edges_to_remove, Some(current), loc);
self.remove_cycles_dfs(visited, edges_to_remove, path, loc);
}

path.pop();
visited.insert(current, 1);
}

Expand Down
Loading