forked from NicolasSlenko/CAI4104FinalProject
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_learning_curves.py
More file actions
116 lines (83 loc) · 4 KB
/
Copy pathplot_learning_curves.py
File metadata and controls
116 lines (83 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import matplotlib.pyplot as plt
import json
import os
import argparse
def plot_accuracy_curves(history_files):
"""
Plots accuracy learning curves from history JSON files
Args:
history_files: List of history JSON file paths to plot
"""
plt.figure(figsize=(12, 7))
colors = ['b', 'r', 'g', 'm', 'c', 'y', 'k']
for i, history_path in enumerate(history_files):
model_name = os.path.basename(history_path).replace('_history.json', '')
color = colors[i % len(colors)]
try:
with open(history_path, 'r') as f:
history = json.load(f)
plt.plot(history["epochs"], history["train_acc"], f'{color}-', label=f'{model_name} - Train Acc')
plt.plot(history["epochs"], history["val_acc"], f'{color}--', label=f'{model_name} - Val Acc')
best_epoch = history["val_acc"].index(max(history["val_acc"]))
best_acc = max(history["val_acc"])
plt.plot(history["epochs"][best_epoch], best_acc, f'{color}o', markersize=8)
plt.annotate(f'{best_acc:.4f}',
(history["epochs"][best_epoch], best_acc),
textcoords="offset points",
xytext=(0,10),
ha='center')
except Exception as e:
print(f"Error loading or plotting history from {history_path}: {e}")
plt.title('Model Accuracy During Training', fontsize=16)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Accuracy', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(loc='lower right')
plt.ylim(top=plt.ylim()[1] * 1.05)
plt.tight_layout()
plt.savefig('accuracy_curves.png', dpi=300)
print(f"Accuracy curves saved to 'accuracy_curves.png'")
def plot_loss_curves(history_files):
"""
Plots loss learning curves from history JSON files
Args:
history_files: List of history JSON file paths to plot
"""
plt.figure(figsize=(12, 7))
colors = ['b', 'r', 'g', 'm', 'c', 'y', 'k']
for i, history_path in enumerate(history_files):
model_name = os.path.basename(history_path).replace('_history.json', '')
color = colors[i % len(colors)]
try:
with open(history_path, 'r') as f:
history = json.load(f)
plt.plot(history["epochs"], history["train_loss"], f'{color}-', label=f'{model_name} - Train Loss')
plt.plot(history["epochs"], history["val_loss"], f'{color}--', label=f'{model_name} - Val Loss')
best_epoch = history["val_loss"].index(min(history["val_loss"]))
best_loss = min(history["val_loss"])
plt.plot(history["epochs"][best_epoch], best_loss, f'{color}o', markersize=8)
plt.annotate(f'{best_loss:.4f}',
(history["epochs"][best_epoch], best_loss),
textcoords="offset points",
xytext=(0,10),
ha='center')
except Exception as e:
print(f"Error loading or plotting history from {history_path}: {e}")
plt.title('Model Loss During Training', fontsize=16)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(loc='upper right')
plt.tight_layout()
plt.savefig('loss_curves.png', dpi=300)
print(f"Loss curves saved to 'loss_curves.png'")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Plot learning curves from model history files")
args = parser.parse_args()
print("Plotting learning curves for models:")
for history_file in args.history_files:
print(f" - {history_file}")
plot_accuracy_curves(args.history_files)
plot_loss_curves(args.history_files)
#show plots after saving both
plt.show()