diff --git a/src/power_grid_model_ds/_core/model/graphs/models/base.py b/src/power_grid_model_ds/_core/model/graphs/models/base.py index 6928e4b7..1d3263d2 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/base.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/base.py @@ -221,16 +221,21 @@ def tmp_remove_nodes(self, nodes: list[int]) -> Generator: considering certain nodes. """ edge_list = [] - for node in nodes: - edge_list += list(self.in_branches(node)) - self.delete_node(node) + node_list = [] - yield + try: + for node in nodes: + edge_list += list(self.in_branches(node)) + + self.delete_node(node) + node_list.append(node) - for node in nodes: - self.add_node(int(node)) # convert to int to avoid type issues when input is e.g. a numpy array - for source, target in edge_list: - self.add_branch(source, target) + yield + finally: + for node in node_list: + self.add_node(int(node)) # convert to int to avoid type issues when input is e.g. a numpy array + for source, target in edge_list: + self.add_branch(source, target) @contextmanager def tmp_remove_branches(self, branches: list[tuple[int, int]]) -> Generator: diff --git a/tests/unit/model/graphs/test_graph_model.py b/tests/unit/model/graphs/test_graph_model.py index 5ce35314..6813f28f 100644 --- a/tests/unit/model/graphs/test_graph_model.py +++ b/tests/unit/model/graphs/test_graph_model.py @@ -143,39 +143,53 @@ def test_graph_in_branches(self, graph: BaseGraphModel): assert list(graph.in_branches(2)) == [(1, 2), (1, 2), (1, 2)] -def test_tmp_remove_nodes(graph_with_2_routes: BaseGraphModel) -> None: - graph = graph_with_2_routes +class TestTmpRemoveNodes: + def test_tmp_remove_nodes(self, graph_with_2_routes: BaseGraphModel) -> None: + graph = graph_with_2_routes + + assert graph.nr_branches == 4 - assert graph.nr_branches == 4 + # add parallel branches to test whether they are restored correctly + graph.add_branch(1, 5) + graph.add_branch(5, 1) - # add parallel branches to test whether they are restored correctly - graph.add_branch(1, 5) - graph.add_branch(5, 1) + assert graph.nr_nodes == 5 + assert graph.nr_branches == 6 - assert graph.nr_nodes == 5 - assert graph.nr_branches == 6 + before_sets = [frozenset(branch) for branch in graph.all_branches] + counter_before = Counter(before_sets) - before_sets = [frozenset(branch) for branch in graph.all_branches] - counter_before = Counter(before_sets) + with graph.tmp_remove_nodes([1, 2]): + assert graph.nr_nodes == 3 + assert list(graph.all_branches) == [(5, 4)] - with graph.tmp_remove_nodes([1, 2]): - assert graph.nr_nodes == 3 - assert list(graph.all_branches) == [(5, 4)] + assert graph.nr_nodes == 5 + assert graph.nr_branches == 6 - assert graph.nr_nodes == 5 - assert graph.nr_branches == 6 + after_sets = [frozenset(branch) for branch in graph.all_branches] + counter_after = Counter(after_sets) + assert counter_before == counter_after - after_sets = [frozenset(branch) for branch in graph.all_branches] - counter_after = Counter(after_sets) - assert counter_before == counter_after + def test_tmp_remove_nodes_array_input(self, graph_with_2_routes: BaseGraphModel) -> None: + with graph_with_2_routes.tmp_remove_nodes(np.array([1, 2])): # type: ignore[arg-type] + pass + + # check that the external ids are still all integers instead of e.g. np.int + assert all([isinstance(e_id, int) for e_id in graph_with_2_routes.external_ids]) + def test_invalid_tmp_remove_nodes(self, graph_with_2_routes: BaseGraphModel) -> None: + original_graph = deepcopy(graph_with_2_routes) + assert graph_with_2_routes.nr_nodes == 5 + assert graph_with_2_routes.nr_branches == 4 -def test_tmp_remove_nodes_array_input(graph_with_2_routes: BaseGraphModel) -> None: - with graph_with_2_routes.tmp_remove_nodes(np.array([1, 2])): # type: ignore[arg-type] - pass + # When we remove node 1 and then an non-existing node that crashes the process + with pytest.raises(MissingNodeError), graph_with_2_routes.tmp_remove_nodes([1, 99]): + pass - # check that the external ids are still all integers instead of e.g. np.int - assert all([isinstance(e_id, int) for e_id in graph_with_2_routes.external_ids]) + # The remaining graph object should still contain the same nodes and edges. + assert graph_with_2_routes.nr_nodes == 5 + assert graph_with_2_routes.nr_branches == 4 + assert graph_with_2_routes == original_graph class TestTmpRemoveBranches: