-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathmain.py
More file actions
103 lines (79 loc) · 2.44 KB
/
main.py
File metadata and controls
103 lines (79 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import math
import numpy as np
import tensorflow as tf
from model import RFDNNet
from tensorflow import keras
from IPython.display import display
from tensorflow.keras import Input, Model
from tensorflow.keras.preprocessing import image_dataset_from_directory
from utils import *
dataset_url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
data_dir = keras.utils.get_file(origin=dataset_url, fname="BSR", untar=True)
root_dir = os.path.join(data_dir, "BSDS500/data")
crop_size = 300
upscale_factor = 3
input_size = crop_size // upscale_factor
batch_size = 8
train_ds = image_dataset_from_directory(
root_dir,
batch_size=batch_size,
image_size=(crop_size, crop_size),
validation_split=0.2,
subset="training",
seed=1337,
label_mode=None,
)
valid_ds = image_dataset_from_directory(
root_dir,
batch_size=batch_size,
image_size=(crop_size, crop_size),
validation_split=0.2,
subset="validation",
seed=1337,
label_mode=None,
)
# Scale from (0, 255) to (0, 1)
train_ds = train_ds.map(scaling)
valid_ds = valid_ds.map(scaling)
dataset = os.path.join(root_dir, "images")
test_path = os.path.join(dataset, "test")
test_img_paths = sorted(
[
os.path.join(test_path, fname)
for fname in os.listdir(test_path)
if fname.endswith(".jpg")
]
)
train_ds = train_ds.map(
lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
train_ds = train_ds.prefetch(buffer_size=32)
valid_ds = valid_ds.map(
lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
valid_ds = valid_ds.prefetch(buffer_size=32)
rfanet_x = RFDNNet()
x = Input(shape=(None, None, 3))
out = rfanet_x.main_model(x, upscale_factor)
model = Model(inputs=x, outputs=out)
model.summary()
early_stopping_callback = keras.callbacks.EarlyStopping(monitor="loss", patience=10)
checkpoint_filepath = "weights"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath + '/best.h5',
monitor="loss",
mode="min",
save_best_only=True,
period=1
)
callbacks = [ESPCNCallback(test_img_paths), early_stopping_callback, model_checkpoint_callback]
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(learning_rate=0.001)
epochs = 100
model.compile(
optimizer=optimizer, loss=loss_fn,
)
model.fit(
train_ds, epochs=epochs, callbacks=callbacks, validation_data=valid_ds, verbose=2
)