-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathresults.py
More file actions
32 lines (24 loc) · 930 Bytes
/
results.py
File metadata and controls
32 lines (24 loc) · 930 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from dataclasses import dataclass
import numpy as np
from base import Model
@dataclass
class Snapshot:
x: np.ndarray
z: np.ndarray
p: np.ndarray
fixed: np.ndarray
free: np.ndarray
class SnapshotManager:
def __init__(self, model: Model):
self.model = model
self.snapshots: {str: Snapshot} = {}
def take_snapshot(self, name):
xzp = self.model.collect_var_val(['x', 'z', 'p', 'fixed', 'free'])
self.snapshots[name] = Snapshot(x=xzp['x'], z=xzp['z'], p=xzp['p'], fixed=xzp['fixed'], free=xzp['free'])
def rewind2snapshot(self, name):
snapshot = self.snapshots[name]
self.model.load_var_values(snapshot.x, 'x')
self.model.load_var_values(snapshot.z, 'z')
self.model.load_var_values(snapshot.p, 'p')
self.model.vars_val = {key: getattr(snapshot, key)
for key in self.model.vars_val}