-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_training_logs.py
More file actions
39 lines (28 loc) · 884 Bytes
/
plot_training_logs.py
File metadata and controls
39 lines (28 loc) · 884 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# %%
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
training_version = "version_18"
# %%
df = pd.read_csv(f'logs/lightning_logs/{training_version}/metrics.csv')
df
# %%
df["epoch"].plot()
# %%
df["step"].plot()
# %% group by epoch and take mean
epoch_logs = df.groupby("epoch").mean()
# %% plot, choosing colors from tab10 colormap
colors = mpl.colormaps["tab10"].colors
epoch_logs.dropna(axis="columns", how="all").drop(columns=["step"]).plot(color=colors)
# %% count number of entries per epoch
df.groupby("epoch").count()["step"].plot()
# %%
epoch_logs = epoch_logs.dropna(axis="columns", how="all")
epoch_logs = epoch_logs.drop(columns=["step"])
# %% drop columsn with "step" in the name
has_step = [col for col in epoch_logs.columns if "step" in col]
epoch_logs = epoch_logs.drop(columns=has_step)
# %%
epoch_logs.plot(color=colors)
# %%