diff --git a/PyTorch/SpeechSynthesis/FastPitch/experiments/train_template.sh b/PyTorch/SpeechSynthesis/FastPitch/experiments/train_template.sh new file mode 100755 index 000000000..af17f1939 --- /dev/null +++ b/PyTorch/SpeechSynthesis/FastPitch/experiments/train_template.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash +USER=`whoami` +export OMP_NUM_THREADS=1 +export WANDB_CONFIG_DIR=/disk/scratch1/${USER}/tmp/.config/wandb + +: ${NUM_GPUS:=1} +: ${BATCH_SIZE:=16} +: ${PROJECT="wandb_project"} +: ${PROJECT_DESC="test description"} +: ${GRAD_ACCUMULATION:=2} +: ${OUTPUT_DIR:="./output/"} +: ${DATASET_PATH:=LJSpeech-1.1} +: ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_durs_text_train_v3.txt} +: ${VAL_FILELIST:=filelists/ljs_audio_pitch_durs_text_val.txt} +: ${AMP:=false} +: ${SEED:=""} + +: ${LEARNING_RATE:=0.1} + +# Adjust these when the amount of data changes +: ${EPOCHS:=1000} +: ${EPOCHS_PER_CHECKPOINT:=100} +: ${WARMUP_STEPS:=1000} +: ${KL_LOSS_WARMUP:=100} + +# Train a mixed phoneme/grapheme model +: ${PHONE:=true} +# Enable energy conditioning +: ${ENERGY:=true} +# Enable spectral tilt conditioning +: ${SPECTRAL_TILT:=true} +# options for spectral tilt: surface, source, both +: ${WHICH_TILT:=both} +: ${TEXT_CLEANERS:=english_cleaners_v2} +# Add dummy space prefix/suffix is audio is not precisely trimmed +: ${APPEND_SPACES:=false} + +: ${LOAD_PITCH_FROM_DISK:=true} +: ${LOAD_MEL_FROM_DISK:=false} + +# For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column +: ${NSPEAKERS:=1} +: ${SAMPLING_RATE:=22050} + +# Adjust env variables to maintain the global batch size: NUM_GPUS x BATCH_SIZE x GRAD_ACCUMULATION = 256. +GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION)) +[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}." +echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \ + "(global batch size ${GBS})\n" + +ARGS="" +ARGS+=" --cuda" +ARGS+=" -o $OUTPUT_DIR" +ARGS+=" --dataset-path $DATASET_PATH" +ARGS+=" --training-files $TRAIN_FILELIST" +ARGS+=" --validation-files $VAL_FILELIST" +ARGS+=" -bs $BATCH_SIZE" +ARGS+=" --grad-accumulation $GRAD_ACCUMULATION" +ARGS+=" --optimizer lamb" +ARGS+=" --epochs $EPOCHS" +ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT" +ARGS+=" --resume" +ARGS+=" --warmup-steps $WARMUP_STEPS" +ARGS+=" -lr $LEARNING_RATE" +ARGS+=" --weight-decay 1e-6" +ARGS+=" --grad-clip-thresh 1000.0" +ARGS+=" --dur-predictor-loss-scale 0.1" +ARGS+=" --pitch-predictor-loss-scale 0.1" + +# Autoalign & new features +ARGS+=" --kl-loss-start-epoch 0" +ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP" +ARGS+=" --text-cleaners $TEXT_CLEANERS" +ARGS+=" --n-speakers $NSPEAKERS" + +[ "$PROJECT" != "" ] && ARGS+=" --project \"${PROJECT}\"" +[ "$EXPERIMENT_DESC" != "" ] && ARGS+=" --experiment-desc \"${EXPERIMENT_DESC}\"" +[ "$AMP" = "true" ] && ARGS+=" --amp" +[ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0" +[ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning" +[ "$SPECTRAL_TILT" = "true" ] && ARGS+=" --spectral-tilt-conditioning" +[ "$WHICH_TILT" != "" ] && ARGS+=" --include-tilt ${WHICH_TILT}" +[ "$WHICH_TILT" = "both" ] && ARGS+=" --no-spectral-predictors 12" +[ "$SEED" != "" ] && ARGS+=" --seed $SEED" +[ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk" +[ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk" +[ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch +[ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD" +[ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text" +[ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text" + +if [ "$SAMPLING_RATE" == "44100" ]; then + ARGS+=" --sampling-rate 44100" + ARGS+=" --filter-length 2048" + ARGS+=" --hop-length 512" + ARGS+=" --win-length 2048" + ARGS+=" --mel-fmin 0.0" + ARGS+=" --mel-fmax 22050.0" + +elif [ "$SAMPLING_RATE" != "22050" ]; then + echo "Unknown sampling rate $SAMPLING_RATE" + exit 1 +fi + +mkdir -p "$OUTPUT_DIR" + +: ${DISTRIBUTED:="-m torch.distributed.run --standalone --nnodes=1 --nproc_per_node $NUM_GPUS"} +CUDA_DEVICES=0 python $DISTRIBUTED train.py $ARGS "$@" diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py index 4e5b13764..f391d7dfd 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py +++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py @@ -119,11 +119,26 @@ def parse_fastpitch_args(parent, add_help=False): energy_pred.add_argument('--energy-predictor-n-layers', default=2, type=int, help='Number of conv-1D layers') + spectral_pred = parser.add_argument_group('spectral tilt predictor parameters') + spectral_pred.add_argument('--spectral-tilt-conditioning', action='store_true') + spectral_pred.add_argument('--spectral-tilt-predictor-kernel-size', default=3, type=int, + help='Pitch predictor conv-1D kernel size') + spectral_pred.add_argument('--spectral-tilt-predictor-filter-size', default=256, type=int, + help='Pitch predictor conv-1D filter size') + spectral_pred.add_argument('--p-spectral-tilt-predictor-dropout', default=0.1, type=float, + help='Pitch probability for energy predictor') + spectral_pred.add_argument('--spectral-tilt-predictor-n-layers', default=2, type=int, + help='Number of conv-1D layers') + spectral_pred.add_argument('--no-spectral-predictors', default=6, type=int, + help='6 if only one of surface or source, 12 if both') + cond = parser.add_argument_group('conditioning parameters') cond.add_argument('--pitch-embedding-kernel-size', default=3, type=int, help='Pitch embedding conv-1D kernel size') cond.add_argument('--energy-embedding-kernel-size', default=3, type=int, - help='Pitch embedding conv-1D kernel size') + help='Energy embedding conv-1D kernel size') + cond.add_argument('--spectral-tilt-embedding-kernel-size', default=3, type=int, + help='Spectral tilt embedding conv-1D kernel size') cond.add_argument('--speaker-emb-weight', type=float, default=1.0, help='Scale speaker embedding') diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py index a007db86f..145b1bc01 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py +++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py @@ -26,7 +26,6 @@ # ***************************************************************************** import functools -import json import re from pathlib import Path @@ -34,6 +33,7 @@ import numpy as np import torch import torch.nn.functional as F +from matplotlib import pyplot as plt from scipy import ndimage from scipy.stats import betabinom @@ -153,6 +153,7 @@ def __init__(self, betabinomial_online_dir=None, use_betabinomial_interpolator=True, pitch_online_method='pyin', + include_tilt=None, **ignored): # Expect a list of filenames @@ -204,6 +205,8 @@ def __init__(self, self.pitch_mean = to_tensor(pitch_mean) self.pitch_std = to_tensor(pitch_std) + self.spectral_tilt_features = include_tilt + def __getitem__(self, index): # Separate filename and text if self.n_speakers > 1: @@ -217,6 +220,10 @@ def __getitem__(self, index): text = self.get_text(text) pitch = self.get_pitch(index, mel.size(-1)) energy = torch.norm(mel.float(), dim=0, p=2) + if self.spectral_tilt_features: + spectral_tilt = self.get_spectral_tilt(mel, audiopath, self.spectral_tilt_features) + else: + spectral_tilt = None attn_prior = self.get_prior(index, mel.shape[1], text.shape[0]) assert pitch.size(-1) == mel.size(-1) @@ -225,7 +232,7 @@ def __getitem__(self, index): if len(pitch.size()) == 1: pitch = pitch[None, :] - return (text, mel, len(text), pitch, energy, speaker, attn_prior, + return (text, mel, len(text), pitch, energy, spectral_tilt, speaker, attn_prior, audiopath) def __len__(self): @@ -324,6 +331,45 @@ def get_pitch(self, index, mel_len=None): return pitch_mel + def get_spectral_tilt(self, mels, audio_path, tilt_features='both'): + # # plot these during development + # fig, axes = plt.subplots(2, 1, squeeze=False) + # titles = ["Mel Spectrogram", "One Slice"] + # axes[0][0].imshow(mels, origin="lower") + # axes[0][0].set_aspect(2.5, adjustable="box") + # axes[0][0].set_ylim(0, mels.shape[0]) + # axes[0][0].set_title(titles[0], fontsize="medium") + # axes[0][0].tick_params(labelsize="x-small", left=False, + # labelleft=False) + # axes[0][0].set_anchor("W") + # one_slice = mels[:, 200] + # axes[1][0].plot(one_slice) + # axes[1][0].set_ylim(min(one_slice) - 0.5, max(one_slice) + 0.5) + # axes[1][0].set_title(titles[0], fontsize="medium") + # axes[1][0].tick_params(labelsize="x-small", left=False, + # labelleft=False) + # axes[1][0].set_anchor("W") + # + # plt.show() + n_mels = mels.size(0) + + # surface tilt + # shape 6 x input_frames + poly_coefficients = np.polynomial.polynomial.polyfit(np.arange(1, n_mels + 1), mels, 5) + if tilt_features == 'both' or tilt_features == 'source': + # TODO: remove hardcoded path + audio_path = re.sub('/wavs/', '/iaif_gci_wavs/', audio_path) + iaif_mels = self.get_mel(audio_path) + iaif_poly_coefficients = np.polynomial.polynomial.polyfit( + np.arange(1, n_mels + 1), iaif_mels, 5) + if tilt_features == 'source': + return torch.FloatTensor(iaif_poly_coefficients) + # shape 12 x input_frames + return torch.cat([torch.FloatTensor(poly_coefficients), + torch.FloatTensor(iaif_poly_coefficients)]) + else: + return torch.FloatTensor(poly_coefficients) + class TTSCollate: """Zero-pads model inputs and targets based on number of frames per step""" @@ -355,7 +401,7 @@ def __call__(self, batch): mel_padded[i, :, :mel.size(1)] = mel output_lengths[i] = mel.size(1) - n_formants = batch[0][3].shape[0] + n_formants = batch[0][3].shape[0] # default 1 pitch_padded = torch.zeros(mel_padded.size(0), n_formants, mel_padded.size(2), dtype=batch[0][3].dtype) energy_padded = torch.zeros_like(pitch_padded[:, 0, :]) @@ -363,13 +409,22 @@ def __call__(self, batch): for i in range(len(ids_sorted_decreasing)): pitch = batch[ids_sorted_decreasing[i]][3] energy = batch[ids_sorted_decreasing[i]][4] + spectral_tilt = batch[ids_sorted_decreasing[i]][5] pitch_padded[i, :, :pitch.shape[1]] = pitch energy_padded[i, :energy.shape[0]] = energy - if batch[0][5] is not None: + if batch[0][5] is not None: # if None, there is no spectral tilt + num_coefficients = batch[0][5].size(0) + spectral_tilt_padded = torch.FloatTensor(len(batch), num_coefficients, max_target_len).zero_() + for i in range(len(ids_sorted_decreasing)): + spectral_tilt_padded[i, :, :spectral_tilt.shape[1]] = spectral_tilt + else: + spectral_tilt_padded = None + + if batch[0][6] is not None: speaker = torch.zeros_like(input_lengths) for i in range(len(ids_sorted_decreasing)): - speaker[i] = batch[ids_sorted_decreasing[i]][5] + speaker[i] = batch[ids_sorted_decreasing[i]][6] else: speaker = None @@ -377,23 +432,23 @@ def __call__(self, batch): max_input_len) attn_prior_padded.zero_() for i in range(len(ids_sorted_decreasing)): - prior = batch[ids_sorted_decreasing[i]][6] + prior = batch[ids_sorted_decreasing[i]][7] attn_prior_padded[i, :prior.size(0), :prior.size(1)] = prior # Count number of items - characters in text len_x = [x[2] for x in batch] len_x = torch.Tensor(len_x) - audiopaths = [batch[i][7] for i in ids_sorted_decreasing] + audiopaths = [batch[i][8] for i in ids_sorted_decreasing] return (text_padded, input_lengths, mel_padded, output_lengths, len_x, - pitch_padded, energy_padded, speaker, attn_prior_padded, + pitch_padded, energy_padded, spectral_tilt_padded, speaker, attn_prior_padded, audiopaths) def batch_to_gpu(batch): (text_padded, input_lengths, mel_padded, output_lengths, len_x, - pitch_padded, energy_padded, speaker, attn_prior, audiopaths) = batch + pitch_padded, energy_padded, spectral_tilt_padded, speaker, attn_prior, audiopaths) = batch text_padded = to_gpu(text_padded).long() input_lengths = to_gpu(input_lengths).long() @@ -402,12 +457,14 @@ def batch_to_gpu(batch): pitch_padded = to_gpu(pitch_padded).float() energy_padded = to_gpu(energy_padded).float() attn_prior = to_gpu(attn_prior).float() + if spectral_tilt_padded is not None: + spectral_tilt_padded = to_gpu(spectral_tilt_padded).float() if speaker is not None: speaker = to_gpu(speaker).long() # Alignments act as both inputs and targets - pass shallow copies x = [text_padded, input_lengths, mel_padded, output_lengths, - pitch_padded, energy_padded, speaker, attn_prior, audiopaths] + pitch_padded, energy_padded, spectral_tilt_padded, speaker, attn_prior, audiopaths] y = [mel_padded, input_lengths, output_lengths] len_x = torch.sum(output_lengths) return (x, y, len_x) diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py index 0cd3775e5..476c75354 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py +++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py @@ -36,18 +36,24 @@ class FastPitchLoss(nn.Module): def __init__(self, dur_predictor_loss_scale=1.0, pitch_predictor_loss_scale=1.0, attn_loss_scale=1.0, - energy_predictor_loss_scale=0.1): + energy_predictor_loss_scale=0.1, + spectral_tilt_predictor_loss_scale=0.1): super(FastPitchLoss, self).__init__() self.dur_predictor_loss_scale = dur_predictor_loss_scale self.pitch_predictor_loss_scale = pitch_predictor_loss_scale self.energy_predictor_loss_scale = energy_predictor_loss_scale + self.spectral_tilt_predictor_loss_scale = spectral_tilt_predictor_loss_scale self.attn_loss_scale = attn_loss_scale self.attn_ctc_loss = AttentionCTCLoss() def forward(self, model_out, targets, is_training=True, meta_agg='mean'): + # (mel_out, dec_mask, dur_pred, log_dur_pred, + # pitch_pred, pitch_tgt, energy_pred, energy_tgt, + # spectral_tilt_pred, spectral_tilt_tgt, + # attn_soft, attn_hard, attn_hard_dur, attn_logprob) (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, pitch_tgt, - energy_pred, energy_tgt, attn_soft, attn_hard, attn_dur, - attn_logprob) = model_out + energy_pred, energy_tgt, spectral_tilt_pred, spectral_tilt_tgt, + attn_soft, attn_hard, attn_dur, attn_logprob) = model_out (mel_tgt, in_lens, out_lens) = targets @@ -81,7 +87,16 @@ def forward(self, model_out, targets, is_training=True, meta_agg='mean'): energy_loss = F.mse_loss(energy_tgt, energy_pred, reduction='none') energy_loss = (energy_loss * dur_mask).sum() / dur_mask.sum() else: - energy_loss = 0 + energy_loss = torch.tensor(0, dtype=torch.int8).to(device='cuda') + + if spectral_tilt_pred is not None: + spectral_tilt_pred = spectral_tilt_pred.permute(0, 2, 1) + spectral_tilt_pred = F.pad(spectral_tilt_pred, (0, ldiff, 0, 0), value=0.0) + spectral_tilt_loss = F.mse_loss(spectral_tilt_tgt, spectral_tilt_pred, reduction='none') + spectral_dur_mask = torch.repeat_interleave(dur_mask[:, None, :], spectral_tilt_loss.size(1), dim=1) + spectral_tilt_loss = (spectral_tilt_loss * spectral_dur_mask).sum() / spectral_dur_mask.sum() + else: + spectral_tilt_loss = torch.tensor(0, dtype=torch.int8).to(device='cuda') # Attention loss attn_loss = self.attn_ctc_loss(attn_logprob, in_lens, out_lens) @@ -90,6 +105,7 @@ def forward(self, model_out, targets, is_training=True, meta_agg='mean'): + dur_pred_loss * self.dur_predictor_loss_scale + pitch_loss * self.pitch_predictor_loss_scale + energy_loss * self.energy_predictor_loss_scale + + spectral_tilt_loss * self.spectral_tilt_predictor_loss_scale + attn_loss * self.attn_loss_scale) meta = { @@ -97,14 +113,13 @@ def forward(self, model_out, targets, is_training=True, meta_agg='mean'): 'mel_loss': mel_loss.clone().detach(), 'duration_predictor_loss': dur_pred_loss.clone().detach(), 'pitch_loss': pitch_loss.clone().detach(), + 'energy_loss': energy_loss.clone().detach(), + 'spectral_tilt_loss': spectral_tilt_loss.clone().detach(), 'attn_loss': attn_loss.clone().detach(), 'dur_error': (torch.abs(dur_pred - dur_tgt).sum() / dur_mask.sum()).detach(), } - if energy_pred is not None: - meta['energy_loss'] = energy_loss.clone().detach() - assert meta_agg in ('sum', 'mean') if meta_agg == 'sum': bsz = mel_out.size(0) diff --git a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py index 34fca4dff..367e1d04c 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py +++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py @@ -126,7 +126,12 @@ def __init__(self, n_mel_channels, n_symbols, padding_idx, energy_predictor_kernel_size, energy_predictor_filter_size, p_energy_predictor_dropout, energy_predictor_n_layers, energy_embedding_kernel_size, - n_speakers, speaker_emb_weight, pitch_conditioning_formants=1): + n_speakers, speaker_emb_weight, spectral_tilt_conditioning, + spectral_tilt_predictor_kernel_size, spectral_tilt_predictor_filter_size, + p_spectral_tilt_predictor_dropout, spectral_tilt_predictor_n_layers, + spectral_tilt_embedding_kernel_size, + pitch_conditioning_formants=1, + no_spectral_predictors=6): super(FastPitch, self).__init__() self.encoder = FFTransformer( @@ -202,6 +207,22 @@ def __init__(self, n_mel_channels, n_symbols, padding_idx, kernel_size=energy_embedding_kernel_size, padding=int((energy_embedding_kernel_size - 1) / 2)) + self.spectral_tilt_conditioning = spectral_tilt_conditioning + if spectral_tilt_conditioning: + self.spectral_tilt_predictor = TemporalPredictor( + in_fft_output_size, + filter_size=spectral_tilt_predictor_filter_size, + kernel_size=spectral_tilt_predictor_kernel_size, + dropout=p_spectral_tilt_predictor_dropout, + n_layers=spectral_tilt_predictor_n_layers, + n_predictions=no_spectral_predictors + ) + + self.spectral_tilt_emb = nn.Conv1d( + no_spectral_predictors, symbols_embedding_dim, + kernel_size=spectral_tilt_embedding_kernel_size, + padding=int((spectral_tilt_embedding_kernel_size - 1) / 2)) + self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True) self.attention = ConvAttention( @@ -242,7 +263,7 @@ def binarize_attention_parallel(self, attn, in_lens, out_lens): def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): (inputs, input_lens, mel_tgt, mel_lens, pitch_dense, energy_dense, - speaker, attn_prior, audiopaths) = inputs + spectral_tilt_dense, speaker, attn_prior, audiopaths) = inputs mel_max_len = mel_tgt.size(2) @@ -307,19 +328,36 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): energy_pred = None energy_tgt = None + # Predict spectral tilt + if self.spectral_tilt_conditioning: + spectral_tilt_pred = self.spectral_tilt_predictor(enc_out, enc_mask).squeeze(-1) + + # Average energy over characters + spectral_tilt_tgt = average_pitch(spectral_tilt_dense, dur_tgt) + eps = 1e-7 + spectral_tilt_tgt = F.relu(spectral_tilt_tgt) + spectral_tilt_tgt = torch.log(eps + spectral_tilt_tgt) + spectral_tilt_emb = self.spectral_tilt_emb(spectral_tilt_tgt) + spectral_tilt_tgt = spectral_tilt_tgt.squeeze(1) + enc_out = enc_out + spectral_tilt_emb.transpose(1, 2) + else: + spectral_tilt_pred = None + spectral_tilt_tgt = None + len_regulated, dec_lens = regulate_len( dur_tgt, enc_out, pace, mel_max_len) # Output FFT dec_out, dec_mask = self.decoder(len_regulated, dec_lens) mel_out = self.proj(dec_out) - return (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, - pitch_tgt, energy_pred, energy_tgt, attn_soft, attn_hard, - attn_hard_dur, attn_logprob) + return (mel_out, dec_mask, dur_pred, log_dur_pred, + pitch_pred, pitch_tgt, energy_pred, energy_tgt, + spectral_tilt_pred, spectral_tilt_tgt, + attn_soft, attn_hard, attn_hard_dur, attn_logprob) def infer(self, inputs, pace=1.0, dur_tgt=None, pitch_tgt=None, - energy_tgt=None, pitch_transform=None, max_duration=75, - speaker=0): + energy_tgt=None, spectral_tilt_tgt=None, + pitch_transform=None, max_duration=75, speaker=0): if self.speaker_emb is None: spk_emb = 0 @@ -367,6 +405,19 @@ def infer(self, inputs, pace=1.0, dur_tgt=None, pitch_tgt=None, else: energy_pred = None + # Predict spectral_tilt + if self.spectral_tilt_conditioning: + + if spectral_tilt_tgt is None: + spectral_tilt_pred = self.spectral_tilt_predictor(enc_out, enc_mask).squeeze(-1) + spectral_tilt_emb = self.spectral_tilt_emb(spectral_tilt_pred.permute(0, 2, 1)).transpose(1, 2) + else: + spectral_tilt_emb = self.spectral_tilt_emb(spectral_tilt_tgt).transpose(1, 2) + + enc_out = enc_out + spectral_tilt_emb + else: + spectral_tilt_pred = None + len_regulated, dec_lens = regulate_len( dur_pred if dur_tgt is None else dur_tgt, enc_out, pace, mel_max_len=None) @@ -375,4 +426,4 @@ def infer(self, inputs, pace=1.0, dur_tgt=None, pitch_tgt=None, mel_out = self.proj(dec_out) # mel_lens = dec_mask.squeeze(2).sum(axis=1).long() mel_out = mel_out.permute(0, 2, 1) # For inference.py - return mel_out, dec_lens, dur_pred, pitch_pred, energy_pred + return mel_out, dec_lens, dur_pred, pitch_pred, energy_pred, spectral_tilt_pred diff --git a/PyTorch/SpeechSynthesis/FastPitch/models.py b/PyTorch/SpeechSynthesis/FastPitch/models.py index ab3af17ed..793a6acd3 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/models.py +++ b/PyTorch/SpeechSynthesis/FastPitch/models.py @@ -144,6 +144,15 @@ def get_model_config(model_name, args): # energy conditioning energy_conditioning=args.energy_conditioning, energy_embedding_kernel_size=args.energy_embedding_kernel_size, + # spectral tilt predictor + spectral_tilt_predictor_kernel_size=args.spectral_tilt_predictor_kernel_size, + spectral_tilt_predictor_filter_size=args.spectral_tilt_predictor_filter_size, + p_spectral_tilt_predictor_dropout=args.p_spectral_tilt_predictor_dropout, + spectral_tilt_predictor_n_layers=args.spectral_tilt_predictor_n_layers, + # spectral tilt conditioning + spectral_tilt_conditioning=args.spectral_tilt_conditioning, + spectral_tilt_embedding_kernel_size=args.spectral_tilt_embedding_kernel_size, + no_spectral_predictors=args.no_spectral_predictors ) return model_config diff --git a/PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh b/PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh index ba041a33f..b1b75e89f 100755 --- a/PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh +++ b/PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh @@ -1,14 +1,17 @@ #!/usr/bin/env bash - +USER=`whoami` export OMP_NUM_THREADS=1 +export WANDB_CONFIG_DIR=/disk/scratch1/${USER}/tmp/.config/wandb -: ${NUM_GPUS:=8} +: ${NUM_GPUS:=1} : ${BATCH_SIZE:=16} +: ${PROJECT="wandb_project"} +: ${PROJECT_DESC="test description"} : ${GRAD_ACCUMULATION:=2} -: ${OUTPUT_DIR:="./output"} +: ${OUTPUT_DIR:="./output/"} : ${DATASET_PATH:=LJSpeech-1.1} -: ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_text_train_v3.txt} -: ${VAL_FILELIST:=filelists/ljs_audio_pitch_text_val.txt} +: ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_durs_text_train_v3.txt} +: ${VAL_FILELIST:=filelists/ljs_audio_pitch_durs_text_val.txt} : ${AMP:=false} : ${SEED:=""} @@ -24,6 +27,10 @@ export OMP_NUM_THREADS=1 : ${PHONE:=true} # Enable energy conditioning : ${ENERGY:=true} +# Enable spectral tilt conditioning +: ${SPECTRAL_TILT:=true} +# options for spectral tilt: surface, source, both +: ${WHICH_TILT:=both} : ${TEXT_CLEANERS:=english_cleaners_v2} # Add dummy space prefix/suffix is audio is not precisely trimmed : ${APPEND_SPACES:=false} @@ -71,6 +78,8 @@ ARGS+=" --n-speakers $NSPEAKERS" [ "$AMP" = "true" ] && ARGS+=" --amp" [ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0" [ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning" +[ "$SPECTRAL_TILT" = "true" ] && ARGS+=" --spectral-tilt-conditioning" +[ "$WHICH_TILT" != "" ] && ARGS+=" --include-tilt ${WHICH_TILT}" [ "$SEED" != "" ] && ARGS+=" --seed $SEED" [ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk" [ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk" @@ -94,5 +103,5 @@ fi mkdir -p "$OUTPUT_DIR" -: ${DISTRIBUTED:="-m torch.distributed.launch --nproc_per_node $NUM_GPUS"} -python $DISTRIBUTED train.py $ARGS "$@" +: ${DISTRIBUTED:="-m torch.distributed.run --standalone --nnodes=1 --nproc_per_node $NUM_GPUS"} +CUDA_DEVICES=0 python $DISTRIBUTED train.py $ARGS "$@" diff --git a/PyTorch/SpeechSynthesis/FastPitch/train.py b/PyTorch/SpeechSynthesis/FastPitch/train.py index 873cfb9c4..a49a2e357 100644 --- a/PyTorch/SpeechSynthesis/FastPitch/train.py +++ b/PyTorch/SpeechSynthesis/FastPitch/train.py @@ -158,6 +158,9 @@ def parse_args(parser): help='Normalization value for pitch') cond.add_argument('--load-mel-from-disk', action='store_true', help='Use mel-spectrograms cache on the disk') # XXX + # for spectral tilt estimation + cond.add_argument('--include-tilt', default=None, type=str, + choices=['source', 'surface', 'both', None]) audio = parser.add_argument_group('audio parameters') audio.add_argument('--max-wav-value', default=32768.0, type=float, @@ -367,15 +370,17 @@ def plot_batch_mels(pred_tgt_lists, rank): def log_validation_batch(x, y_pred, rank): x_fields = ['text_padded', 'input_lengths', 'mel_padded', - 'output_lengths', 'pitch_padded', 'energy_padded', + 'output_lengths', 'pitch_padded', 'energy_padded', 'spectral_tilt_padded', 'speaker', 'attn_prior', 'audiopaths'] y_pred_fields = ['mel_out', 'dec_mask', 'dur_pred', 'log_dur_pred', 'pitch_pred', 'pitch_tgt', 'energy_pred', - 'energy_tgt', 'attn_soft', 'attn_hard', + 'energy_tgt', 'spectral_tilt_pred', + 'spectral_tilt_tgt', 'attn_soft', 'attn_hard', 'attn_hard_dur', 'attn_logprob'] - validation_dict = dict(zip(x_fields + y_pred_fields, list(x) + list(y_pred))) + # dec mask contains booleans, which to be logged need to be converted to integers + validation_dict.pop('dec_mask', None) log(validation_dict, rank) # something in here returns a warning pred_specs_keys = ['mel_out', 'pitch_pred', 'energy_pred', 'attn_hard_dur'] @@ -400,14 +405,20 @@ def validate(model, criterion, valset, batch_size, collate_fn, distributed_run, val_meta = defaultdict(float) val_num_frames = 0 for i, batch in enumerate(val_loader): + # x = (inputs, input_lens, mel_tgt, mel_lens, pitch_dense, + # energy_dense, spectral_tilt_dense, speaker, attn_prior, audiopaths) x, y, num_frames = batch_to_gpu(batch) + # (mel_out, dec_mask, dur_pred, log_dur_pred, + # pitch_pred, pitch_tgt, energy_pred, energy_tgt, + # spectral_tilt_pred, spectral_tilt_tgt, + # attn_soft, attn_hard, attn_hard_dur, attn_logprob) y_pred = model(x) - + loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum') if i % 5 == 0: + # dec_mask is index 1 + # y_pred = y_pred[0:1] + y_pred[2:] log_validation_batch(x, y_pred, rank) - loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum') - if distributed_run: for k, v in meta.items(): val_meta[k] += reduce_tensor(v, 1) @@ -420,13 +431,13 @@ def validate(model, criterion, valset, batch_size, collate_fn, distributed_run, val_meta = {k: v / len(valset) for k, v in val_meta.items()} val_meta['took'] = time.perf_counter() - tik - # log overall statistics of the validate step log({ 'loss/validation-loss': val_meta['loss'].item(), 'mel-loss/validation-mel-loss': val_meta['mel_loss'].item(), 'pitch-loss/validation-pitch-loss': val_meta['pitch_loss'].item(), 'energy-loss/validation-energy-loss': val_meta['energy_loss'].item(), + 'spectral-loss/validation-spectral-loss': val_meta['spectral_tilt_loss'].item(), 'dur-loss/validation-dur-error': val_meta['duration_predictor_loss'].item(), 'validation-frames per s': num_frames.item() / val_meta['took'], 'validation-took': val_meta['took'], @@ -615,6 +626,7 @@ def main(): epoch_mel_loss = 0.0 epoch_pitch_loss = 0.0 epoch_energy_loss = 0.0 + epoch_spectral_loss = 0.0 epoch_dur_loss = 0.0 epoch_num_frames = 0 epoch_frames_per_sec = 0.0 @@ -631,7 +643,6 @@ def main(): epoch_iter = 0 num_iters = len(train_loader) // args.grad_accumulation for batch in train_loader: - if accumulated_steps == 0: if epoch_iter == num_iters: break @@ -640,22 +651,19 @@ def main(): adjust_learning_rate(total_iter, optimizer, args.learning_rate, args.warmup_steps) - model.zero_grad(set_to_none=True) x, y, num_frames = batch_to_gpu(batch) - with torch.cuda.amp.autocast(enabled=args.amp): y_pred = model(x) loss, meta = criterion(y_pred, y) - if (args.kl_loss_start_epoch is not None and epoch >= args.kl_loss_start_epoch): if args.kl_loss_start_epoch == epoch and epoch_iter == 1: print('Begin hard_attn loss') - _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred + _, _, _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred binarization_loss = attention_kl_loss(attn_hard, attn_soft) kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight @@ -711,6 +719,7 @@ def main(): iter_kl_loss = iter_meta['kl_loss'].item() iter_pitch_loss = iter_meta['pitch_loss'].item() iter_energy_loss = iter_meta['energy_loss'].item() + iter_spectral_loss = iter_meta['spectral_tilt_loss'].item() iter_dur_loss = iter_meta['duration_predictor_loss'].item() iter_time = time.perf_counter() - iter_start_time epoch_frames_per_sec += iter_num_frames / iter_time @@ -719,6 +728,7 @@ def main(): epoch_mel_loss += iter_mel_loss epoch_pitch_loss += iter_pitch_loss epoch_energy_loss += iter_energy_loss + epoch_spectral_loss += iter_spectral_loss epoch_dur_loss += iter_dur_loss if epoch_iter % 5 == 0: @@ -733,6 +743,7 @@ def main(): 'kl_weight': kl_weight, 'pitch-loss/pitch_loss': iter_pitch_loss, 'energy-loss/energy_loss': iter_energy_loss, + 'spectral-loss/spectral_loss': iter_spectral_loss, 'dur-loss/dur_loss': iter_dur_loss, 'frames per s': iter_num_frames / iter_time, 'took': iter_time, @@ -756,20 +767,19 @@ def main(): 'mel-loss/epoch_mel_loss': epoch_mel_loss, 'pitch-loss/epoch_pitch_loss': epoch_pitch_loss, 'energy-loss/epoch_energy_loss': epoch_energy_loss, + 'spectral-loss/epoch_spectral_loss': epoch_spectral_loss, 'dur-loss/epoch_dur_loss': epoch_dur_loss, 'epoch_frames per s': epoch_num_frames / epoch_time, 'epoch_took': epoch_time, }, args.local_rank) bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss, epoch_time) - validate(model, criterion, valset, args.batch_size, collate_fn, distributed_run, batch_to_gpu, args.local_rank) if args.ema_decay > 0: validate(ema_model, criterion, valset, args.batch_size, collate_fn, distributed_run, batch_to_gpu, args.local_rank) - maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch, total_iter, model_config)