-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer.py
More file actions
105 lines (79 loc) · 4.11 KB
/
Copy pathtrainer.py
File metadata and controls
105 lines (79 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import torch
import json
from src.data.build import build_dataloader
from src.config.config import save_cfg
from src.engine.base import _BaseEngine
class Trainer():
def __init__(self, engine:_BaseEngine, cfg:dict, device:str = "cuda", test_only:bool= False):
# Configuration parameters
self.cfg = cfg
self.device = device
self.seed = self.cfg["SYSTEM"]["SEED"]
self.test_only = test_only
self.engine = engine
self.dataloader = build_dataloader(cfg)
# Optimization configuration
self.batch_size = self.cfg["DATA"]["BATCH_SIZE"]
self.epochs = self.cfg["DATA"]["EPOCHS"]
self.start_epoch = self.engine.start_epoch
# Saving configuration
self.model_output = os.path.join(self.cfg["OUTPUT"],"model")
if not os.path.exists(self.model_output):
os.makedirs(self.model_output)
self.log_output = os.path.join(self.cfg["OUTPUT"],"log")
if not os.path.exists(self.log_output):
os.makedirs(self.log_output)
self.train_output = os.path.join(self.cfg["OUTPUT"],"train",)
if not os.path.exists(self.train_output):
os.makedirs(self.train_output)
self.val_output = os.path.join(self.cfg["OUTPUT"],"validation",)
if not os.path.exists(self.val_output):
os.makedirs(self.val_output)
self.test_output = os.path.join(self.cfg["OUTPUT"],"test",)
if not os.path.exists(self.test_output):
os.makedirs(self.test_output)
save_cfg(self.cfg, self.cfg["OUTPUT"])
def run(self):
if not self.test_only:
self.train()
self.test()
def train(self):
#Starting Training
# Initialization best checkpoint parameters
print("Beginning of the training process...")
for epoch in range(self.start_epoch, self.epochs):
# One epoch training
train_dict, save_dict = self.engine.train_model(epoch=epoch,
train_data=self.dataloader.train_dataloader(),
output_dir=self.train_output,
visualization = True,
save_frq = self.cfg["EVALUATION"]["TRAIN_FRQ"])
print("TRAINING: [Epoch {}/{}] || {}".format(epoch, self.epochs, train_dict))
torch.save(save_dict, os.path.join(self.model_output,'checkpoint.pth'))
# Saving training logs
with open(os.path.join(self.log_output, "train_log.txt"), "a") as f:
f.write(json.dumps(train_dict) +'\n')
# One epoch validation
if epoch==self.start_epoch or epoch%self.cfg["EVALUATION"]["VAL_FRQ"] ==0 or epoch==self.epochs-1:
val_dict = self.engine.eval_model(epoch=epoch,
val_data=self.dataloader.val_dataloader(),
Test=False,
output_dir=self.val_output)
# Saving validation logs
with open(os.path.join(self.log_output, "val_log.txt"), "a") as f:
f.write(json.dumps(val_dict) + '\n')
# Save current checkpoint
torch.save(save_dict, os.path.join(self.model_output, 'checkpoint.pth'))
print("... End of the training process")
def test(self):
# One epoch testing
print("Beginning of the testing process...")
test_dict = self.engine.eval_model(epoch=self.epochs,
val_data=self.dataloader.test_dataloader(),
Test=True,
output_dir=self.val_output)
# Saving testing logs
with open(os.path.join(self.log_output, "test_log.txt"), "a") as f:
f.write(json.dumps(test_dict) + '\n')
print("... End of the testing process")