Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions arc/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A module for plotting and saving output files such as RMG libraries.
"""

import datetime
import matplotlib
# Force matplotlib to not use any Xwindows backend.
# This must be called before pylab, matplotlib.pyplot, or matplotlib.backends is imported.
Expand All @@ -12,10 +13,15 @@
import numpy as np
import os
import shutil
import textwrap
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.mplot3d import Axes3D
from typing import List, Optional, Tuple, Union

try:
import graphviz
except ImportError:
graphviz = None
import py3Dmol as p3D
from rdkit import Chem

Expand Down Expand Up @@ -54,6 +60,168 @@
logger = get_logger()


def _sanitize_graphviz_id(value: str) -> str:
"""Return a Graphviz-safe identifier."""
return ''.join(ch if ch.isalnum() else '_' for ch in value)


def _wrap_graph_label(text: str, width: int = 24) -> str:
"""Wrap long labels so graph nodes stay readable, preserving intentional newlines."""
if not text:
return ''
return '\n'.join(line for part in str(text).split('\n')
for line in (textwrap.wrap(part, width=width) or ['']))


def save_provenance_artifacts(project_directory: str,
provenance: dict,
) -> dict:
"""
Save provenance YAML and render Graphviz artifacts for an ARC run.

Args:
project_directory (str): The ARC project directory.
provenance (dict): A provenance dictionary with an ``events`` list.

