diff --git a/PyTorch/SpeechSynthesis/FastPitch/add_durations_lj_filelist.py b/PyTorch/SpeechSynthesis/FastPitch/add_durations_lj_filelist.py new file mode 100644 index 000000000..56bd9e378 --- /dev/null +++ b/PyTorch/SpeechSynthesis/FastPitch/add_durations_lj_filelist.py @@ -0,0 +1,26 @@ +import os +from pathlib import Path + + +def add_duration_column(filename, output_filename): + all_info = [] + with open(filename) as f: + for line in f: + file_path, pitch_path, transcript = line.strip().split('|', maxsplit=2) + name_stem = Path(os.path.basename(file_path)).stem + # stop hard-coding which columns already exist (no mels or speakers) + all_info.append('|'.join([file_path, + pitch_path, + f'durations/{name_stem}.pt', + transcript])) + + with open(output_filename, 'w') as f: + f.writelines('\n'.join(all_info)) + + +if __name__ == '__main__': + filelists = {'filelists/ljs_audio_pitch_text_test.txt': 'filelists/ljs_audio_pitch_durs_text_test.txt', + 'filelists/ljs_audio_pitch_text_train_v3.txt': 'filelists/ljs_audio_pitch_durs_text_train_v3.txt', + 'filelists/ljs_audio_pitch_text_val.txt': 'filelists/ljs_audio_pitch_durs_text_val.txt'} + for file_name, output_name in filelists.items(): + add_duration_column(file_name, output_name) diff --git a/PyTorch/SpeechSynthesis/FastPitch/common/layers.py b/PyTorch/SpeechSynthesis/FastPitch/common/layers.py index d3ec68f6d..80c059b87 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/common/layers.py +++ b/PyTorch/SpeechSynthesis/FastPitch/common/layers.py @@ -93,8 +93,9 @@ def __init__(self, filter_length=1024, hop_length=256, win_length=1024, self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate self.stft_fn = STFT(filter_length, hop_length, win_length) - mel_basis = librosa_mel_fn( - sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) + mel_basis = librosa_mel_fn(sr=sampling_rate, n_fft=filter_length, + n_mels=n_mel_channels, + fmin=mel_fmin, fmax=mel_fmax) mel_basis = torch.from_numpy(mel_basis).float() self.register_buffer('mel_basis', mel_basis) diff --git a/PyTorch/SpeechSynthesis/FastPitch/common/stft.py b/PyTorch/SpeechSynthesis/FastPitch/common/stft.py index 4084dc68e..bc140c142 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/common/stft.py +++ b/PyTorch/SpeechSynthesis/FastPitch/common/stft.py @@ -64,7 +64,7 @@ def __init__(self, filter_length=800, hop_length=200, win_length=800, assert(filter_length >= win_length) # get window and zero center pad it to filter_length fft_window = get_window(window, win_length, fftbins=True) - fft_window = pad_center(fft_window, filter_length) + fft_window = pad_center(fft_window, size=filter_length) fft_window = torch.from_numpy(fft_window).float() # window the bases diff --git a/PyTorch/SpeechSynthesis/FastPitch/common/text/symbols.py b/PyTorch/SpeechSynthesis/FastPitch/common/text/symbols.py index cfdb5755a..7262b1284 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/common/text/symbols.py +++ b/PyTorch/SpeechSynthesis/FastPitch/common/text/symbols.py @@ -9,6 +9,8 @@ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): _arpabet = ['@' + s for s in valid_symbols] +# In phones extracted from MFA TextGrid +_silences = ['@sp', '@sil'] def get_symbols(symbol_set='english_basic'): @@ -17,20 +19,20 @@ def get_symbols(symbol_set='english_basic'): _punctuation = '!\'(),.:;? ' _special = '-' _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - symbols = list(_pad + _special + _punctuation + _letters) + _arpabet + symbols = list(_pad + _special + _punctuation + _letters) + _arpabet + _silences elif symbol_set == 'english_basic_lowercase': _pad = '_' _punctuation = '!\'"(),.:;? ' _special = '-' _letters = 'abcdefghijklmnopqrstuvwxyz' - symbols = list(_pad + _special + _punctuation + _letters) + _arpabet + symbols = list(_pad + _special + _punctuation + _letters) + _arpabet + _silences elif symbol_set == 'english_expanded': _punctuation = '!\'",.:;? ' _math = '#%&*+-/[]()' _special = '_@©°½—₩€$' _accented = 'áçéêëñöøćž' _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet + symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet + _silences else: raise Exception("{} symbol set does not exist".format(symbol_set)) diff --git a/PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py b/PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py index b700df1f4..8a7e3e638 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py +++ b/PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py @@ -77,6 +77,9 @@ def clean_text(self, text): def symbols_to_sequence(self, symbols): return [self.symbol_to_id[s] for s in symbols if s in self.symbol_to_id] + def arpabet_list_to_sequence(self, text): + return self.symbols_to_sequence(['@' + s for s in text]) + def arpabet_to_sequence(self, text): return self.symbols_to_sequence(['@' + s for s in text.split()]) @@ -118,9 +121,7 @@ def get_arpabet(self, word): else: arpabet = arpabet[0] - arpabet = "{" + arpabet + arpabet_suffix + "}" - - return arpabet + return arpabet + arpabet_suffix def encode_text(self, text, return_all=False): if self.expand_currency: @@ -144,20 +145,16 @@ def encode_text(self, text, return_all=False): text = text_arpabet elif self.handle_arpabet == 'word': words = _words_re.findall(text) - text_arpabet = [ - word[1] if word[0] == '' else ( - self.get_arpabet(word[0]) - if np.random.uniform() < self.p_arpabet - else word[0]) - for word in words] - text_arpabet = ''.join(text_arpabet) + text_arpabet = [[word[1]] if word[0] == '' + else self.get_arpabet(word[0]).split(' ') + for word in words] + text_arpabet = [phone for phone_list in text_arpabet + for phone in phone_list if phone != ' '] text = text_arpabet elif self.handle_arpabet != '': raise Exception("{} handle_arpabet is not supported".format( self.handle_arpabet)) - - text_encoded = self.text_to_sequence(text) - + text_encoded = self.arpabet_list_to_sequence(text) if return_all: return text_encoded, text_clean, text_arpabet diff --git a/PyTorch/SpeechSynthesis/FastPitch/create_lab_files.py b/PyTorch/SpeechSynthesis/FastPitch/create_lab_files.py new file mode 100644 index 000000000..48da1f631 --- /dev/null +++ b/PyTorch/SpeechSynthesis/FastPitch/create_lab_files.py @@ -0,0 +1,36 @@ +import argparse +import os +import pathlib + +from common.utils import load_filepaths_and_text + + +def create_lab_files(dataset_path, filelist, n_speakers): + # Expect a list of filenames + if type(filelist) is str: + filelist = [filelist] + + # difficulty: dealing with 'are there speaker codes are not'? + dataset_entries = load_filepaths_and_text(filelist, dataset_path, + (n_speakers > 1)) + + for filepath, text in dataset_entries: + wav_name = pathlib.Path(filepath).stem + # lab extension is hardcoded + # so is the use of the wavs subdirectory + lab_filepath = os.path.join(dataset_path, f'{wav_name}.lab') + with open(lab_filepath, 'w') as f: + f.write(text) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, required=True, + help='Path to dataset') + parser.add_argument('--filelist', type=str, required=True, nargs='+', + help='List of wavs with transcript') + parser.add_argument('--n-speakers', type=int, default=1, + help='Number of speakers in dataset') + args = parser.parse_args() + + create_lab_files(args.dataset, args.filelist, args.n_speakers) diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py index 4e5b13764..cb1d8a581 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py +++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py @@ -110,6 +110,7 @@ def parse_fastpitch_args(parent, add_help=False): energy_pred = parser.add_argument_group('energy predictor parameters') energy_pred.add_argument('--energy-conditioning', action='store_true') + energy_pred.add_argument('--norm_energy', action='store_true') energy_pred.add_argument('--energy-predictor-kernel-size', default=3, type=int, help='Pitch predictor conv-1D kernel size') energy_pred.add_argument('--energy-predictor-filter-size', default=256, type=int, diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attention.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attention.py deleted file mode 100644 index 59a7397d6..000000000 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attention.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - - -class ConvNorm(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, - padding=None, dilation=1, bias=True, w_init_gain='linear'): - super(ConvNorm, self).__init__() - if padding is None: - assert(kernel_size % 2 == 1) - padding = int(dilation * (kernel_size - 1) / 2) - - self.conv = torch.nn.Conv1d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, - bias=bias) - - torch.nn.init.xavier_uniform_( - self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) - - def forward(self, signal): - conv_signal = self.conv(signal) - return conv_signal - - -class Invertible1x1ConvLUS(torch.nn.Module): - def __init__(self, c): - super(Invertible1x1ConvLUS, self).__init__() - # Sample a random orthonormal matrix to initialize weights - W, _ = torch.linalg.qr(torch.randn(c, c)) - # Ensure determinant is 1.0 not -1.0 - if torch.det(W) < 0: - W[:, 0] = -1*W[:, 0] - p, lower, upper = torch.lu_unpack(*torch.lu(W)) - - self.register_buffer('p', p) - # diagonals of lower will always be 1s anyway - lower = torch.tril(lower, -1) - lower_diag = torch.diag(torch.eye(c, c)) - self.register_buffer('lower_diag', lower_diag) - self.lower = nn.Parameter(lower) - self.upper_diag = nn.Parameter(torch.diag(upper)) - self.upper = nn.Parameter(torch.triu(upper, 1)) - - def forward(self, z, reverse=False): - U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag) - L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag) - W = torch.mm(self.p, torch.mm(L, U)) - if reverse: - if not hasattr(self, 'W_inverse'): - # Reverse computation - W_inverse = W.float().inverse() - if z.type() == 'torch.cuda.HalfTensor': - W_inverse = W_inverse.half() - - self.W_inverse = W_inverse[..., None] - z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) - return z - else: - W = W[..., None] - z = F.conv1d(z, W, bias=None, stride=1, padding=0) - log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag))) - return z, log_det_W - - -class ConvAttention(torch.nn.Module): - def __init__(self, n_mel_channels=80, n_speaker_dim=128, - n_text_channels=512, n_att_channels=80, temperature=1.0, - n_mel_convs=2, align_query_enc_type='3xconv', - use_query_proj=True): - super(ConvAttention, self).__init__() - self.temperature = temperature - self.att_scaling_factor = np.sqrt(n_att_channels) - self.softmax = torch.nn.Softmax(dim=3) - self.log_softmax = torch.nn.LogSoftmax(dim=3) - self.query_proj = Invertible1x1ConvLUS(n_mel_channels) - self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1) - self.align_query_enc_type = align_query_enc_type - self.use_query_proj = bool(use_query_proj) - - self.key_proj = nn.Sequential( - ConvNorm(n_text_channels, - n_text_channels * 2, - kernel_size=3, - bias=True, - w_init_gain='relu'), - torch.nn.ReLU(), - ConvNorm(n_text_channels * 2, - n_att_channels, - kernel_size=1, - bias=True)) - - self.align_query_enc_type = align_query_enc_type - - if align_query_enc_type == "inv_conv": - self.query_proj = Invertible1x1ConvLUS(n_mel_channels) - elif align_query_enc_type == "3xconv": - self.query_proj = nn.Sequential( - ConvNorm(n_mel_channels, - n_mel_channels * 2, - kernel_size=3, - bias=True, - w_init_gain='relu'), - torch.nn.ReLU(), - ConvNorm(n_mel_channels * 2, - n_mel_channels, - kernel_size=1, - bias=True), - torch.nn.ReLU(), - ConvNorm(n_mel_channels, - n_att_channels, - kernel_size=1, - bias=True)) - else: - raise ValueError("Unknown query encoder type specified") - - def run_padded_sequence(self, sorted_idx, unsort_idx, lens, padded_data, - recurrent_model): - """Sorts input data by previded ordering (and un-ordering) and runs the - packed data through the recurrent model - - Args: - sorted_idx (torch.tensor): 1D sorting index - unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx) - lens: lengths of input data (sorted in descending order) - padded_data (torch.tensor): input sequences (padded) - recurrent_model (nn.Module): recurrent model to run data through - Returns: - hidden_vectors (torch.tensor): outputs of the RNN, in the original, - unsorted, ordering - """ - - # sort the data by decreasing length using provided index - # we assume batch index is in dim=1 - padded_data = padded_data[:, sorted_idx] - padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens) - hidden_vectors = recurrent_model(padded_data)[0] - hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors) - # unsort the results at dim=1 and return - hidden_vectors = hidden_vectors[:, unsort_idx] - return hidden_vectors - - def encode_query(self, query, query_lens): - query = query.permute(2, 0, 1) # seq_len, batch, feature dim - lens, ids = torch.sort(query_lens, descending=True) - original_ids = [0] * lens.size(0) - for i in range(len(ids)): - original_ids[ids[i]] = i - - query_encoded = self.run_padded_sequence(ids, original_ids, lens, - query, self.query_lstm) - query_encoded = query_encoded.permute(1, 2, 0) - return query_encoded - - def forward(self, queries, keys, query_lens, mask=None, key_lens=None, - keys_encoded=None, attn_prior=None): - """Attention mechanism for flowtron parallel - Unlike in Flowtron, we have no restrictions such as causality etc, - since we only need this during training. - - Args: - queries (torch.tensor): B x C x T1 tensor - (probably going to be mel data) - keys (torch.tensor): B x C2 x T2 tensor (text data) - query_lens: lengths for sorting the queries in descending order - mask (torch.tensor): uint8 binary mask for variable length entries - (should be in the T2 domain) - Output: - attn (torch.tensor): B x 1 x T1 x T2 attention mask. - Final dim T2 should sum to 1 - """ - keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 - - # Beware can only do this since query_dim = attn_dim = n_mel_channels - if self.use_query_proj: - if self.align_query_enc_type == "inv_conv": - queries_enc, log_det_W = self.query_proj(queries) - elif self.align_query_enc_type == "3xconv": - queries_enc = self.query_proj(queries) - log_det_W = 0.0 - else: - queries_enc, log_det_W = self.query_proj(queries) - else: - queries_enc, log_det_W = queries, 0.0 - - # different ways of computing attn, - # one is isotopic gaussians (per phoneme) - # Simplistic Gaussian Isotopic Attention - - # B x n_attn_dims x T1 x T2 - attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 - # compute log likelihood from a gaussian - attn = -0.0005 * attn.sum(1, keepdim=True) - if attn_prior is not None: - attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]+1e-8) - - attn_logprob = attn.clone() - - if mask is not None: - attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), - -float("inf")) - - attn = self.softmax(attn) # Softmax along T2 - return attn, attn_logprob diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py deleted file mode 100644 index a653504fd..000000000 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class AttentionCTCLoss(torch.nn.Module): - def __init__(self, blank_logprob=-1): - super(AttentionCTCLoss, self).__init__() - self.log_softmax = torch.nn.LogSoftmax(dim=3) - self.blank_logprob = blank_logprob - self.CTCLoss = nn.CTCLoss(zero_infinity=True) - - def forward(self, attn_logprob, in_lens, out_lens): - key_lens = in_lens - query_lens = out_lens - attn_logprob_padded = F.pad(input=attn_logprob, - pad=(1, 0, 0, 0, 0, 0, 0, 0), - value=self.blank_logprob) - cost_total = 0.0 - for bid in range(attn_logprob.shape[0]): - target_seq = torch.arange(1, key_lens[bid]+1).unsqueeze(0) - curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2) - curr_logprob = curr_logprob[:query_lens[bid], :, :key_lens[bid]+1] - curr_logprob = self.log_softmax(curr_logprob[None])[0] - ctc_cost = self.CTCLoss( - curr_logprob, target_seq, input_lengths=query_lens[bid:bid+1], - target_lengths=key_lens[bid:bid+1]) - cost_total += ctc_cost - cost = cost_total/attn_logprob.shape[0] - return cost - - -class AttentionBinarizationLoss(torch.nn.Module): - def __init__(self): - super(AttentionBinarizationLoss, self).__init__() - - def forward(self, hard_attention, soft_attention, eps=1e-12): - log_sum = torch.log(torch.clamp(soft_attention[hard_attention == 1], - min=eps)).sum() - return -log_sum / hard_attention.sum() diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py index a007db86f..001c90f22 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py +++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py @@ -25,57 +25,54 @@ # # ***************************************************************************** -import functools -import json import re +from functools import lru_cache from pathlib import Path import librosa import numpy as np import torch import torch.nn.functional as F -from scipy import ndimage -from scipy.stats import betabinom import common.layers as layers +from common.text import cmudict from common.text.text_processing import TextProcessing from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu +from tgt.io import read_textgrid + + +def check_durations(durs, mel_len, filepath): + assert sum(durs) == mel_len, \ + f'Length mismatch: {filepath}, {sum(durs)} durs != {mel_len} lens' + + +def parse_textgrid(tier, sampling_rate, hop_length): + # From Dan Wells + # Latest MFA replaces silence phones with "" in output TextGrids + sil_phones = ['sil', 'sp', 'spn', ''] + start_time = tier[0].start_time + end_time = tier[-1].end_time + phones = [] + durations = [] + for index, label in enumerate(tier._objects): + p_start, p_end, phone = label.start_time, label.end_time, label.text + # if p_start > end_time: + # phones.append('') + end_time = p_end + if phone not in sil_phones: + phones.append(phone) + else: + if (index == 0) or (index == len(tier) - 1): + # leading or trailing silence + phones.append('sil') + else: + # short pause between words + phones.append('sp') + durations.append(int(np.ceil(p_end * sampling_rate / hop_length) + - np.ceil(p_start * sampling_rate / hop_length))) -class BetaBinomialInterpolator: - """Interpolates alignment prior matrices to save computation. - - Calculating beta-binomial priors is costly. Instead cache popular sizes - and use img interpolation to get priors faster. - """ - def __init__(self, round_mel_len_to=100, round_text_len_to=20): - self.round_mel_len_to = round_mel_len_to - self.round_text_len_to = round_text_len_to - self.bank = functools.lru_cache(beta_binomial_prior_distribution) - - def round(self, val, to): - return max(1, int(np.round((val + 1) / to))) * to - - def __call__(self, w, h): - bw = self.round(w, to=self.round_mel_len_to) - bh = self.round(h, to=self.round_text_len_to) - ret = ndimage.zoom(self.bank(bw, bh).T, zoom=(w / bw, h / bh), order=1) - assert ret.shape[0] == w, ret.shape - assert ret.shape[1] == h, ret.shape - return ret - - -def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0): - P = phoneme_count - M = mel_count - x = np.arange(0, P) - mel_text_probs = [] - for i in range(1, M+1): - a, b = scaling * i, scaling * (M + 1 - i) - rv = betabinom(P, a, b) - mel_i_prob = rv.pmf(x) - mel_text_probs.append(mel_i_prob) - return torch.tensor(np.array(mel_text_probs)) + return phones, durations, start_time, end_time def estimate_pitch(wav, mel_len, method='pyin', normalize_mean=None, @@ -128,38 +125,26 @@ class TTSDataset(torch.utils.data.Dataset): 2) normalizes text and converts them to sequences of one-hot vectors 3) computes mel-spectrograms from audio files. """ - def __init__(self, - dataset_path, - audiopaths_and_text, - text_cleaners, - n_mel_channels, - symbol_set='english_basic', - p_arpabet=1.0, - n_speakers=1, - load_mel_from_disk=True, - load_pitch_from_disk=True, - pitch_mean=214.72203, # LJSpeech defaults - pitch_std=65.72038, - max_wav_value=None, - sampling_rate=None, - filter_length=None, - hop_length=None, - win_length=None, - mel_fmin=None, - mel_fmax=None, - prepend_space_to_text=False, - append_space_to_text=False, - pitch_online_dir=None, - betabinomial_online_dir=None, - use_betabinomial_interpolator=True, - pitch_online_method='pyin', - **ignored): + def __init__(self, dataset_path, audiopaths_and_text, text_cleaners, + n_mel_channels, symbol_set='english_basic', p_arpabet=1.0, + cmu_dict='cmudict/cmudict-0.7b', + n_speakers=1, load_mel_from_disk=True, + load_pitch_from_disk=True, pitch_mean=214.72203, + pitch_std=65.72038, energy_mean=51.796032, energy_std=9.861213, + max_wav_value=None, sampling_rate=None, + filter_length=None, hop_length=None, win_length=None, + mel_fmin=None, mel_fmax=None, prepend_space_to_text=False, + append_space_to_text=False, load_durs_from_disk=False, + dur_online_dir=None, textgrid_path=None, + pitch_online_dir=None, pitch_online_method='pyin', **ignored): # Expect a list of filenames if type(audiopaths_and_text) is str: audiopaths_and_text = [audiopaths_and_text] + self.hop_length = hop_length self.dataset_path = dataset_path + self.textgrid_path = textgrid_path self.audiopaths_and_text = load_filepaths_and_text( audiopaths_and_text, dataset_path, has_speakers=(n_speakers > 1)) @@ -171,6 +156,7 @@ def __init__(self, filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax) self.load_pitch_from_disk = load_pitch_from_disk + self.load_durs_from_disk = load_durs_from_disk self.prepend_space_to_text = prepend_space_to_text self.append_space_to_text = append_space_to_text @@ -178,19 +164,16 @@ def __init__(self, assert p_arpabet == 0.0 or p_arpabet == 1.0, ( 'Only 0.0 and 1.0 p_arpabet is currently supported. ' 'Variable probability breaks caching of betabinomial matrices.') + if p_arpabet > 0.0: + cmudict.initialize(cmu_dict, keep_ambiguous=True) - self.tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet) + self.tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet, handle_arpabet='word', handle_arpabet_ambiguous='random') self.n_speakers = n_speakers self.pitch_tmp_dir = pitch_online_dir + self.dur_tmp_dir = dur_online_dir self.f0_method = pitch_online_method - self.betabinomial_tmp_dir = betabinomial_online_dir - self.use_betabinomial_interpolator = use_betabinomial_interpolator - - if use_betabinomial_interpolator: - self.betabinomial_interpolator = BetaBinomialInterpolator() - - expected_columns = (2 + int(load_pitch_from_disk) + (n_speakers > 1)) + expected_columns = (2 + int(load_durs_from_disk) + int(load_pitch_from_disk) + (n_speakers > 1)) assert not (load_pitch_from_disk and self.pitch_tmp_dir is not None) if len(self.audiopaths_and_text[0]) < expected_columns: @@ -203,6 +186,8 @@ def __init__(self, to_tensor = lambda x: torch.Tensor([x]) if type(x) is float else x self.pitch_mean = to_tensor(pitch_mean) self.pitch_std = to_tensor(pitch_std) + self.energy_mean = to_tensor(energy_mean) + self.energy_std = to_tensor(energy_std) def __getitem__(self, index): # Separate filename and text @@ -214,23 +199,30 @@ def __getitem__(self, index): speaker = None mel = self.get_mel(audiopath) - text = self.get_text(text) pitch = self.get_pitch(index, mel.size(-1)) energy = torch.norm(mel.float(), dim=0, p=2) - attn_prior = self.get_prior(index, mel.shape[1], text.shape[0]) + if self.energy_mean is not None: + assert self.energy_std is not None + norm_energy = normalize_pitch(energy.unsqueeze(dim=0), self.energy_mean, self.energy_std) + energy = norm_energy.squeeze() + dur, phones = self.get_dur(index) + text = phones assert pitch.size(-1) == mel.size(-1) # No higher formants? if len(pitch.size()) == 1: pitch = pitch[None, :] - return (text, mel, len(text), pitch, energy, speaker, attn_prior, - audiopath) + # this is a batch + # FastPitch 1.0: (text, mel, len_text, dur, pitch, speaker) + return (text, mel, len(text), pitch, energy, speaker, dur, + audiopath, phones) def __len__(self): return len(self.audiopaths_and_text) + @lru_cache() def get_mel(self, filename): if not self.load_mel_from_disk: audio, sampling_rate = load_wav_to_torch(filename) @@ -251,8 +243,9 @@ def get_mel(self, filename): return melspec + @lru_cache() def get_text(self, text): - text = self.tp.encode_text(text) + text, text_clean, text_arpabet = self.tp.encode_text(text, return_all=True) space = [self.tp.encode_text("A A")[1]] if self.prepend_space_to_text: @@ -261,31 +254,49 @@ def get_text(self, text): if self.append_space_to_text: text = text + space - return torch.LongTensor(text) - - def get_prior(self, index, mel_len, text_len): - - if self.use_betabinomial_interpolator: - return torch.from_numpy(self.betabinomial_interpolator(mel_len, - text_len)) - - if self.betabinomial_tmp_dir is not None: - audiopath, *_ = self.audiopaths_and_text[index] - fname = Path(audiopath).relative_to(self.dataset_path) if self.dataset_path else Path(audiopath) - fname = fname.with_suffix('.pt') - cached_fpath = Path(self.betabinomial_tmp_dir, fname) - - if cached_fpath.is_file(): - return torch.load(cached_fpath) - - attn_prior = beta_binomial_prior_distribution(text_len, mel_len) - - if self.betabinomial_tmp_dir is not None: - cached_fpath.parent.mkdir(parents=True, exist_ok=True) - torch.save(attn_prior, cached_fpath) - - return attn_prior + return torch.LongTensor(text), text_arpabet + @lru_cache() + def get_dur(self, index): + audiopath, *fields = self.audiopaths_and_text[index] + name = Path(audiopath).stem + + # TODO: check what happens here with absolute vs relative paths + path = Path(self.dataset_path, 'durations') if self.dataset_path else Path(audiopath) + fname = Path(path, name).with_suffix('.pt') + + if self.dur_tmp_dir is not None: + cached_durpath = Path(self.dur_tmp_dir, fname) + cached_phonepath = Path(self.dur_tmp_dir, name + '_phones').with_suffix('.pt') + if cached_durpath.is_file(): + # assume if one exists the other does too + return torch.load(cached_durpath), torch.load(cached_phonepath) + + if self.load_durs_from_disk: + duration_path = fields[1] # assume durations come after pitch + # assume phone_path is known from duration_path + phone_path = Path(Path(duration_path).parent, name + '_phones').with_suffix('.pt') + return torch.load(duration_path), torch.load(phone_path) + + tgt_path = Path(self.textgrid_path, f'{name}.TextGrid') + try: + textgrid = read_textgrid(tgt_path, include_empty_intervals=True) + except FileNotFoundError: + print(f'{name}.wav TextGrid missing: {tgt_path}') + raise + phones, durs, _, _ = parse_textgrid(textgrid.get_tier_by_name('phones'), + self.sampling_rate, + self.hop_length) + phones = torch.Tensor(self.tp.arpabet_list_to_sequence(phones)) + check_durations(durs, self.get_mel(audiopath).size(1), name) + durs = torch.Tensor(durs) + + if self.dur_tmp_dir is not None and not cached_durpath.is_file() and not cached_phonepath.is_file(): + return torch.save(durs, cached_durpath), torch.save(phones, cached_phonepath) + + return durs, phones + + @lru_cache() def get_pitch(self, index, mel_len=None): audiopath, *fields = self.audiopaths_and_text[index] @@ -327,13 +338,23 @@ def get_pitch(self, index, mel_len=None): class TTSCollate: """Zero-pads model inputs and targets based on number of frames per step""" - + # (text, mel, len(text), pitch, energy, speaker, dur, audiopath, phones) = batch + # 0: text + # 1: mel + # 2: len_text + # 3: pitch + # 4: energy + # 5: speaker + # 6: dur + # 7: audiopath + # 8: phones def __call__(self, batch): """Collate training batch from normalized text and mel-spec""" # Right zero-pad all one-hot text sequences to max input length input_lengths, ids_sorted_decreasing = torch.sort( torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True) + max_input_len = input_lengths[0] text_padded = torch.LongTensor(len(batch), max_input_len) @@ -342,6 +363,19 @@ def __call__(self, batch): text = batch[ids_sorted_decreasing[i]][0] text_padded[i, :text.size(0)] = text + dur_padded = torch.zeros_like(text_padded, dtype=torch.int32) + + dur_lens = torch.zeros(dur_padded.size(0), dtype=torch.int32) + for i in range(len(ids_sorted_decreasing)): + dur = batch[ids_sorted_decreasing[i]][6] + # With MFA durations: + # some mismatch between phones in transcript vs phones from text preprocessing + # for now using phones from texgrid as input + # PREP DATASET: DUR = LIST, TRAIN: DUR = TENSOR + dur_padded[i, :len(dur)] = dur + dur_lens[i] = len(dur) + assert dur_lens[i] == input_lengths[i] + # Right zero-pad mel-spec num_mels = batch[0][1].size(0) max_target_len = max([x[1].size(1) for x in batch]) @@ -359,12 +393,14 @@ def __call__(self, batch): pitch_padded = torch.zeros(mel_padded.size(0), n_formants, mel_padded.size(2), dtype=batch[0][3].dtype) energy_padded = torch.zeros_like(pitch_padded[:, 0, :]) - + phones_padded = torch.zeros_like(text_padded, dtype=int) for i in range(len(ids_sorted_decreasing)): pitch = batch[ids_sorted_decreasing[i]][3] energy = batch[ids_sorted_decreasing[i]][4] + phones = batch[ids_sorted_decreasing[i]][8] pitch_padded[i, :, :pitch.shape[1]] = pitch energy_padded[i, :energy.shape[0]] = energy + phones_padded[i, :phones.shape[0]] = phones if batch[0][5] is not None: speaker = torch.zeros_like(input_lengths) @@ -373,41 +409,35 @@ def __call__(self, batch): else: speaker = None - attn_prior_padded = torch.zeros(len(batch), max_target_len, - max_input_len) - attn_prior_padded.zero_() - for i in range(len(ids_sorted_decreasing)): - prior = batch[ids_sorted_decreasing[i]][6] - attn_prior_padded[i, :prior.size(0), :prior.size(1)] = prior - # Count number of items - characters in text len_x = [x[2] for x in batch] len_x = torch.Tensor(len_x) audiopaths = [batch[i][7] for i in ids_sorted_decreasing] - return (text_padded, input_lengths, mel_padded, output_lengths, len_x, - pitch_padded, energy_padded, speaker, attn_prior_padded, - audiopaths) + return (text_padded, dur_padded, input_lengths, mel_padded, output_lengths, len_x, + pitch_padded, energy_padded, dur_lens, speaker, audiopaths, phones_padded) def batch_to_gpu(batch): - (text_padded, input_lengths, mel_padded, output_lengths, len_x, - pitch_padded, energy_padded, speaker, attn_prior, audiopaths) = batch + (text_padded, durs_padded, input_lengths, mel_padded, output_lengths, len_x, + pitch_padded, energy_padded, dur_lens, speaker, audiopaths, phones_padded) = batch text_padded = to_gpu(text_padded).long() + durs_padded = to_gpu(durs_padded).long() + dur_lens = to_gpu(dur_lens).long() input_lengths = to_gpu(input_lengths).long() mel_padded = to_gpu(mel_padded).float() output_lengths = to_gpu(output_lengths).long() pitch_padded = to_gpu(pitch_padded).float() energy_padded = to_gpu(energy_padded).float() - attn_prior = to_gpu(attn_prior).float() + phones_padded = to_gpu(phones_padded).long() if speaker is not None: speaker = to_gpu(speaker).long() # Alignments act as both inputs and targets - pass shallow copies x = [text_padded, input_lengths, mel_padded, output_lengths, - pitch_padded, energy_padded, speaker, attn_prior, audiopaths] - y = [mel_padded, input_lengths, output_lengths] + pitch_padded, energy_padded, speaker, durs_padded, audiopaths, phones_padded] + y = [mel_padded, durs_padded, dur_lens, output_lengths] len_x = torch.sum(output_lengths) return (x, y, len_x) diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py index 0cd3775e5..dc2361cbe 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py +++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py @@ -30,29 +30,21 @@ from torch import nn from common.utils import mask_from_lens -from fastpitch.attn_loss_function import AttentionCTCLoss class FastPitchLoss(nn.Module): def __init__(self, dur_predictor_loss_scale=1.0, - pitch_predictor_loss_scale=1.0, attn_loss_scale=1.0, + pitch_predictor_loss_scale=1.0, energy_predictor_loss_scale=0.1): super(FastPitchLoss, self).__init__() self.dur_predictor_loss_scale = dur_predictor_loss_scale self.pitch_predictor_loss_scale = pitch_predictor_loss_scale self.energy_predictor_loss_scale = energy_predictor_loss_scale - self.attn_loss_scale = attn_loss_scale - self.attn_ctc_loss = AttentionCTCLoss() def forward(self, model_out, targets, is_training=True, meta_agg='mean'): - (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, pitch_tgt, - energy_pred, energy_tgt, attn_soft, attn_hard, attn_dur, - attn_logprob) = model_out - - (mel_tgt, in_lens, out_lens) = targets - - dur_tgt = attn_dur - dur_lens = in_lens + (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, pitch_tgt, energy_pred, energy_tgt) = model_out + # model_out = (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, pitch_tgt, energy_pred, energy_tgt) + mel_tgt, dur_tgt, dur_lens, output_lengths = targets mel_tgt.requires_grad = False # (B,H,T) => (B,T,H) @@ -70,7 +62,6 @@ def forward(self, model_out, targets, is_training=True, meta_agg='mean'): loss_fn = F.mse_loss mel_loss = loss_fn(mel_out, mel_tgt, reduction='none') mel_loss = (mel_loss * mel_mask).sum() / mel_mask.sum() - ldiff = pitch_tgt.size(2) - pitch_pred.size(2) pitch_pred = F.pad(pitch_pred, (0, ldiff, 0, 0, 0, 0), value=0.0) pitch_loss = F.mse_loss(pitch_tgt, pitch_pred, reduction='none') @@ -83,21 +74,16 @@ def forward(self, model_out, targets, is_training=True, meta_agg='mean'): else: energy_loss = 0 - # Attention loss - attn_loss = self.attn_ctc_loss(attn_logprob, in_lens, out_lens) - loss = (mel_loss + dur_pred_loss * self.dur_predictor_loss_scale + pitch_loss * self.pitch_predictor_loss_scale - + energy_loss * self.energy_predictor_loss_scale - + attn_loss * self.attn_loss_scale) + + energy_loss * self.energy_predictor_loss_scale) meta = { 'loss': loss.clone().detach(), 'mel_loss': mel_loss.clone().detach(), 'duration_predictor_loss': dur_pred_loss.clone().detach(), 'pitch_loss': pitch_loss.clone().detach(), - 'attn_loss': attn_loss.clone().detach(), 'dur_error': (torch.abs(dur_pred - dur_tgt).sum() / dur_mask.sum()).detach(), } diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py index 34fca4dff..b8f02300b 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py +++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py @@ -34,7 +34,6 @@ from common.layers import ConvReLUNorm from common.utils import mask_from_lens from fastpitch.alignment import b_mas, mas_width1 -from fastpitch.attention import ConvAttention from fastpitch.transformer import FFTransformer @@ -126,7 +125,7 @@ def __init__(self, n_mel_channels, n_symbols, padding_idx, energy_predictor_kernel_size, energy_predictor_filter_size, p_energy_predictor_dropout, energy_predictor_n_layers, energy_embedding_kernel_size, - n_speakers, speaker_emb_weight, pitch_conditioning_formants=1): + n_speakers, speaker_emb_weight, pitch_conditioning_formants=1, norm_energy=True): super(FastPitch, self).__init__() self.encoder = FFTransformer( @@ -187,6 +186,7 @@ def __init__(self, n_mel_channels, n_symbols, padding_idx, self.register_buffer('pitch_std', torch.zeros(1)) self.energy_conditioning = energy_conditioning + self.norm_energy = norm_energy if energy_conditioning: self.energy_predictor = TemporalPredictor( in_fft_output_size, @@ -204,45 +204,11 @@ def __init__(self, n_mel_channels, n_symbols, padding_idx, self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True) - self.attention = ConvAttention( - n_mel_channels, 0, symbols_embedding_dim, - use_query_proj=True, align_query_enc_type='3xconv') - - def binarize_attention(self, attn, in_lens, out_lens): - """For training purposes only. Binarizes attention with MAS. - These will no longer recieve a gradient. - - Args: - attn: B x 1 x max_mel_len x max_text_len - """ - b_size = attn.shape[0] - with torch.no_grad(): - attn_cpu = attn.data.cpu().numpy() - attn_out = torch.zeros_like(attn) - for ind in range(b_size): - hard_attn = mas_width1( - attn_cpu[ind, 0, :out_lens[ind], :in_lens[ind]]) - attn_out[ind, 0, :out_lens[ind], :in_lens[ind]] = torch.tensor( - hard_attn, device=attn.get_device()) - return attn_out - - def binarize_attention_parallel(self, attn, in_lens, out_lens): - """For training purposes only. Binarizes attention with MAS. - These will no longer recieve a gradient. - - Args: - attn: B x 1 x max_mel_len x max_text_len - """ - with torch.no_grad(): - attn_cpu = attn.data.cpu().numpy() - attn_out = b_mas(attn_cpu, in_lens.cpu().numpy(), - out_lens.cpu().numpy(), width=1) - return torch.from_numpy(attn_out).to(attn.get_device()) - - def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): - + def forward(self, inputs, use_gt_pitch=True, use_gt_durations=True, pace=1.0, max_duration=75): + # was FP1.0 : inputs, _, mel_tgt, _, DUR_TGT, _, pitch_tgt, speaker = inputs + # will be: inputs, input_lens, mel_tgt, mel_lens, DUR_TGT, pitch_dense, energy_dense, speaker, audiopaths = inputs (inputs, input_lens, mel_tgt, mel_lens, pitch_dense, energy_dense, - speaker, attn_prior, audiopaths) = inputs + speaker, dur_tgt, audiopaths, phones_padded) = inputs mel_max_len = mel_tgt.size(2) @@ -254,27 +220,7 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): spk_emb.mul_(self.speaker_emb_weight) # Input FFT - enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb) - - # Alignment - text_emb = self.encoder.word_emb(inputs) - - # make sure to do the alignments before folding - attn_mask = mask_from_lens(input_lens)[..., None] == 0 - # attn_mask should be 1 for unused timesteps in the text_enc_w_spkvec tensor - - attn_soft, attn_logprob = self.attention( - mel_tgt, text_emb.permute(0, 2, 1), mel_lens, attn_mask, - key_lens=input_lens, keys_encoded=enc_out, attn_prior=attn_prior) - - attn_hard = self.binarize_attention_parallel( - attn_soft, input_lens, mel_lens) - - # Viterbi --> durations - attn_hard_dur = attn_hard.sum(2)[:, 0, :] - dur_tgt = attn_hard_dur - - assert torch.all(torch.eq(dur_tgt.sum(dim=1), mel_lens)) + enc_out, enc_mask = self.encoder(phones_padded, conditioning=spk_emb) # Predict durations log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1) @@ -298,7 +244,8 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): # Average energy over characters energy_tgt = average_pitch(energy_dense.unsqueeze(1), dur_tgt) - energy_tgt = torch.log(1.0 + energy_tgt) + if not self.norm_energy: + energy_tgt = torch.log(1.0 + energy_tgt) energy_emb = self.energy_emb(energy_tgt) energy_tgt = energy_tgt.squeeze(1) @@ -308,14 +255,14 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): energy_tgt = None len_regulated, dec_lens = regulate_len( - dur_tgt, enc_out, pace, mel_max_len) + dur_tgt if use_gt_durations else dur_pred, + enc_out, pace, mel_max_len) # Output FFT dec_out, dec_mask = self.decoder(len_regulated, dec_lens) mel_out = self.proj(dec_out) return (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, - pitch_tgt, energy_pred, energy_tgt, attn_soft, attn_hard, - attn_hard_dur, attn_logprob) + pitch_tgt, energy_pred, energy_tgt) def infer(self, inputs, pace=1.0, dur_tgt=None, pitch_tgt=None, energy_tgt=None, pitch_transform=None, max_duration=75, diff --git a/PyTorch/SpeechSynthesis/FastPitch/filelists/mini_ljs_audio_pitch_durs_text_train_v3.txt b/PyTorch/SpeechSynthesis/FastPitch/filelists/mini_ljs_audio_pitch_durs_text_train_v3.txt new file mode 100644 index 000000000..db6e29b47 --- /dev/null +++ b/PyTorch/SpeechSynthesis/FastPitch/filelists/mini_ljs_audio_pitch_durs_text_train_v3.txt @@ -0,0 +1,20 @@ +wavs/LJ050-0234.wav|pitch/LJ050-0234.pt|durations/LJ050-0234.pt|It has used other Treasury law enforcement agents on special experiments in building and route surveys in places to which the President frequently travels. +wavs/LJ019-0373.wav|pitch/LJ019-0373.pt|durations/LJ019-0373.pt|to avail himself of his powers, as it was difficult to bring home the derelictions of duties and evasion of the acts. Too much was left to the inspectors. +wavs/LJ050-0207.wav|pitch/LJ050-0207.pt|durations/LJ050-0207.pt|Although Chief Rowley does not complain about the pay scale for Secret Service agents, +wavs/LJ048-0203.wav|pitch/LJ048-0203.pt|durations/LJ048-0203.pt|The three officers confirm that their primary concern was crowd and traffic control, +wavs/LJ003-0182.wav|pitch/LJ003-0182.pt|durations/LJ003-0182.pt|The tried and the untried, young and old, were herded together +wavs/LJ044-0166.wav|pitch/LJ044-0166.pt|durations/LJ044-0166.pt|According to Marina Oswald, he thought that would help him when he got to Cuba. +wavs/LJ019-0208.wav|pitch/LJ019-0208.pt|durations/LJ019-0208.pt|The proposal made was to purchase some fifty thousand square feet between Newgate, Warwick Lane, and the Sessions House, +wavs/LJ021-0146.wav|pitch/LJ021-0146.pt|durations/LJ021-0146.pt|I shall seek assurances of the making and maintenance of agreements, which can be mutually relied upon, +wavs/LJ013-0214.wav|pitch/LJ013-0214.pt|durations/LJ013-0214.pt|who took a carving-knife from the sideboard in the dining-room, went upstairs to Lord William's bedroom, and drew the knife across his throat. +wavs/LJ011-0256.wav|pitch/LJ011-0256.pt|durations/LJ011-0256.pt|By this time the neighbors were aroused, and several people came to the scene of the affray. +wavs/LJ014-0083.wav|pitch/LJ014-0083.pt|durations/LJ014-0083.pt|which, having possessed herself of the murdered man's keys, she rifled from end to end. +wavs/LJ035-0121.wav|pitch/LJ035-0121.pt|durations/LJ035-0121.pt|This is the period during which Oswald would have descended the stairs. In all likelihood +wavs/LJ049-0118.wav|pitch/LJ049-0118.pt|durations/LJ049-0118.pt|Enactment of this statute would mean that the investigation of any of the acts covered and of the possibility of a further attempt +wavs/LJ006-0132.wav|pitch/LJ006-0132.pt|durations/LJ006-0132.pt|All the wardsmen alike were more or less irresponsible. +wavs/LJ049-0084.wav|pitch/LJ049-0084.pt|durations/LJ049-0084.pt|Murder of the President has never been covered by Federal law, however, so that once it became reasonably clear that the killing was the act of a single person, +wavs/LJ012-0052.wav|pitch/LJ012-0052.pt|durations/LJ012-0052.pt|He claimed to be admitted to bail, and was taken from Newgate on a writ of habeas before one of the judges sitting at Westminster. +wavs/LJ011-0203.wav|pitch/LJ011-0203.pt|durations/LJ011-0203.pt|Monsieur le Maire was appealed to, and decided to leave it to the young lady, who at once abandoned Wakefield. +wavs/LJ019-0141.wav|pitch/LJ019-0141.pt|durations/LJ019-0141.pt|The old wards, day rooms and sleeping rooms combined, of which the reader has already heard so much, +wavs/LJ003-0322.wav|pitch/LJ003-0322.pt|durations/LJ003-0322.pt|except for the use of the debtors, or as medical comforts for the infirmary. +wavs/LJ027-0028.wav|pitch/LJ027-0028.pt|durations/LJ027-0028.pt|Such structures or organs are most often found internally. diff --git a/PyTorch/SpeechSynthesis/FastPitch/filelists/mini_ljs_audio_pitch_durs_text_val.txt b/PyTorch/SpeechSynthesis/FastPitch/filelists/mini_ljs_audio_pitch_durs_text_val.txt new file mode 100644 index 000000000..eda515c7e --- /dev/null +++ b/PyTorch/SpeechSynthesis/FastPitch/filelists/mini_ljs_audio_pitch_durs_text_val.txt @@ -0,0 +1,16 @@ +wavs/LJ016-0288.wav|pitch/LJ016-0288.pt|durations/LJ016-0288.pt|"Müller, Müller, He's the man," till a diversion was created by the appearance of the gallows, which was received with continuous yells. +wavs/LJ028-0275.wav|pitch/LJ028-0275.pt|durations/LJ028-0275.pt|At last, in the twentieth month, +wavs/LJ019-0273.wav|pitch/LJ019-0273.pt|durations/LJ019-0273.pt|which Sir Joshua Jebb told the committee he considered the proper elements of penal discipline. +wavs/LJ021-0145.wav|pitch/LJ021-0145.pt|durations/LJ021-0145.pt|From those willing to join in establishing this hoped-for period of peace, +wavs/LJ009-0076.wav|pitch/LJ009-0076.pt|durations/LJ009-0076.pt|We come to the sermon. +wavs/LJ048-0194.wav|pitch/LJ048-0194.pt|durations/LJ048-0194.pt|during the morning of November twenty-two prior to the motorcade. +wavs/LJ049-0050.wav|pitch/LJ049-0050.pt|durations/LJ049-0050.pt|Hill had both feet on the car and was climbing aboard to assist President and Mrs. Kennedy. +wavs/LJ022-0023.wav|pitch/LJ022-0023.pt|durations/LJ022-0023.pt|The overwhelming majority of people in this country know how to sift the wheat from the chaff in what they hear and what they read. +wavs/LJ034-0053.wav|pitch/LJ034-0053.pt|durations/LJ034-0053.pt|reached the same conclusion as Latona that the prints found on the cartons were those of Lee Harvey Oswald. +wavs/LJ035-0129.wav|pitch/LJ035-0129.pt|durations/LJ035-0129.pt|and she must have run down the stairs ahead of Oswald and would probably have seen or heard him. +wavs/LJ039-0075.wav|pitch/LJ039-0075.pt|durations/LJ039-0075.pt|once you know that you must put the crosshairs on the target and that is all that is necessary. +wavs/LJ046-0184.wav|pitch/LJ046-0184.pt|durations/LJ046-0184.pt|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes. +wavs/LJ003-0111.wav|pitch/LJ003-0111.pt|durations/LJ003-0111.pt|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity. +wavs/LJ037-0234.wav|pitch/LJ037-0234.pt|durations/LJ037-0234.pt|Mrs. Mary Brock, the wife of a mechanic who worked at the station, was there at the time and she saw a white male, +wavs/LJ047-0044.wav|pitch/LJ047-0044.pt|durations/LJ047-0044.pt|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies +wavs/LJ028-0081.wav|pitch/LJ028-0081.pt|durations/LJ028-0081.pt|Years later, when the archaeologists could readily distinguish the false from the true, diff --git a/PyTorch/SpeechSynthesis/FastPitch/install.sh b/PyTorch/SpeechSynthesis/FastPitch/install.sh index b788fdcc3..2e6b138b7 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/install.sh +++ b/PyTorch/SpeechSynthesis/FastPitch/install.sh @@ -42,6 +42,7 @@ conda uninstall pytorch ## Then we reinstall and this for some reason downgrades the gcc to 7 and then installing apex works/ conda install pytorch torchvision cudatoolkit=10.2 -c pytorch +conda install -c conda-forge montreal-forced-aligner ## Apex cd /disk/scratch1/${USER}/FastPitches/PyTorch/SpeechSynthesis/FastPitch/ @@ -58,6 +59,7 @@ pip install wandb pip install llvmlite==0.35.0 ## Ignore warning around here pip install numba==0.49.1 +pip install tgt ## for logging ## if needed, create a free account here: https://app.wandb.ai/login?signup=true diff --git a/PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py b/PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py index d93065b42..cb17fcd14 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py +++ b/PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py @@ -26,6 +26,8 @@ # ***************************************************************************** import argparse +import os +import sys import time from pathlib import Path @@ -44,14 +46,18 @@ def parse_args(parser): """ parser.add_argument('-d', '--dataset-path', type=str, default='./', help='Path to dataset') + parser.add_argument('--textgrid-path', type=str, + help='Path to TextGrids') parser.add_argument('--wav-text-filelists', required=True, nargs='+', type=str, help='Files with audio paths and text') parser.add_argument('--extract-mels', action='store_true', help='Calculate spectrograms from .wav files') parser.add_argument('--extract-pitch', action='store_true', help='Extract pitch') - parser.add_argument('--save-alignment-priors', action='store_true', - help='Pre-calculate diagonal matrices of alignment of text to audio') + parser.add_argument('--extract-durations', action='store_true', + help='Extract durations (from alignment dir)') + parser.add_argument('--durs-online-dir', type=str, + help='Durations tmp dir') parser.add_argument('--log-file', type=str, default='preproc_log.json', help='Filename for logging') parser.add_argument('--n-speakers', type=int, default=1) @@ -85,6 +91,7 @@ def main(): parser = parse_args(parser) args, unk_args = parser.parse_known_args() if len(unk_args) > 0: + print(unk_args) raise ValueError(f'Invalid options {unk_args}') DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, Path(args.dataset_path, args.log_file)), @@ -99,34 +106,34 @@ def main(): if args.extract_pitch: Path(args.dataset_path, 'pitch').mkdir(parents=False, exist_ok=True) - if args.save_alignment_priors: - Path(args.dataset_path, 'alignment_priors').mkdir(parents=False, exist_ok=True) + if args.extract_durations: + if not args.textgrid_path: + args.textgrid_path = os.path.join(args.dataset_path, 'TextGrid') + durs_path = Path(args.dataset_path, 'durations') + durs_path.mkdir(parents=False, exist_ok=True) + if args.durs_online_dir: + Path(args.durs_online_dir, durs_path).mkdir(parents=True, exist_ok=True) for filelist in args.wav_text_filelists: print(f'Processing {filelist}...') - dataset = TTSDataset( - args.dataset_path, - filelist, - text_cleaners=['english_cleaners_v2'], - n_mel_channels=args.n_mel_channels, - p_arpabet=0.0, - n_speakers=args.n_speakers, - load_mel_from_disk=False, - load_pitch_from_disk=False, - pitch_mean=None, - pitch_std=None, - max_wav_value=args.max_wav_value, - sampling_rate=args.sampling_rate, - filter_length=args.filter_length, - hop_length=args.hop_length, - win_length=args.win_length, - mel_fmin=args.mel_fmin, - mel_fmax=args.mel_fmax, - betabinomial_online_dir=None, - pitch_online_dir=None, - pitch_online_method=args.f0_method) + dataset = TTSDataset(args.dataset_path, filelist, + text_cleaners=['english_cleaners_v2'], + n_mel_channels=args.n_mel_channels, p_arpabet=1.0, + n_speakers=args.n_speakers, + load_mel_from_disk=False, + load_pitch_from_disk=False, pitch_mean=None, + pitch_std=None, max_wav_value=args.max_wav_value, + sampling_rate=args.sampling_rate, + filter_length=args.filter_length, + hop_length=args.hop_length, + win_length=args.win_length, mel_fmin=args.mel_fmin, + mel_fmax=args.mel_fmax, + pitch_online_dir=None, + dur_online_dir=None, + textgrid_path=args.textgrid_path, + pitch_online_method=args.f0_method) data_loader = DataLoader( dataset, @@ -142,8 +149,13 @@ def main(): for i, batch in enumerate(tqdm.tqdm(data_loader)): tik = time.time() - _, input_lens, mels, mel_lens, _, pitch, _, _, attn_prior, fpaths = batch - + # DATASET GETITEM + # (text, mel, len(text), pitch, energy, speaker, dur, audiopath, phones) + # TTSCOLLATE CALL + # (text_padded, dur_padded, input_lengths, mel_padded, + # output_lengths, len_x, pitch_padded, energy_padded, speaker, + # audiopaths, phones_padded) + text, durs, input_lens, mels, mel_lens, _, pitch, _, _, _, fpaths, phones = batch # Ensure filenames are unique for p in fpaths: fname = Path(p).name @@ -163,11 +175,20 @@ def main(): fpath = Path(args.dataset_path, 'pitch', fname) torch.save(p[:mel_lens[j]], fpath) - if args.save_alignment_priors: - for j, prior in enumerate(attn_prior): - fname = Path(fpaths[j]).with_suffix('.pt').name - fpath = Path(args.dataset_path, 'alignment_priors', fname) - torch.save(prior[:mel_lens[j], :input_lens[j]], fpath) + if args.extract_durations: + # From Dan Wells + for j, d in enumerate(durs): + filename = Path(fpaths[j]).stem + # TODO remove hardcoding dataset path? + dur_path = Path(args.dataset_path, + 'durations', f'{filename}.pt') + torch.save(d, dur_path) + for j, p in enumerate(phones): + filename = Path(fpaths[j]).stem + # save phones too + phones_path = Path(args.dataset_path, + 'durations', f'{filename}_phones.pt') + torch.save(p, phones_path) if __name__ == '__main__': diff --git a/PyTorch/SpeechSynthesis/FastPitch/requirements.txt b/PyTorch/SpeechSynthesis/FastPitch/requirements.txt index e6d7b1751..33b7548c1 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/requirements.txt +++ b/PyTorch/SpeechSynthesis/FastPitch/requirements.txt @@ -3,5 +3,6 @@ numpy inflect librosa==0.8.0 scipy +tgt tensorboardX==2.0 git+git://github.com/NVIDIA/dllogger.git@26a0f8f1958de2c0c460925ff6102a4d2486d6cc#egg=dllogger diff --git a/PyTorch/SpeechSynthesis/FastPitch/scripts/prepare_dataset.sh b/PyTorch/SpeechSynthesis/FastPitch/scripts/prepare_dataset.sh index 43525ef48..c7f9f5846 100755 --- a/PyTorch/SpeechSynthesis/FastPitch/scripts/prepare_dataset.sh +++ b/PyTorch/SpeechSynthesis/FastPitch/scripts/prepare_dataset.sh @@ -2,14 +2,41 @@ set -e +while getopts "ln:" opt; do + case $opt in + l ) LABELS="true";; + n ) NSPEAKERS=$OPTARG;; + \?) echo "Invalid option: -"$OPTARG"" >&2 + exit 1;; + esac + done + +: ${NSPEAKERS:=1} # default value : ${DATA_DIR:=LJSpeech-1.1} +: ${WAV_DIR:=${DATA_DIR}/wavs} # should already exist +: ${FILELIST:=filelists/ljs_audio_text.txt} +: ${ALIGNMENT_DIR:=${DATA_DIR}/mfa_alignments} : ${ARGS="--extract-mels"} +if [ "$LABELS" = "true" ] +then + python ./create_lab_files.py --dataset ${WAV_DIR} --filelist ${FILELIST} --n-speakers ${NSPEAKERS} +fi + +#mfa model download acoustic english --temp_directory /disk/scratch1/evdv/tmp/MFA +#mfa model download dictionary english --temp_directory /disk/scratch1/evdv/tmp/MFA +#mfa validate ${WAV_DIR} english english --temp_directory /disk/scratch1/evdv/tmp/MFA +#mfa align ${WAV_DIR} english english ${ALIGNMENT_DIR} --temp_directory /disk/scratch1/evdv/tmp/MFA + +# don't change batch size python prepare_dataset.py \ - --wav-text-filelists filelists/ljs_audio_text.txt \ - --n-workers 16 \ + --wav-text-filelists ${FILELIST} \ + --n-workers 4 \ --batch-size 1 \ --dataset-path $DATA_DIR \ + --textgrid-path $ALIGNMENT_DIR \ --extract-pitch \ + --extract-durations\ + --durs-online-dir "/tmp/" \ --f0-method pyin \ $ARGS diff --git a/PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh b/PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh index ba041a33f..1b2f0ce4d 100755 --- a/PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh +++ b/PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh @@ -1,24 +1,25 @@ #!/usr/bin/env bash export OMP_NUM_THREADS=1 +#export MPLCONFIGDIR=/disk/scratch1/evdv/tmp/ +#export WANDB_CONFIG_DIR=/disk/scratch1/evdv/tmp/.config/wandb -: ${NUM_GPUS:=8} -: ${BATCH_SIZE:=16} +: ${NUM_GPUS:=1} +: ${BATCH_SIZE:=2} : ${GRAD_ACCUMULATION:=2} -: ${OUTPUT_DIR:="./output"} +: ${OUTPUT_DIR:="./output_mfa/norm"} : ${DATASET_PATH:=LJSpeech-1.1} -: ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_text_train_v3.txt} -: ${VAL_FILELIST:=filelists/ljs_audio_pitch_text_val.txt} +: ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_durs_text_train_v3.txt} +: ${VAL_FILELIST:=filelists/mini_ljs_audio_pitch_durs_text_val.txt} : ${AMP:=false} : ${SEED:=""} : ${LEARNING_RATE:=0.1} # Adjust these when the amount of data changes -: ${EPOCHS:=1000} -: ${EPOCHS_PER_CHECKPOINT:=100} -: ${WARMUP_STEPS:=1000} -: ${KL_LOSS_WARMUP:=100} +: ${EPOCHS:=50} +: ${EPOCHS_PER_CHECKPOINT:=10} +: ${WARMUP_STEPS:=10} # Train a mixed phoneme/grapheme model : ${PHONE:=true} @@ -28,8 +29,9 @@ export OMP_NUM_THREADS=1 # Add dummy space prefix/suffix is audio is not precisely trimmed : ${APPEND_SPACES:=false} -: ${LOAD_PITCH_FROM_DISK:=true} -: ${LOAD_MEL_FROM_DISK:=false} +: ${LOAD_PITCH_FROM_DISK:=TRUE} +: ${LOAD_DURS_FROM_DISK:=TRUE} +: ${LOAD_MEL_FROM_DISK:=FALSE} # For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column : ${NSPEAKERS:=1} @@ -60,9 +62,6 @@ ARGS+=" --grad-clip-thresh 1000.0" ARGS+=" --dur-predictor-loss-scale 0.1" ARGS+=" --pitch-predictor-loss-scale 0.1" -# Autoalign & new features -ARGS+=" --kl-loss-start-epoch 0" -ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP" ARGS+=" --text-cleaners $TEXT_CLEANERS" ARGS+=" --n-speakers $NSPEAKERS" @@ -72,9 +71,11 @@ ARGS+=" --n-speakers $NSPEAKERS" [ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0" [ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning" [ "$SEED" != "" ] && ARGS+=" --seed $SEED" -[ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk" -[ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk" +[ "$LOAD_MEL_FROM_DISK" = TRUE ] && ARGS+=" --load-mel-from-disk" +[ "$LOAD_DURS_FROM_DISK" = TRUE ] && ARGS+=" --load-durs-from-disk" +[ "$LOAD_PITCH_FROM_DISK" = TRUE ] && ARGS+=" --load-pitch-from-disk" [ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch +[ "$DUR_ONLINE_DIR" != "" ] && ARGS+=" --dur-online-dir $DUR_ONLINE_DIR" # e.g., /dev/shm/dur [ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD" [ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text" [ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text" diff --git a/PyTorch/SpeechSynthesis/FastPitch/train.py b/PyTorch/SpeechSynthesis/FastPitch/train.py index 90cfb4443..3b682713d 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/train.py +++ b/PyTorch/SpeechSynthesis/FastPitch/train.py @@ -50,7 +50,6 @@ import models from common.text import cmudict from common.utils import BenchmarkStats, prepare_tmp -from fastpitch.attn_loss_function import AttentionBinarizationLoss from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset from fastpitch.loss_function import FastPitchLoss from fastpitch.model import regulate_len @@ -90,12 +89,6 @@ def parse_args(parser): help='Discounting factor for training weights EMA') train.add_argument('--grad-accumulation', type=int, default=1, help='Training steps to accumulate gradients for') - train.add_argument('--kl-loss-start-epoch', type=int, default=250, - help='Start adding the hard attention loss term') - train.add_argument('--kl-loss-warmup-epochs', type=int, default=100, - help='Gradually increase the hard attention loss term') - train.add_argument('--kl-loss-weight', type=float, default=1.0, - help='Gradually increase the hard attention loss term') train.add_argument('--benchmark-epochs-num', type=int, default=20, help='Number of epochs for calculating final stats') @@ -129,7 +122,8 @@ def parse_args(parser): help='Type of text cleaners for input text') data.add_argument('--symbol-set', type=str, default='english_basic', help='Define symbol set for input text') - data.add_argument('--p-arpabet', type=float, default=0.0, + # should be 1.0 to work with MFA textgrids, which contain only phones + data.add_argument('--p-arpabet', type=float, default=1.0, help='Probability of using arpabets instead of graphemes ' 'for each word; set 0 for pure grapheme training') data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms', @@ -145,6 +139,8 @@ def parse_args(parser): cond.add_argument('--n-speakers', type=int, default=1, help='Number of speakers in the dataset. ' 'n_speakers > 1 enables speaker embeddings') + cond.add_argument('--load-durs-from-disk', action='store_true', + help='Use durations cached on disk with prepare_dataset.py') cond.add_argument('--load-pitch-from-disk', action='store_true', help='Use pitch cached on disk with prepare_dataset.py') cond.add_argument('--pitch-online-method', default='pyin', @@ -152,6 +148,8 @@ def parse_args(parser): help='Calculate pitch on the fly during trainig') cond.add_argument('--pitch-online-dir', type=str, default=None, help='A directory for storing pitch calculated on-line') + cond.add_argument('--dur-online-dir', type=str, default=None, + help='A directory for storing durations calculated on-line') cond.add_argument('--pitch-mean', type=float, default=214.72203, help='Normalization value for pitch') cond.add_argument('--pitch-std', type=float, default=65.72038, @@ -336,19 +334,21 @@ def plot_batch_mels(pred_tgt_lists, rank): regulated_features = [] # prediction: mel, pitch, energy # target: mel, pitch, energy - for mel_pitch_energy in pred_tgt_lists: + for i, mel_pitch_energy in enumerate(pred_tgt_lists): mels = mel_pitch_energy[0] if mels.size(dim=2) == 80: # tgt and pred mel have diff dimension order mels = mels.permute(0, 2, 1) - mel_lens = mel_pitch_energy[-1] + mel_lens = mel_pitch_energy[-1].squeeze() + pitch = mel_pitch_energy[1].squeeze().unsqueeze(dim=-1) + energy = mel_pitch_energy[2].squeeze().unsqueeze(dim=-1) # reverse regulation for plotting: for every mel frame get pitch+energy - new_pitch = regulate_len(mel_lens, - mel_pitch_energy[1].permute(0, 2, 1))[0] - new_energy = regulate_len(mel_lens, - mel_pitch_energy[2].unsqueeze(dim=-1))[0] + if i == 0: + energy = regulate_len(mel_lens, energy)[0] + pitch = regulate_len(mel_lens, pitch)[0] + regulated_features.append([mels, - new_pitch.squeeze(axis=2), - new_energy.squeeze(axis=2)]) + pitch.squeeze(axis=2), + energy.squeeze(axis=2)]) batch_sizes = [feature.size(dim=0) for pred_tgt in regulated_features @@ -366,20 +366,20 @@ def plot_batch_mels(pred_tgt_lists, rank): def log_validation_batch(x, y_pred, rank): + # x = [text_padded, input_lengths, mel_padded, output_lengths, + # pitch_padded, energy_padded, speaker, durs_padded, audiopaths, phones_padded] + # y_pred = mel_out, dec_lens, dur_pred, pitch_pred, energy_pred x_fields = ['text_padded', 'input_lengths', 'mel_padded', 'output_lengths', 'pitch_padded', 'energy_padded', - 'speaker', 'attn_prior', 'audiopaths'] - y_pred_fields = ['mel_out', 'dec_mask', 'dur_pred', 'log_dur_pred', - 'pitch_pred', 'pitch_tgt', 'energy_pred', - 'energy_tgt', 'attn_soft', 'attn_hard', - 'attn_hard_dur', 'attn_logprob'] + 'speaker', 'durs_padded', 'audiopaths', 'phones_padded'] + y_pred_fields = ['mel_out', 'dec_mask', 'dur_pred', 'pitch_pred', 'energy_pred'] validation_dict = dict(zip(x_fields + y_pred_fields, list(x) + list(y_pred))) log(validation_dict, rank) # something in here returns a warning - pred_specs_keys = ['mel_out', 'pitch_pred', 'energy_pred', 'attn_hard_dur'] - tgt_specs_keys = ['mel_padded', 'pitch_tgt', 'energy_tgt', 'attn_hard_dur'] + pred_specs_keys = ['mel_out', 'pitch_pred', 'energy_pred', 'durs_padded'] + tgt_specs_keys = ['mel_padded', 'pitch_padded', 'energy_padded', 'durs_padded'] plot_batch_mels([[validation_dict[key] for key in pred_specs_keys], [validation_dict[key] for key in tgt_specs_keys]], rank) @@ -400,7 +400,12 @@ def validate(model, criterion, valset, batch_size, collate_fn, distributed_run, val_meta = defaultdict(float) val_num_frames = 0 for i, batch in enumerate(val_loader): + # x = [text_padded, input_lengths, mel_padded, output_lengths, + # pitch_padded, energy_padded, speaker, durs_padded, audiopaths, phones_padded] + # y = [mel_padded, durs_padded, dur_lens, output_lengths] + # len_x = torch.sum(output_lengths) x, y, num_frames = batch_to_gpu(batch) + # y_pred = mel_out, dec_lens, dur_pred, pitch_pred, energy_pred y_pred = model(x) if i % 5 == 0: @@ -425,6 +430,9 @@ def validate(model, criterion, valset, batch_size, collate_fn, distributed_run, log({ 'loss/validation-loss': val_meta['loss'].item(), 'mel-loss/validation-mel-loss': val_meta['mel_loss'].item(), + 'pitch-loss/validation-pitch-loss': val_meta['pitch_loss'].item(), + 'energy-loss/validation-energy-loss': val_meta['energy_loss'].item(), + 'dur-loss/validation-dur-error': val_meta['duration_predictor_loss'].item(), 'validation-frames per s': num_frames.item() / val_meta['took'], 'validation-took': val_meta['took'], }, rank) @@ -514,7 +522,6 @@ def main(): model_config = models.get_model_config('FastPitch', args) model = models.get_model('FastPitch', model_config, device) - attention_kl_loss = AttentionBinarizationLoss() if args.local_rank == 0: wandb.init(project=args.project, @@ -574,8 +581,7 @@ def main(): criterion = FastPitchLoss( dur_predictor_loss_scale=args.dur_predictor_loss_scale, - pitch_predictor_loss_scale=args.pitch_predictor_loss_scale, - attn_loss_scale=args.attn_loss_scale) + pitch_predictor_loss_scale=args.pitch_predictor_loss_scale) collate_fn = TTSCollate() @@ -610,6 +616,9 @@ def main(): epoch_loss = 0.0 epoch_mel_loss = 0.0 + epoch_pitch_loss = 0.0 + epoch_energy_loss = 0.0 + epoch_dur_loss = 0.0 epoch_num_frames = 0 epoch_frames_per_sec = 0.0 @@ -625,134 +634,132 @@ def main(): epoch_iter = 0 num_iters = len(train_loader) // args.grad_accumulation for batch in train_loader: - - if accumulated_steps == 0: - if epoch_iter == num_iters: - break - total_iter += 1 - epoch_iter += 1 - - adjust_learning_rate(total_iter, optimizer, args.learning_rate, - args.warmup_steps) - - model.zero_grad(set_to_none=True) - - x, y, num_frames = batch_to_gpu(batch) - - with torch.cuda.amp.autocast(enabled=args.amp): - y_pred = model(x) - loss, meta = criterion(y_pred, y) - - if (args.kl_loss_start_epoch is not None - and epoch >= args.kl_loss_start_epoch): - - if args.kl_loss_start_epoch == epoch and epoch_iter == 1: - print('Begin hard_attn loss') - - _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred - binarization_loss = attention_kl_loss(attn_hard, attn_soft) - kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight - meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight - loss += kl_weight * binarization_loss - - else: - meta['kl_loss'] = torch.zeros_like(loss) - kl_weight = 0 - binarization_loss = 0 - - loss /= args.grad_accumulation - - meta = {k: v / args.grad_accumulation - for k, v in meta.items()} - - if args.amp: - scaler.scale(loss).backward() - else: - loss.backward() - - if distributed_run: - reduced_loss = reduce_tensor(loss.data, args.world_size).item() - reduced_num_frames = reduce_tensor(num_frames.data, 1).item() - meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()} - else: - reduced_loss = loss.item() - reduced_num_frames = num_frames.item() - if np.isnan(reduced_loss): - raise Exception("loss is NaN") - - accumulated_steps += 1 - iter_loss += reduced_loss - iter_num_frames += reduced_num_frames - iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta} - - if accumulated_steps % args.grad_accumulation == 0: - - if args.amp: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - model.parameters(), args.grad_clip_thresh) - scaler.step(optimizer) - scaler.update() - else: - torch.nn.utils.clip_grad_norm_( - model.parameters(), args.grad_clip_thresh) - optimizer.step() - - if args.ema_decay > 0.0: - apply_multi_tensor_ema(args.ema_decay, *mt_ema_params) - - iter_mel_loss = iter_meta['mel_loss'].item() - iter_kl_loss = iter_meta['kl_loss'].item() - iter_time = time.perf_counter() - iter_start_time - epoch_frames_per_sec += iter_num_frames / iter_time - epoch_loss += iter_loss - epoch_num_frames += iter_num_frames - epoch_mel_loss += iter_mel_loss - if epoch_iter % 5 == 0: - log({ - 'epoch': epoch, - 'epoch_iter': epoch_iter, - 'num_iters': num_iters, - 'total_steps': total_iter, - 'loss/loss': iter_loss, - 'mel-loss/mel_loss': iter_mel_loss, - 'kl_loss': iter_kl_loss, - 'kl_weight': kl_weight, - 'frames per s': iter_num_frames / iter_time, - 'took': iter_time, - 'lrate': optimizer.param_groups[0]['lr'], - }, args.local_rank) - - accumulated_steps = 0 - iter_loss = 0 - iter_num_frames = 0 - iter_meta = {} - iter_start_time = time.perf_counter() - - # Finished epoch - epoch_loss /= epoch_iter - epoch_mel_loss /= epoch_iter - epoch_time = time.perf_counter() - epoch_start_time - - log({ - 'epoch': epoch, - 'loss/epoch_loss': epoch_loss, - 'mel-loss/epoch_mel_loss': epoch_mel_loss, - 'epoch_frames per s': epoch_num_frames / epoch_time, - 'epoch_took': epoch_time, - }, args.local_rank) - bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss, - epoch_time) - - validate(model, criterion, valset, args.batch_size, collate_fn, - distributed_run, batch_to_gpu, args.local_rank) - - if args.ema_decay > 0: - validate(ema_model, criterion, valset, args.batch_size, collate_fn, - distributed_run, batch_to_gpu, args.local_rank) - - maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch, - total_iter, model_config) + print(batch[-1]) + # + # if accumulated_steps == 0: + # if epoch_iter == num_iters: + # break + # total_iter += 1 + # epoch_iter += 1 + # + # adjust_learning_rate(total_iter, optimizer, args.learning_rate, + # args.warmup_steps) + # + # model.zero_grad(set_to_none=True) + # + # x, y, num_frames = batch_to_gpu(batch) + # + # with torch.cuda.amp.autocast(enabled=args.amp): + # # (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, pitch_tgt, energy_pred, energy_tgt) + # y_pred = model(x, use_gt_durations=True) + # # y = mel_padded, input_lengths, output_lengths + # loss, meta = criterion(y_pred, y) + # loss /= args.grad_accumulation + # + # meta = {k: v / args.grad_accumulation + # for k, v in meta.items()} + # + # if args.amp: + # scaler.scale(loss).backward() + # else: + # loss.backward() + # + # if distributed_run: + # reduced_loss = reduce_tensor(loss.data, args.world_size).item() + # reduced_num_frames = reduce_tensor(num_frames.data, 1).item() + # meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()} + # else: + # reduced_loss = loss.item() + # reduced_num_frames = num_frames.item() + # if np.isnan(reduced_loss): + # raise Exception("loss is NaN") + # + # accumulated_steps += 1 + # iter_loss += reduced_loss + # iter_num_frames += reduced_num_frames + # iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta} + # + # if accumulated_steps % args.grad_accumulation == 0: + # + # if args.amp: + # scaler.unscale_(optimizer) + # torch.nn.utils.clip_grad_norm_( + # model.parameters(), args.grad_clip_thresh) + # scaler.step(optimizer) + # scaler.update() + # else: + # torch.nn.utils.clip_grad_norm_( + # model.parameters(), args.grad_clip_thresh) + # optimizer.step() + # + # if args.ema_decay > 0.0: + # apply_multi_tensor_ema(args.ema_decay, *mt_ema_params) + # + # iter_mel_loss = iter_meta['mel_loss'].item() + # iter_pitch_loss = iter_meta['pitch_loss'].item() + # iter_energy_loss = iter_meta['energy_loss'].item() + # iter_dur_loss = iter_meta['duration_predictor_loss'].item() + # iter_time = time.perf_counter() - iter_start_time + # epoch_frames_per_sec += iter_num_frames / iter_time + # epoch_loss += iter_loss + # epoch_num_frames += iter_num_frames + # epoch_mel_loss += iter_mel_loss + # epoch_pitch_loss += iter_pitch_loss + # epoch_energy_loss += iter_energy_loss + # epoch_dur_loss += iter_dur_loss + # + # if epoch_iter % 5 == 0: + # log({ + # 'epoch': epoch, + # 'epoch_iter': epoch_iter, + # 'num_iters': num_iters, + # 'total_steps': total_iter, + # 'loss/loss': iter_loss, + # 'mel-loss/mel_loss': iter_mel_loss, + # 'pitch-loss/pitch_loss': iter_pitch_loss, + # 'energy-loss/energy_loss': iter_energy_loss, + # 'dur-loss/dur_loss': iter_dur_loss, + # 'frames per s': iter_num_frames / iter_time, + # 'took': iter_time, + # 'lrate': optimizer.param_groups[0]['lr'], + # }, args.local_rank) + # + # accumulated_steps = 0 + # iter_loss = 0 + # iter_num_frames = 0 + # iter_meta = {} + # iter_start_time = time.perf_counter() + # # for debugging only + # # validate(model, criterion, valset, args.batch_size, collate_fn, + # # distributed_run, batch_to_gpu, args.local_rank) + # + # # Finished epoch + # epoch_loss /= epoch_iter + # epoch_mel_loss /= epoch_iter + # epoch_time = time.perf_counter() - epoch_start_time + # + # log({ + # 'epoch': epoch, + # 'loss/epoch_loss': epoch_loss, + # 'mel-loss/epoch_mel_loss': epoch_mel_loss, + # 'pitch-loss/epoch_pitch_loss': epoch_pitch_loss, + # 'energy-loss/epoch_energy_loss': epoch_energy_loss, + # 'dur-loss/epoch_dur_loss': epoch_dur_loss, + # 'epoch_frames per s': epoch_num_frames / epoch_time, + # 'epoch_took': epoch_time, + # }, args.local_rank) + # bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss, + # epoch_time) + # + # validate(model, criterion, valset, args.batch_size, collate_fn, + # distributed_run, batch_to_gpu, args.local_rank) + # + # if args.ema_decay > 0: + # validate(ema_model, criterion, valset, args.batch_size, collate_fn, + # distributed_run, batch_to_gpu, args.local_rank) + # + # maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch, + # total_iter, model_config) # Finished training if len(bmark_stats) > 0: