diff --git a/src/training/dcc_tf.py b/src/training/dcc_tf.py index 46d915f..918c7e1 100644 --- a/src/training/dcc_tf.py +++ b/src/training/dcc_tf.py @@ -7,6 +7,8 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +import numpy as np + from torchmetrics.functional import( scale_invariant_signal_noise_ratio as si_snr, signal_noise_ratio as snr, @@ -80,10 +82,7 @@ def __init__(self, channels, num_layers, kernel_size=3): for i in range(num_layers)] # Compute buffer start indices for each layer - self.buf_indices = [0] - for i in range(num_layers - 1): - self.buf_indices.append( - self.buf_indices[-1] + self.buf_lengths[i]) + self.buf_indices = [0] + np.cumsum(self.buf_lengths[:(num_layers - 1)]).tolist() # Dilated causal conv layers aggregate previous context to obtain # contexful encoded input.