Returns:
dict: Paths to generated artifacts.
"""
output_directory = os.path.join(project_directory, 'output')
os.makedirs(output_directory, exist_ok=True)
yml_path = os.path.join(output_directory, 'provenance.yml')
dot_path = os.path.join(output_directory, 'provenance.dot')
svg_path = os.path.join(output_directory, 'provenance.svg')

run_label = provenance.get('project', 'ARC run')
if graphviz is None:
logger.warning('The graphviz Python package is not available, so ARC will only save provenance.yml.')
provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds')
save_yaml_file(path=yml_path, content=provenance)
return {'yml': yml_path, 'dot': None, 'svg': None}

graph = graphviz.Digraph(
name='arc_provenance',
comment=f'ARC provenance for {run_label}',
graph_attr={'rankdir': 'LR', 'splines': 'true', 'overlap': 'false'},
node_attr={'shape': 'box', 'style': 'rounded,filled', 'fillcolor': 'white', 'fontname': 'Helvetica'},
edge_attr={'fontname': 'Helvetica'},
)
run_node_id = _sanitize_graphviz_id(f"run_{provenance.get('run_id', run_label)}")
run_header = provenance.get('started_at', '')
run_footer = provenance.get('ended_at', '')
run_text = f'{run_label}'
if run_header:
run_text += f'\nstart: {run_header}'
if run_footer:
run_text += f'\nend: {run_footer}'
graph.node(run_node_id, _wrap_graph_label(run_text, width=32), shape='oval', fillcolor='lightgoldenrod1')

species_nodes = dict()
job_nodes = dict()
# Track the most recent decision node (troubleshoot / TS selection) per label,
# so that follow-up jobs spawned by that decision connect from the diamond.
last_decision_by_label = dict()

for event in provenance.get('events', list()):
event_type = event.get('event_type', '')
label = event.get('label')
if label and label not in species_nodes:
species_node_id = _sanitize_graphviz_id(f'species_{label}')
species_text = label
if event.get('is_ts'):
species_text += '\nTS'
graph.node(species_node_id, _wrap_graph_label(species_text), fillcolor='aliceblue')
graph.edge(run_node_id, species_node_id)
species_nodes[label] = species_node_id

if event_type == 'job_started':
job_key = event.get('job_key', event.get('job_name', 'job'))
job_node_id = _sanitize_graphviz_id(f'job_{job_key}')
job_text = f"{event.get('job_type', 'job')}\n{event.get('job_name', job_key)}"
if event.get('job_adapter'):
job_text += f"\n{event['job_adapter']}"
if event.get('level'):
job_text += f"\n{event['level']}"
graph.node(job_node_id, _wrap_graph_label(job_text), fillcolor='white')

# Determine the source node for this job's incoming edge.
parent_job = event.get('provenance_parent_job')
reason = event.get('provenance_reason', '')
if parent_job and label in last_decision_by_label:
# A decision (troubleshoot / TS selection) preceded this job — connect from it.
source_node_id = last_decision_by_label.pop(label)
elif parent_job:
# Rerun or other child job — connect from the parent job node.
parent_key = f'{label}:{parent_job}'
source_node_id = job_nodes.get(parent_key, species_nodes.get(label, run_node_id))
else:
# Normal first-launch job — connect from the species node.
source_node_id = species_nodes.get(label, run_node_id)
graph.edge(source_node_id, job_node_id, label=reason)
job_nodes[job_key] = job_node_id

elif event_type == 'job_finished':
job_key = event.get('job_key')
if job_key in job_nodes:
status = event.get('status', 'unknown')
fillcolor = {'done': 'honeydew', 'errored': 'mistyrose'}.get(status, 'lightyellow')
graph.node(job_nodes[job_key], fillcolor=fillcolor)

result_node_id = _sanitize_graphviz_id(
f"result_{event.get('event_id', len(job_nodes))}_{job_key}"
)
result_text = f"{status}"
if event.get('run_time'):
result_text += f"\n{event['run_time']}"
if event.get('keywords'):
result_text += f"\n{', '.join(event['keywords'])}"
graph.node(result_node_id, _wrap_graph_label(result_text), shape='note', fillcolor='cornsilk')
graph.edge(job_nodes[job_key], result_node_id)

elif event_type in ('ts_guess_selected', 'ts_guess_selection_failed', 'job_troubleshooting'):
decision_node_id = _sanitize_graphviz_id(f"decision_{event.get('event_id', 0)}")
if event_type == 'ts_guess_selected':
decision_text = f"Select TS guess {event.get('selected_index')}"
if event.get('method'):
decision_text += f"\n{event['method']}"
fillcolor = 'lavender'
elif event_type == 'ts_guess_selection_failed':
decision_text = 'TS guess selection\nfailed'
fillcolor = 'mistyrose'
else:
decision_text = f"Troubleshoot {event.get('job_name', '')}"
if event.get('methods'):
decision_text += f"\n{', '.join(event['methods'])}"
fillcolor = 'moccasin'
graph.node(decision_node_id, _wrap_graph_label(decision_text), shape='diamond', fillcolor=fillcolor)
source_job_key = event.get('job_key')
source_node_id = job_nodes.get(source_job_key) if source_job_key else species_nodes.get(label)
if source_node_id is None and label is not None:
source_node_id = species_nodes.get(label)
if source_node_id is not None:
graph.edge(source_node_id, decision_node_id)
if label is not None:
last_decision_by_label[label] = decision_node_id

elif event_type == 'species_initialized' and label in species_nodes:
continue

with open(dot_path, 'w') as f:
f.write(graph.source)

try:
svg_data = graph.pipe(format='svg')
except (graphviz.ExecutableNotFound, graphviz.CalledProcessError):
logger.warning('Could not render ARC provenance SVG because Graphviz is not available on this system.')
else:
with open(svg_path, 'wb') as f:
f.write(svg_data)

provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds')
save_yaml_file(path=yml_path, content=provenance)
return {'yml': yml_path, 'dot': dot_path, 'svg': svg_path if os.path.isfile(svg_path) else None}


# *** Drawings species ***

def draw_structure(xyz=None, species=None, project_directory=None, method='show_sticks', show_atom_indices=False):
Expand Down
82 changes: 82 additions & 0 deletions arc/plotter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,88 @@ def test_save_irc_traj_animation(self):
plotter.save_irc_traj_animation(irc_f_path, irc_r_path, out_path)
self.assertTrue(os.path.isfile(out_path))

def test_wrap_graph_label(self):
"""Test that _wrap_graph_label preserves intentional newlines."""
# Intentional newlines should be preserved, not collapsed.
result = plotter._wrap_graph_label("opt\nopt_a1\ngaussian\nwb97xd/def2tzvp", width=30)
lines = result.split('\n')
self.assertEqual(lines[0], 'opt')
self.assertEqual(lines[1], 'opt_a1')
self.assertEqual(lines[2], 'gaussian')
self.assertEqual(lines[3], 'wb97xd/def2tzvp')
# Long single lines should still be wrapped.
result = plotter._wrap_graph_label("this is a very long label that should be wrapped", width=20)
self.assertTrue(all(len(line) <= 20 for line in result.split('\n')))
# Empty string returns empty.
self.assertEqual(plotter._wrap_graph_label(''), '')

def test_save_provenance_artifacts(self):
"""Test saving ARC provenance YAML / Graphviz artifacts."""
project = 'arc_project_for_testing_delete_after_usage'
project_directory = os.path.join(ARC_PATH, 'Projects', project)
provenance = {
'project': project,
'run_id': 'run_1',
'started_at': '2026-03-15T10:00:00',
'ended_at': '2026-03-15T10:05:00',
'events': [
{'event_id': 1, 'event_type': 'species_initialized', 'timestamp': '2026-03-15T10:00:00',
'label': 'spc1'},
{'event_id': 2, 'event_type': 'species_initialized', 'timestamp': '2026-03-15T10:00:00',
'label': 'TS0', 'is_ts': True},
{'event_id': 3, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:00:01',
'label': 'spc1', 'job_key': 'spc1:opt_a1', 'job_name': 'opt_a1', 'job_type': 'opt',
'job_adapter': 'gaussian', 'level': 'b3lyp/6-31g(d)'},
{'event_id': 4, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:01:00',
'label': 'spc1', 'job_key': 'spc1:opt_a1', 'status': 'done', 'run_time': '0:01:00'},
{'event_id': 5, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:01:01',
'label': 'spc1', 'job_key': 'spc1:freq_a2', 'job_name': 'freq_a2', 'job_type': 'freq',
'job_adapter': 'gaussian', 'level': 'b3lyp/6-31g(d)'},
{'event_id': 6, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:01:30',
'label': 'spc1', 'job_key': 'spc1:freq_a2', 'status': 'errored',
'run_time': '0:00:30', 'keywords': ['memory']},
{'event_id': 7, 'event_type': 'job_troubleshooting', 'timestamp': '2026-03-15T10:01:35',
'label': 'spc1', 'job_key': 'spc1:freq_a2', 'job_name': 'freq_a2', 'job_type': 'freq',
'methods': ['memory']},
{'event_id': 8, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:01:40',
'label': 'spc1', 'job_key': 'spc1:freq_a3', 'job_name': 'freq_a3', 'job_type': 'freq',
'job_adapter': 'gaussian', 'provenance_parent_job': 'freq_a2',
'provenance_reason': 'ess_troubleshoot'},
{'event_id': 9, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:02:00',
'label': 'spc1', 'job_key': 'spc1:freq_a3', 'status': 'done', 'run_time': '0:00:20'},
{'event_id': 10, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:02:01',
'label': 'TS0', 'job_key': 'TS0:tsg0', 'job_name': 'tsg0', 'job_type': 'tsg',
'job_adapter': 'autotst'},
{'event_id': 11, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:03:00',
'label': 'TS0', 'job_key': 'TS0:tsg0', 'status': 'done'},
{'event_id': 12, 'event_type': 'ts_guess_selected', 'timestamp': '2026-03-15T10:03:01',
'label': 'TS0', 'selected_index': 0, 'method': 'autotst', 'energy': -154.321},
],
}
paths = plotter.save_provenance_artifacts(project_directory=project_directory, provenance=provenance)
self.assertTrue(os.path.isfile(paths['yml']))
if paths['dot'] is not None:
self.assertTrue(os.path.isfile(paths['dot']))
with open(paths['dot'], 'r') as f:
dot = f.read()
# Species and job nodes are present.
self.assertIn('spc1', dot)
self.assertIn('opt_a1', dot)
self.assertIn('TS0', dot)
# Troubleshoot diamond and edge label rendered.
self.assertIn('Troubleshoot', dot)
self.assertIn('ess_troubleshoot', dot)
# TS guess selection diamond rendered.
self.assertIn('Select TS guess 0', dot)
self.assertIn('autotst', dot)
# Errored job node coloured correctly.
self.assertIn('mistyrose', dot)
# Normal jobs (opt_a1, freq_a2) connect from the species node, not from each other.
self.assertIn('species_spc1 -> job_spc1_opt_a1', dot)
self.assertIn('species_spc1 -> job_spc1_freq_a2', dot)
# Troubleshoot follow-up connects from the decision diamond, not the species node.
self.assertIn('decision_7 -> job_spc1_freq_a3', dot)


@classmethod
def tearDownClass(cls):
Expand Down
Loading
Loading