-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·102 lines (80 loc) · 2.79 KB
/
main.py
File metadata and controls
executable file
·102 lines (80 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
"""
Controller for the Recurrent CNN model for image super-resolution.
This file contains the main function to run the model.
"""
__author__ = "Mir Sazzat Hossain"
import argparse
import random
import numpy as np
import torch
from train.recur_cnn_trainer import RecurrentCNNTrainer
from utils.config import load_config
def clear_cache() -> None:
"""Clear the cache of PyTorch."""
torch.cuda.empty_cache()
def set_seed(seed: int) -> None:
"""
Set the seed for reproducibility.
:param seed: The seed value.
:type seed: int
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed value for reproducibility.",
)
parser.add_argument(
"--config",
type=str,
default="recursive_cnn",
)
parser.add_argument(
"--test_only",
action="store_true",
)
args = parser.parse_args()
clear_cache()
set_seed(args.seed)
config = load_config(args.config)
if args.test_only and config["log_params"]["test_model_path"] is None:
raise ValueError(
"Please provide the path to the model to be tested."
)
trainer = RecurrentCNNTrainer(
data_dir=config["data_params"]["data_dir"],
train_dataset_name=config["data_params"]["train_dataset_name"],
test_dataset_names=config["data_params"]["test_dataset_names"],
scale_factor=config["exp_params"]["scale_factor"],
patch_size=config["exp_params"]["patch_size"],
batch_size=config["exp_params"]["batch_size"],
overlap_height_ratio=config["data_params"]["overlap_height_ratio"],
overlap_width_ratio=config["data_params"]["overlap_width_ratio"],
width=config["model_params"]["width"],
depth=config["model_params"]["depth"],
num_epochs=config["exp_params"]["epochs"],
lr=config["exp_params"]["lr"],
num_workers=config["exp_params"]["num_workers"],
device=torch.device(config["exp_params"]["device"]),
log_dir=config["log_params"]["log_dir"],
result_dir=config["log_params"]["result_dir"],
save_interval=config["log_params"]["save_interval"],
add_discriminator=config["exp_params"]["add_discriminator"],
)
if not args.test_only:
trainer.train()
test_model_path = config["log_params"]["test_model_path"] \
if args.test_only else None
print("Testing model: ", test_model_path)
for dataset_name in config["data_params"]["test_dataset_names"]:
trainer.test(
dataset_name=dataset_name,
model_path=test_model_path,
)