-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotting.py
More file actions
364 lines (285 loc) · 11.5 KB
/
plotting.py
File metadata and controls
364 lines (285 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
## Functions used for plots. Not all are used in the final project.
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import pandas as pd
def plot_event_distribution(df, df_histogram, start, end, xaxis='time', window_size=100, filename=None, show=True):
"""
Plots the event distribution over time or frame for a single cluster.
Parameters:
df (pd.DataFrame): Not directly used here but included for compatibility
df_histogram (pd.DataFrame): DataFrame with 'timestamps_sec', 'frames', and 'count' columns
start (float): Start of the cluster (in seconds or frames)
end (float): End of the cluster (in seconds or frames)
xaxis (str): 'time' or 'frame'
window_size (int): Window size for moving average
filename (str): Optional file name to save the plot
show (bool): Whether to display the plot
"""
if xaxis == 'time':
x_column = 'timestamps_sec'
margin = 0.5 # in seconds
xlabel = "Time [s]"
else:
x_column = 'frames'
margin = 2
xlabel = "Frame"
# First filter based on x_column (e.g., 'timestamps_sec')
df_plot = df_histogram[
(df_histogram[x_column] >= start - margin) &
(df_histogram[x_column] <= end + margin)
].dropna(subset=[x_column, 'count']).sort_values(by=x_column)
# Now filter based on x being within 50 pixels of a reference point
#reference_x = df_plot['x'].mean() # <-- define this based on your use case
#pixel_tolerance = 30
#df_plot = df_plot[
#(df_plot['x'] >= reference_x - pixel_tolerance) &
#(df_plot['x'] <= reference_x + pixel_tolerance)
#]
#print(df_plot)
# Data for statistics (excludes margin)
df_stats = df_histogram[
(df_histogram[x_column] >= start) &
(df_histogram[x_column] <= end)
].dropna(subset=[x_column, 'count'])
total_events = df_stats['count'].sum()
max_event = df_stats['count'].max()
avg_count = df_histogram['count'].mean()
# Create a rolling average on the plotting data (use available data)
moving_avg = df_plot['count'].rolling(window=window_size).mean()
xdata = df_plot[x_column]
# Begin plotting
plt.figure(figsize=(8, 5))
#plt.plot(xdata + 0.2, df_plot['count'], marker='o', linestyle='-', label='Event Count', color='#1d468b', alpha=0.7)
plt.plot(xdata + 0.2, df_plot['count'], linestyle='-', color='#1d468b', alpha=0.4)
plt.plot(xdata + 0.2, df_plot['count'], linestyle='None', marker='o', color='#1d468b', label='Event Counts', alpha=0.7)
plt.axhline(y=avg_count, color='#e89e42', linestyle='--', linewidth=2, label="Average")
plt.axvspan(start+0.2, end+0.2, color='#e89e42', alpha=0.4, label='Cluster Period')
plt.xlabel(xlabel, fontsize=14)
plt.xlim(start+0.2 - margin, end+0.2 + margin)
#plt.xlim(start+0.2 - 0.05, end+0.2 + 0.051)
plt.ylabel("ON-Events")
#plt.title(f"Event Distribution from {start:.2f} to {end:.2f} ({xaxis}). Total events: {total_events}")
plt.grid(True)
plt.legend(loc='upper right')
if filename:
plt.savefig(f'/Users/josephine/Library/CloudStorage/OneDrive-DanmarksTekniskeUniversitet/ThorDavis/histogrammer/start_{start+0.2}_{filename}.pdf', format='pdf', bbox_inches='tight', dpi=300)
print(f"Plot saved as {filename}.pdf")
if show:
plt.show()
else:
plt.close()
return total_events, max_event
def plot_cluster_locations(df, title="Cluster Locations", filename=None, show=True):
"""
Plots merged cluster locations using already-processed DataFrame.
Parameters:
- df: DataFrame from filter_and_merge_clusters() with 'x', 'y', and 'mean time'.
- title: Title of the plot.
- filename: If given, saves the plot to this path.
- show: Whether to display the plot.
"""
num_clusters = len(df)
colors = cm.rainbow(np.linspace(0, 1, num_clusters))
plt.figure(figsize=(8, 6))
for i in range(num_clusters):
x_vals = df['x'][i]
y_vals = df['y'][i]
mean_time = df['mean time'][i]
plt.scatter(x_vals, y_vals,
color=colors[i],
s=25,
alpha=0.7,
label=f"Lightning @ {mean_time:.2f}s")
plt.title(title)
plt.xlabel("x [px]")
plt.ylabel("y [px]")
plt.xlim(0,346)
plt.ylim(0.260)
# Limit to first 15 legend entries if there are too many
handles, labels = plt.gca().get_legend_handles_labels()
if len(handles) > 15:
handles = handles[:15]
labels = labels[:15]
plt.legend(handles, labels, fontsize='small')
else:
plt.legend()
plt.grid(True)
plt.tight_layout()
if filename:
plt.savefig(filename, dpi=300)
print(f"✅ Plot saved as {filename}")
if show:
plt.show()
else:
plt.close()
def plot_event_count(df, x_res=None, y_res=None, vmax=None, filename=None, show=True):
"""
Fast version of event count heatmap using numpy.histogram2d.
"""
#x = np.concatenate(df['x'])
#y = np.concatenate(df['y'])
if x_res is None:
x_res = df['x'].max() + 1
if y_res is None:
y_res = df['y'].max() + 1
x_res = int(x_res)
y_res = int(y_res)
# 2D histogram (note: swap x and y to match image axes)
event_count_map, xedges, yedges = np.histogram2d(df['y'], df['x'], bins=[y_res, x_res])
plt.figure(figsize=(8, 6))
plt.imshow(event_count_map, cmap='hot', vmax=vmax)
plt.gca().invert_yaxis() # Reverse the y-axis here
plt.colorbar(label='Event count per pixel')
plt.xlabel('X')
plt.ylabel('Y')
if filename:
plt.savefig(filename, format='pdf', bbox_inches='tight', dpi=300)
print(f"Plot saved as {filename}.pdf")
if show:
plt.show()
else:
plt.close()
def plot_event_timeline(df_histogram,
df=None,
window_size=100,
xaxis='time',
filename=None,
show=True):
"""
Plots the event distribution over time from a dataframe.
Adds average and moving average lines. Saves the plot as a PDF.
Adds an inset zoom to the region around the maximum event count.
Parameters:
df_histogram (pd.DataFrame): DataFrame containing 'time[s]' and 'count' columns
cluster_data (pd.DataFrame or list): DataFrame or list with cluster start and end times
window_size (int): Window size for calculating the moving average
filename (str): The name of the PDF file to save the plot
"""
x_data = df_histogram['timestamps_sec'] if xaxis == 'time' else df_histogram['frames']
# Set default font size globally
plt.rcParams.update({'font.size': 15})
# Calculate average and moving average
avg_count = df_histogram['count'].mean()
moving_avg = df_histogram['count'].rolling(window=window_size).mean()
# Create main plot
fig, ax = plt.subplots(figsize=(15, 5))
# Plot event count
ax.plot(
x_data,
df_histogram['count'],
marker='o',
linestyle='-',
label='Event Count',
color='lightsteelblue')
# Plot moving average
ax.plot(
x_data,
moving_avg,
color='steelblue',
linestyle='-',
linewidth=2,
label='Moving Average')
# Plot average line
ax.axhline(y=avg_count, color='darkblue', linestyle='--', linewidth=2, label="Average")
# Labels and title with font sizes
ax.set_xlabel("Time [s]" if xaxis == 'time' else "Frame", fontsize=14)
ax.set_ylabel("Event Count", fontsize=14)
# Tick label size
ax.tick_params(axis='both', labelsize=14)
# Inset zoom implementation
# Find the timestamp where the max event occurs
max_event_row = df_histogram.loc[df_histogram['count'].idxmax()]
max_x = max_event_row['timestamps_sec'] if xaxis == 'time' else max_event_row['frames']
# Define zoom area (x-axis and y-axis limits)
x1 = max_x - (0.05 if xaxis == 'time' else 2) # adjust zoom window
x2 = max_x + (0.05 if xaxis == 'time' else 2)
# Get the subset to determine y-limits
subset = df_histogram[(x_data >= x1) & (x_data <= x2)]
y1 = 0
y2 = subset['count'].max()
# Create inset axes
axins = ax.inset_axes([0.65, 0.66, 0.13, 0.3], xlim=(x1, x2), ylim=(y1, y2))
# Plot on inset
axins.plot(
df_histogram['timestamps_sec'],
df_histogram['count'],
marker='o',
linestyle='-',
color='lightsteelblue')
axins.plot(
df_histogram['timestamps_sec'],
moving_avg,
color='steelblue',
linestyle='-',
linewidth=2)
# Highlight clusters if cluster_data is provided
if df is not None:
is_dataframe = isinstance(df, pd.DataFrame)
if is_dataframe:
df['start'] = pd.to_numeric(df['start'], errors='coerce')
df['end'] = pd.to_numeric(df['end'], errors='coerce')
# Iterate over DataFrame rows
for idx, row in df.iterrows():
if xaxis == 'time':
cluster_start = row['start']
cluster_end = row['end']
else:
cluster_start = row['start frame']
cluster_end = row['end frame']
# Add label only on first iteration
label = 'Event Interval' if idx == 0 else None
ax.axvspan(cluster_start, cluster_end, color='red', alpha=0.2, label=label)
axins.axvspan(cluster_start, cluster_end, color='red', alpha=0.2)
else:
pass
# Assume it's a list of lists or tuples
#for idx, cluster in enumerate(df):
# cluster_start = cluster[0] # Assuming 'start' is the first element
# cluster_end = cluster[1] # Assuming 'end' is the second element
# Add label only on first iteration
#label = 'Event Interval' if idx == 0 else None
#ax.axvspan(cluster_start, cluster_end, color='red', alpha=0.3, label=label)
#axins.axvspan(cluster_start, cluster_end, color='red', alpha=0.3)
# Hide tick labels on inset
axins.set_xticklabels([])
axins.set_yticklabels([])
# Indicate zoom area on main plot
ax.indicate_inset_zoom(axins, edgecolor="black", linewidth=4)
# Grid and legend
ax.grid(True)
# Set legend font size
ax.legend(fontsize=14, loc='upper right')
if filename:
plt.savefig(filename, format='pdf', bbox_inches='tight', dpi=300)
print(f"Plot saved as {filename}.pdf")
if show:
plt.show()
else:
plt.close(fig)
def plot_variable(df, xvariable, yvariable, xlabel, ylabel, title, filename=None, show=True):
"""
Plots a desired variable
Parameters:
- output_df: DataFrame with temporal spread data.
- std_col: Name of the column for standard deviation (default: 'std (time) [ms]').
- mean_col: Name of the column for mean cluster time (default: 'mean time').
- filename: Optional file path to save the plot (e.g., 'spread_plot.png').
- show: If True, the plot will be shown.
"""
x = df[xvariable]
y = df[yvariable]
plt.figure(figsize=(10, 4))
plt.plot(x, y, marker='o', linestyle='-', color='darkorange')
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.grid(True)
plt.tight_layout()
if filename:
plt.savefig(filename, format='pdf', bbox_inches='tight', dpi=300)
print(f"Plot saved as {filename}.pdf")
if show:
plt.show()
else:
plt.close()