forked from ming-liuyi/Extreme-Video-Compression-With-Prediction-Using-Pre-trainded-Diffusion-Models-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNetwork.py
More file actions
executable file
·664 lines (544 loc) · 30.6 KB
/
Network.py
File metadata and controls
executable file
·664 lines (544 loc) · 30.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
"""
A simple test algorithm to rewrite the network
"""
import math
import torch
import torch.nn as nn
from torch import Tensor
from timm.models.layers import trunc_normal_
from ELICUtilis.layers import (
AttentionBlock,
conv3x3,
CheckboardMaskedConv2d,
)
from compressai.models.priors import CompressionModel, GaussianConditional
from compressai.ops import ste_round
from compressai.models.utils import conv, deconv, update_registered_buffers
from thop import profile
from ptflops import get_model_complexity_info
# From Balle's tensorflow compression examples
SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
"""1x1 convolution."""
return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
class ResidualBottleneckBlock(nn.Module):
"""Simple residual block with two 3x3 convolutions.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
"""
def __init__(self, in_ch: int):
super().__init__()
self.conv1 = conv1x1(in_ch, in_ch//2)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(in_ch//2, in_ch//2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = conv1x1(in_ch//2, in_ch)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.conv3(out)
out = out + identity
return out
class Quantizer():
def quantize(self, inputs, quantize_type="noise"):
if quantize_type == "noise":
half = float(0.5)
noise = torch.empty_like(inputs).uniform_(-half, half)
inputs = inputs + noise
return inputs
elif quantize_type == "ste":
return torch.round(inputs) - inputs.detach() + inputs
else:
return torch.round(inputs)
class TestModel(CompressionModel):
def __init__(self, N=192, M=320, num_slices=5, **kwargs):
super().__init__(entropy_bottleneck_channels=192)
self.N = int(N)
self.M = int(M)
self.num_slices = num_slices
"""
N: channel number of main network
M: channnel number of latent space
"""
self.groups = [0, 16, 16, 32, 64, 192] #support depth
self.g_a = nn.Sequential(
conv(3, N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
conv(N, N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
AttentionBlock(N),
conv(N, N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
conv(N, M),
AttentionBlock(M),
)
self.g_s = nn.Sequential(
AttentionBlock(M),
deconv(M, N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
deconv(N, N),
AttentionBlock(N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
deconv(N, N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
ResidualBottleneckBlock(N),
deconv(N, 3),
)
self.h_a = nn.Sequential(
conv3x3(M, N),
nn.ReLU(inplace=True),
conv(N, N),
nn.ReLU(inplace=True),
conv(N, N),
)
self.h_s = nn.Sequential(
deconv(N, N),
nn.ReLU(inplace=True),
deconv(N, N*3//2),
nn.ReLU(inplace=True),
conv3x3(N*3//2, 2*M),
)
self.cc_transforms = nn.ModuleList(
nn.Sequential(
conv(self.groups[min(1, i) if i > 0 else 0] + self.groups[i if i > 1 else 0], 224, stride=1,
kernel_size=5),
nn.ReLU(inplace=True),
conv(224, 128, stride=1, kernel_size=5),
nn.ReLU(inplace=True),
conv(128, self.groups[i + 1]*2, stride=1, kernel_size=5),
) for i in range(1, num_slices)
) ## from https://github.com/tensorflow/compression/blob/master/models/ms2020.py
self.context_prediction = nn.ModuleList(
CheckboardMaskedConv2d(
self.groups[i+1], 2*self.groups[i+1], kernel_size=5, padding=2, stride=1
) for i in range(num_slices)
)## from https://github.com/JiangWeibeta/Checkerboard-Context-Model-for-Efficient-Learned-Image-Compression/blob/main/version2/layers/CheckerboardContext.py
self.ParamAggregation = nn.ModuleList(
nn.Sequential(
conv1x1(640 + self.groups[i+1 if i > 0 else 0] * 2 + self.groups[
i + 1] * 2, 640),
nn.ReLU(inplace=True),
conv1x1(640, 512),
nn.ReLU(inplace=True),
conv1x1(512, self.groups[i + 1]*2),
) for i in range(num_slices)
) ##from checkboard "Checkerboard Context Model for Efficient Learned Image Compression"" gep网络参数
self.quantizer = Quantizer()
self.gaussian_conditional = GaussianConditional(None)
@property
def downsampling_factor(self) -> int:
return 2 ** (4 + 2)
def init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, noisequant=False):
y = self.g_a(x)
B, C, H, W = y.size() ## The shape of y to generate the mask
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
if not noisequant:
z_offset = self.entropy_bottleneck._get_medians()
z_tmp = z - z_offset
z_hat = ste_round(z_tmp) + z_offset
latent_means, latent_scales = self.h_s(z_hat).chunk(2, 1)
anchor = torch.zeros_like(y).to(x.device)
non_anchor = torch.zeros_like(y).to(x.device)
anchor[:, :, 0::2, 0::2] = y[:, :, 0::2, 0::2]
anchor[:, :, 1::2, 1::2] = y[:, :, 1::2, 1::2]
non_anchor[:, :, 0::2, 1::2] = y[:, :, 0::2, 1::2]
non_anchor[:, :, 1::2, 0::2] = y[:, :, 1::2, 0::2]
y_slices = torch.split(y, self.groups[1:], 1)
anchor_split = torch.split(anchor, self.groups[1:], 1)
non_anchor_split = torch.split(non_anchor, self.groups[1:], 1)
ctx_params_anchor_split = torch.split(torch.zeros(B, C * 2, H, W).to(x.device),
[2 * i for i in self.groups[1:]], 1)
y_hat_slices = []
y_hat_slices_for_gs = []
y_likelihood = []
for slice_index, y_slice in enumerate(y_slices):
if slice_index == 0:
support_slices = []
elif slice_index == 1:
support_slices = y_hat_slices[0]
support_slices_ch = self.cc_transforms[slice_index-1](support_slices)
support_slices_ch_mean, support_slices_ch_scale = support_slices_ch.chunk(2, 1)
else:
support_slices = torch.cat([y_hat_slices[0], y_hat_slices[slice_index-1]], dim=1)
support_slices_ch = self.cc_transforms[slice_index-1](support_slices)
support_slices_ch_mean, support_slices_ch_scale = support_slices_ch.chunk(2, 1)
##support mean and scale
support = torch.cat([latent_means, latent_scales], dim=1) if slice_index == 0 else torch.cat(
[support_slices_ch_mean, support_slices_ch_scale, latent_means, latent_scales], dim=1)
### checkboard process 1
y_anchor = anchor_split[slice_index]
means_anchor, scales_anchor, = self.ParamAggregation[slice_index](
torch.cat([ctx_params_anchor_split[slice_index], support], dim=1)).chunk(2, 1)
scales_hat_split = torch.zeros_like(y_anchor).to(x.device)
means_hat_split = torch.zeros_like(y_anchor).to(x.device)
scales_hat_split[:, :, 0::2, 0::2] = scales_anchor[:, :, 0::2, 0::2]
scales_hat_split[:, :, 1::2, 1::2] = scales_anchor[:, :, 1::2, 1::2]
means_hat_split[:, :, 0::2, 0::2] = means_anchor[:, :, 0::2, 0::2]
means_hat_split[:, :, 1::2, 1::2] = means_anchor[:, :, 1::2, 1::2]
if noisequant:
y_anchor_quantilized = self.quantizer.quantize(y_anchor, "noise")
y_anchor_quantilized_for_gs = self.quantizer.quantize(y_anchor, "ste")
else:
y_anchor_quantilized = self.quantizer.quantize(y_anchor - means_anchor, "ste") + means_anchor
y_anchor_quantilized_for_gs = self.quantizer.quantize(y_anchor - means_anchor, "ste") + means_anchor
y_anchor_quantilized[:, :, 0::2, 1::2] = 0
y_anchor_quantilized[:, :, 1::2, 0::2] = 0
y_anchor_quantilized_for_gs[:, :, 0::2, 1::2] = 0
y_anchor_quantilized_for_gs[:, :, 1::2, 0::2] = 0
### checkboard process 2
masked_context = self.context_prediction[slice_index](y_anchor_quantilized)
means_non_anchor, scales_non_anchor = self.ParamAggregation[slice_index](
torch.cat([masked_context, support], dim=1)).chunk(2, 1)
scales_hat_split[:, :, 0::2, 1::2] = scales_non_anchor[:, :, 0::2, 1::2]
scales_hat_split[:, :, 1::2, 0::2] = scales_non_anchor[:, :, 1::2, 0::2]
means_hat_split[:, :, 0::2, 1::2] = means_non_anchor[:, :, 0::2, 1::2]
means_hat_split[:, :, 1::2, 0::2] = means_non_anchor[:, :, 1::2, 0::2]
# entropy estimation
_, y_slice_likelihood = self.gaussian_conditional(y_slice, scales_hat_split, means=means_hat_split)
y_non_anchor = non_anchor_split[slice_index]
if noisequant:
y_non_anchor_quantilized = self.quantizer.quantize(y_non_anchor, "noise")
y_non_anchor_quantilized_for_gs = self.quantizer.quantize(y_non_anchor, "ste")
else:
y_non_anchor_quantilized = self.quantizer.quantize(y_non_anchor - means_non_anchor,
"ste") + means_non_anchor
y_non_anchor_quantilized_for_gs = self.quantizer.quantize(y_non_anchor - means_non_anchor,
"ste") + means_non_anchor
y_non_anchor_quantilized[:, :, 0::2, 0::2] = 0
y_non_anchor_quantilized[:, :, 1::2, 1::2] = 0
y_non_anchor_quantilized_for_gs[:, :, 0::2, 0::2] = 0
y_non_anchor_quantilized_for_gs[:, :, 1::2, 1::2] = 0
y_hat_slice = y_anchor_quantilized + y_non_anchor_quantilized
y_hat_slice_for_gs = y_anchor_quantilized_for_gs + y_non_anchor_quantilized_for_gs
y_hat_slices.append(y_hat_slice)
### ste for synthesis model
y_hat_slices_for_gs.append(y_hat_slice_for_gs)
y_likelihood.append(y_slice_likelihood)
y_likelihoods = torch.cat(y_likelihood, dim=1)
"""
use STE(y) as the input of synthesizer
"""
y_hat = torch.cat(y_hat_slices_for_gs, dim=1)
x_hat = self.g_s(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}
def load_state_dict(self, state_dict):
update_registered_buffers(
self.gaussian_conditional,
"gaussian_conditional",
["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
state_dict,
)
super().load_state_dict(state_dict)
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
net = cls()
net.load_state_dict(state_dict)
return net
def update(self, scale_table=None, force=False):
if scale_table is None:
scale_table = get_scale_table()
updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
updated |= super().update(force=force)
return updated
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
net = cls()
net.load_state_dict(state_dict)
return net
def compress(self, x):
import time
y_enc_start = time.time()
y = self.g_a(x)
y_enc = time.time() - y_enc_start
B, C, H, W = y.size() ## The shape of y to generate the mask
z_enc_start = time.time()
z = self.h_a(y)
z_enc = time.time() - z_enc_start
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
z_dec_start = time.time()
latent_means, latent_scales = self.h_s(z_hat).chunk(2, 1)
z_dec = time.time() - z_dec_start
y_slices = torch.split(y, self.groups[1:], 1)
ctx_params_anchor_split = torch.split(torch.zeros(B, C * 2, H, W).to(x.device), [2 * i for i in self.groups[1:]], 1)
y_strings = []
y_hat_slices = []
params_start = time.time()
for slice_index, y_slice in enumerate(y_slices):
if slice_index == 0:
support_slices = []
elif slice_index == 1:
support_slices = y_hat_slices[0]
support_slices_ch = self.cc_transforms[slice_index - 1](support_slices)
support_slices_ch_mean, support_slices_ch_scale = support_slices_ch.chunk(2, 1)
else:
support_slices = torch.cat([y_hat_slices[0], y_hat_slices[slice_index - 1]], dim=1)
support_slices_ch = self.cc_transforms[slice_index - 1](support_slices)
support_slices_ch_mean, support_slices_ch_scale = support_slices_ch.chunk(2, 1)
##support mean and scale
# support = torch.concat([latent_means, latent_scales], dim=1) if slice_index == 0 else torch.concat(
# [support_slices_ch_mean, support_slices_ch_scale, latent_means, latent_scales], dim=1)
support = torch.cat([latent_means, latent_scales], dim=1) if slice_index == 0 else torch.cat(
[support_slices_ch_mean, support_slices_ch_scale, latent_means, latent_scales], dim=1)
### checkboard process 1
y_anchor = y_slices[slice_index].clone()
means_anchor, scales_anchor, = self.ParamAggregation[slice_index](
torch.cat([ctx_params_anchor_split[slice_index], support], dim=1)).chunk(2, 1)
B_anchor, C_anchor, H_anchor, W_anchor = y_anchor.size()
y_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor//2).to(x.device)
means_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor//2).to(x.device)
scales_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor // 2).to(x.device)
y_anchor_decode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor).to(x.device)
y_anchor_encode[:, :, 0::2, :] = y_anchor[:, :, 0::2, 0::2]
y_anchor_encode[:, :, 1::2, :] = y_anchor[:, :, 1::2, 1::2]
means_anchor_encode[:, :, 0::2, :] = means_anchor[:, :, 0::2, 0::2]
means_anchor_encode[:, :, 1::2, :] = means_anchor[:, :, 1::2, 1::2]
scales_anchor_encode[:, :, 0::2, :] = scales_anchor[:, :, 0::2, 0::2]
scales_anchor_encode[:, :, 1::2, :] = scales_anchor[:, :, 1::2, 1::2]
indexes_anchor = self.gaussian_conditional.build_indexes(scales_anchor_encode)
anchor_strings = self.gaussian_conditional.compress(y_anchor_encode, indexes_anchor, means=means_anchor_encode)
anchor_quantized = self.gaussian_conditional.decompress(anchor_strings, indexes_anchor, means=means_anchor_encode)
y_anchor_decode[:, :, 0::2, 0::2] = anchor_quantized[:, :, 0::2, :]
y_anchor_decode[:, :, 1::2, 1::2] = anchor_quantized[:, :, 1::2, :]
### checkboard process 2
masked_context = self.context_prediction[slice_index](y_anchor_decode)
means_non_anchor, scales_non_anchor = self.ParamAggregation[slice_index](
torch.cat([masked_context, support], dim=1)).chunk(2, 1)
y_non_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor // 2).to(x.device)
means_non_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor // 2).to(x.device)
scales_non_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor // 2).to(x.device)
non_anchor = y_slices[slice_index].clone()
y_non_anchor_encode[:, :, 0::2, :] = non_anchor[:, :, 0::2, 1::2]
y_non_anchor_encode[:, :, 1::2, :] = non_anchor[:, :, 1::2, 0::2]
means_non_anchor_encode[:, :, 0::2, :] = means_non_anchor[:, :, 0::2, 1::2]
means_non_anchor_encode[:, :, 1::2, :] = means_non_anchor[:, :, 1::2, 0::2]
scales_non_anchor_encode[:, :, 0::2, :] = scales_non_anchor[:, :, 0::2, 1::2]
scales_non_anchor_encode[:, :, 1::2, :] = scales_non_anchor[:, :, 1::2, 0::2]
indexes_non_anchor = self.gaussian_conditional.build_indexes(scales_non_anchor_encode)
non_anchor_strings = self.gaussian_conditional.compress(y_non_anchor_encode, indexes_non_anchor,
means=means_non_anchor_encode)
non_anchor_quantized = self.gaussian_conditional.decompress(non_anchor_strings, indexes_non_anchor,
means=means_non_anchor_encode)
y_non_anchor_quantized = torch.zeros_like(means_anchor)
y_non_anchor_quantized[:, :, 0::2, 1::2] = non_anchor_quantized[:, :, 0::2, :]
y_non_anchor_quantized[:, :, 1::2, 0::2] = non_anchor_quantized[:, :, 1::2, :]
y_slice_hat = y_anchor_decode + y_non_anchor_quantized
y_hat_slices.append(y_slice_hat)
y_strings.append([anchor_strings, non_anchor_strings])
params_time = time.time() - params_start
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:],
"time": {'y_enc': y_enc, "z_enc": z_enc, "z_dec": z_dec, "params": params_time}}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
# FIXME: we don't respect the default entropy coder and directly call thse
# range ANS decoder
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
B, _, _, _ = z_hat.size()
latent_means, latent_scales = self.h_s(z_hat).chunk(2, 1)
y_shape = [z_hat.shape[2] * 4, z_hat.shape[3] * 4]
y_strings = strings[0]
ctx_params_anchor = torch.zeros((B, self.M*2, z_hat.shape[2] * 4, z_hat.shape[3] * 4)).to(z_hat.device)
ctx_params_anchor_split = torch.split(ctx_params_anchor, [2 * i for i in self.groups[1:]], 1)
y_hat_slices = []
for slice_index in range(len(self.groups) - 1):
if slice_index == 0:
support_slices = []
elif slice_index == 1:
support_slices = y_hat_slices[0]
support_slices_ch = self.cc_transforms[slice_index - 1](support_slices)
support_slices_ch_mean, support_slices_ch_scale = support_slices_ch.chunk(2, 1)
else:
support_slices = torch.cat([y_hat_slices[0], y_hat_slices[slice_index - 1]], dim=1)
support_slices_ch = self.cc_transforms[slice_index - 1](support_slices)
support_slices_ch_mean, support_slices_ch_scale = support_slices_ch.chunk(2, 1)
##support mean and scale
support = torch.cat([latent_means, latent_scales], dim=1) if slice_index == 0 else torch.cat(
[support_slices_ch_mean, support_slices_ch_scale, latent_means, latent_scales], dim=1)
### checkboard process 1
means_anchor, scales_anchor, = self.ParamAggregation[slice_index](
torch.cat([ctx_params_anchor_split[slice_index], support], dim=1)).chunk(2, 1)
B_anchor, C_anchor, H_anchor, W_anchor = means_anchor.size()
means_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor // 2).to(z_hat.device)
scales_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor // 2).to(z_hat.device)
y_anchor_decode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor).to(z_hat.device)
means_anchor_encode[:, :, 0::2, :] = means_anchor[:, :, 0::2, 0::2]
means_anchor_encode[:, :, 1::2, :] = means_anchor[:, :, 1::2, 1::2]
scales_anchor_encode[:, :, 0::2, :] = scales_anchor[:, :, 0::2, 0::2]
scales_anchor_encode[:, :, 1::2, :] = scales_anchor[:, :, 1::2, 1::2]
indexes_anchor = self.gaussian_conditional.build_indexes(scales_anchor_encode)
anchor_strings = y_strings[slice_index][0]
anchor_quantized = self.gaussian_conditional.decompress(anchor_strings, indexes_anchor,
means=means_anchor_encode)
y_anchor_decode[:, :, 0::2, 0::2] = anchor_quantized[:, :, 0::2, :]
y_anchor_decode[:, :, 1::2, 1::2] = anchor_quantized[:, :, 1::2, :]
### checkboard process 2
masked_context = self.context_prediction[slice_index](y_anchor_decode)
means_non_anchor, scales_non_anchor = self.ParamAggregation[slice_index](
torch.cat([masked_context, support], dim=1)).chunk(2, 1)
means_non_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor // 2).to(z_hat.device)
scales_non_anchor_encode = torch.zeros(B_anchor, C_anchor, H_anchor, W_anchor // 2).to(z_hat.device)
means_non_anchor_encode[:, :, 0::2, :] = means_non_anchor[:, :, 0::2, 1::2]
means_non_anchor_encode[:, :, 1::2, :] = means_non_anchor[:, :, 1::2, 0::2]
scales_non_anchor_encode[:, :, 0::2, :] = scales_non_anchor[:, :, 0::2, 1::2]
scales_non_anchor_encode[:, :, 1::2, :] = scales_non_anchor[:, :, 1::2, 0::2]
indexes_non_anchor = self.gaussian_conditional.build_indexes(scales_non_anchor_encode)
non_anchor_strings = y_strings[slice_index][1]
non_anchor_quantized = self.gaussian_conditional.decompress(non_anchor_strings, indexes_non_anchor,
means=means_non_anchor_encode)
y_non_anchor_quantized = torch.zeros_like(means_anchor)
y_non_anchor_quantized[:, :, 0::2, 1::2] = non_anchor_quantized[:, :, 0::2, :]
y_non_anchor_quantized[:, :, 1::2, 0::2] = non_anchor_quantized[:, :, 1::2, :]
y_slice_hat = y_anchor_decode + y_non_anchor_quantized
y_hat_slices.append(y_slice_hat)
y_hat = torch.cat(y_hat_slices, dim=1)
import time
y_dec_start = time.time()
x_hat = self.g_s(y_hat).clamp_(0, 1)
y_dec = time.time() - y_dec_start
return {"x_hat": x_hat, "time":{"y_dec": y_dec}}
def inference(self, x):
import time
y_enc_start = time.time()
y = self.g_a(x)
y_enc = time.time() - y_enc_start
B, C, H, W = y.size() ## The shape of y to generate the mask
z_enc_start = time.time()
z = self.h_a(y)
z_enc = time.time() - z_enc_start
z_hat, z_likelihoods = self.entropy_bottleneck(z)
z_offset = self.entropy_bottleneck._get_medians()
z_tmp = z - z_offset
z_hat = ste_round(z_tmp) + z_offset
z_dec_start = time.time()
latent_means, latent_scales = self.h_s(z_hat).chunk(2, 1)
z_dec = time.time() - z_dec_start
anchor = torch.zeros_like(y).to(x.device)
non_anchor = torch.zeros_like(y).to(x.device)
anchor[:, :, 0::2, 0::2] = y[:, :, 0::2, 0::2]
anchor[:, :, 1::2, 1::2] = y[:, :, 1::2, 1::2]
non_anchor[:, :, 0::2, 1::2] = y[:, :, 0::2, 1::2]
non_anchor[:, :, 1::2, 0::2] = y[:, :, 1::2, 0::2]
y_slices = torch.split(y, self.groups[1:], 1)
anchor_split = torch.split(anchor, self.groups[1:], 1)
non_anchor_split = torch.split(non_anchor, self.groups[1:], 1)
ctx_params_anchor_split = torch.split(torch.zeros(B, C * 2, H, W).to(x.device),
[2 * i for i in self.groups[1:]], 1)
y_hat_slices = []
y_likelihood = []
params_start = time.time()
for slice_index, y_slice in enumerate(y_slices):
if slice_index == 0:
support_slices = []
elif slice_index == 1:
support_slices = y_hat_slices[0]
support_slices_ch = self.cc_transforms[slice_index - 1](support_slices)
support_slices_ch_mean, support_slices_ch_scale = support_slices_ch.chunk(2, 1)
else:
support_slices = torch.cat([y_hat_slices[0], y_hat_slices[slice_index - 1]], dim=1)
support_slices_ch = self.cc_transforms[slice_index - 1](support_slices)
support_slices_ch_mean, support_slices_ch_scale = support_slices_ch.chunk(2, 1)
##support mean and scale
support = torch.cat([latent_means, latent_scales], dim=1) if slice_index == 0 else torch.cat(
[support_slices_ch_mean, support_slices_ch_scale, latent_means, latent_scales], dim=1)
### checkboard process 1
y_anchor = anchor_split[slice_index]
means_anchor, scales_anchor, = self.ParamAggregation[slice_index](
torch.cat([ctx_params_anchor_split[slice_index], support], dim=1)).chunk(2, 1)
scales_hat_split = torch.zeros_like(y_anchor).to(x.device)
means_hat_split = torch.zeros_like(y_anchor).to(x.device)
scales_hat_split[:, :, 0::2, 0::2] = scales_anchor[:, :, 0::2, 0::2]
scales_hat_split[:, :, 1::2, 1::2] = scales_anchor[:, :, 1::2, 1::2]
means_hat_split[:, :, 0::2, 0::2] = means_anchor[:, :, 0::2, 0::2]
means_hat_split[:, :, 1::2, 1::2] = means_anchor[:, :, 1::2, 1::2]
y_anchor_quantilized_for_gs = self.quantizer.quantize(y_anchor - means_anchor, "ste") + means_anchor
y_anchor_quantilized_for_gs[:, :, 0::2, 1::2] = 0
y_anchor_quantilized_for_gs[:, :, 1::2, 0::2] = 0
### checkboard process 2
masked_context = self.context_prediction[slice_index](y_anchor_quantilized_for_gs)
means_non_anchor, scales_non_anchor = self.ParamAggregation[slice_index](
torch.cat([masked_context, support], dim=1)).chunk(2, 1)
scales_hat_split[:, :, 0::2, 1::2] = scales_non_anchor[:, :, 0::2, 1::2]
scales_hat_split[:, :, 1::2, 0::2] = scales_non_anchor[:, :, 1::2, 0::2]
means_hat_split[:, :, 0::2, 1::2] = means_non_anchor[:, :, 0::2, 1::2]
means_hat_split[:, :, 1::2, 0::2] = means_non_anchor[:, :, 1::2, 0::2]
# entropy estimation
_, y_slice_likelihood = self.gaussian_conditional(y_slice, scales_hat_split, means=means_hat_split)
y_non_anchor = non_anchor_split[slice_index]
y_non_anchor_quantilized_for_gs = self.quantizer.quantize(y_non_anchor - means_non_anchor,
"ste") + means_non_anchor
y_non_anchor_quantilized_for_gs[:, :, 0::2, 0::2] = 0
y_non_anchor_quantilized_for_gs[:, :, 1::2, 1::2] = 0
y_hat_slice = y_anchor_quantilized_for_gs + y_non_anchor_quantilized_for_gs
y_hat_slices.append(y_hat_slice)
### ste for synthesis model
y_likelihood.append(y_slice_likelihood)
params_time = time.time() - params_start
y_likelihoods = torch.cat(y_likelihood, dim=1)
"""
use STE(y) as the input of synthesizer
"""
y_hat = torch.cat(y_hat_slices, dim=1)
y_dec_start = time.time()
x_hat = self.g_s(y_hat)
y_dec = time.time() - y_dec_start
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
"time": {'y_enc': y_enc, "y_dec": y_dec, "z_enc": z_enc, "z_dec": z_dec, "params":params_time}
}
if __name__ == "__main__":
# model = convTransformer(H=384, W=384, lenslet_num=8, viewsize=6, C=128, depth=4, heads=4, dim_head=96, mlp_dim=96, dropout=0.1, emb_dropout=0.)
# model = convTransformer(H=192, W=192, channels=3, patchsize=2, dim=64, depth=2, heads=4,
# dim_head=64, mlp_dim=64, dropout=0.1,
# emb_dropout=0.)
model = TestModel(N=192, M=320, num_slices=5)
# model = JointAutoregressiveHierarchicalPriors(192, 192)
# model = Cheng2020Attention(128)
input = torch.Tensor(1, 3, 256, 256)
# from torchvision import models
# model = models.resnet18()
# print(model)
out = model(input)
print(out["x_hat"].shape)
flops, params = get_model_complexity_info(model, (3, 256, 256), as_strings=True, print_per_layer_stat=True)
print('flops: ', flops, 'params: ', params)
flops, params = profile(model, (input,))
print('flops: ', flops, 'params: ', params)