Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions PyTorch/SpeechSynthesis/FastPitch/experiments/train_template.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env bash
USER=`whoami`
export OMP_NUM_THREADS=1
export WANDB_CONFIG_DIR=/disk/scratch1/${USER}/tmp/.config/wandb

: ${NUM_GPUS:=1}
: ${BATCH_SIZE:=16}
: ${PROJECT="wandb_project"}
: ${PROJECT_DESC="test description"}
: ${GRAD_ACCUMULATION:=2}
: ${OUTPUT_DIR:="./output/"}
: ${DATASET_PATH:=LJSpeech-1.1}
: ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_durs_text_train_v3.txt}
: ${VAL_FILELIST:=filelists/ljs_audio_pitch_durs_text_val.txt}
: ${AMP:=false}
: ${SEED:=""}

: ${LEARNING_RATE:=0.1}

# Adjust these when the amount of data changes
: ${EPOCHS:=1000}
: ${EPOCHS_PER_CHECKPOINT:=100}
: ${WARMUP_STEPS:=1000}
: ${KL_LOSS_WARMUP:=100}

# Train a mixed phoneme/grapheme model
: ${PHONE:=true}
# Enable energy conditioning
: ${ENERGY:=true}
# Enable spectral tilt conditioning
: ${SPECTRAL_TILT:=true}
# options for spectral tilt: surface, source, both
: ${WHICH_TILT:=both}
: ${TEXT_CLEANERS:=english_cleaners_v2}
# Add dummy space prefix/suffix is audio is not precisely trimmed
: ${APPEND_SPACES:=false}

: ${LOAD_PITCH_FROM_DISK:=true}
: ${LOAD_MEL_FROM_DISK:=false}

# For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column
: ${NSPEAKERS:=1}
: ${SAMPLING_RATE:=22050}

# Adjust env variables to maintain the global batch size: NUM_GPUS x BATCH_SIZE x GRAD_ACCUMULATION = 256.
GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION))
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}."
echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \
"(global batch size ${GBS})\n"

ARGS=""
ARGS+=" --cuda"
ARGS+=" -o $OUTPUT_DIR"
ARGS+=" --dataset-path $DATASET_PATH"
ARGS+=" --training-files $TRAIN_FILELIST"
ARGS+=" --validation-files $VAL_FILELIST"
ARGS+=" -bs $BATCH_SIZE"
ARGS+=" --grad-accumulation $GRAD_ACCUMULATION"
ARGS+=" --optimizer lamb"
ARGS+=" --epochs $EPOCHS"
ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT"
ARGS+=" --resume"
ARGS+=" --warmup-steps $WARMUP_STEPS"
ARGS+=" -lr $LEARNING_RATE"
ARGS+=" --weight-decay 1e-6"
ARGS+=" --grad-clip-thresh 1000.0"
ARGS+=" --dur-predictor-loss-scale 0.1"
ARGS+=" --pitch-predictor-loss-scale 0.1"

# Autoalign & new features
ARGS+=" --kl-loss-start-epoch 0"
ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP"
ARGS+=" --text-cleaners $TEXT_CLEANERS"
ARGS+=" --n-speakers $NSPEAKERS"

[ "$PROJECT" != "" ] && ARGS+=" --project \"${PROJECT}\""
[ "$EXPERIMENT_DESC" != "" ] && ARGS+=" --experiment-desc \"${EXPERIMENT_DESC}\""
[ "$AMP" = "true" ] && ARGS+=" --amp"
[ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0"
[ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning"
[ "$SPECTRAL_TILT" = "true" ] && ARGS+=" --spectral-tilt-conditioning"
[ "$WHICH_TILT" != "" ] && ARGS+=" --include-tilt ${WHICH_TILT}"
[ "$WHICH_TILT" = "both" ] && ARGS+=" --no-spectral-predictors 12"
[ "$SEED" != "" ] && ARGS+=" --seed $SEED"
[ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk"
[ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk"
[ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch
[ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
[ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text"
[ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text"

if [ "$SAMPLING_RATE" == "44100" ]; then
ARGS+=" --sampling-rate 44100"
ARGS+=" --filter-length 2048"
ARGS+=" --hop-length 512"
ARGS+=" --win-length 2048"
ARGS+=" --mel-fmin 0.0"
ARGS+=" --mel-fmax 22050.0"

elif [ "$SAMPLING_RATE" != "22050" ]; then
echo "Unknown sampling rate $SAMPLING_RATE"
exit 1
fi

mkdir -p "$OUTPUT_DIR"

: ${DISTRIBUTED:="-m torch.distributed.run --standalone --nnodes=1 --nproc_per_node $NUM_GPUS"}
CUDA_DEVICES=0 python $DISTRIBUTED train.py $ARGS "$@"
17 changes: 16 additions & 1 deletion PyTorch/SpeechSynthesis/FastPitch/fastpitch/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,26 @@ def parse_fastpitch_args(parent, add_help=False):
energy_pred.add_argument('--energy-predictor-n-layers', default=2, type=int,
help='Number of conv-1D layers')

spectral_pred = parser.add_argument_group('spectral tilt predictor parameters')
spectral_pred.add_argument('--spectral-tilt-conditioning', action='store_true')
spectral_pred.add_argument('--spectral-tilt-predictor-kernel-size', default=3, type=int,
help='Pitch predictor conv-1D kernel size')
spectral_pred.add_argument('--spectral-tilt-predictor-filter-size', default=256, type=int,
help='Pitch predictor conv-1D filter size')
spectral_pred.add_argument('--p-spectral-tilt-predictor-dropout', default=0.1, type=float,
help='Pitch probability for energy predictor')
spectral_pred.add_argument('--spectral-tilt-predictor-n-layers', default=2, type=int,
help='Number of conv-1D layers')
spectral_pred.add_argument('--no-spectral-predictors', default=6, type=int,
help='6 if only one of surface or source, 12 if both')

cond = parser.add_argument_group('conditioning parameters')
cond.add_argument('--pitch-embedding-kernel-size', default=3, type=int,
help='Pitch embedding conv-1D kernel size')
cond.add_argument('--energy-embedding-kernel-size', default=3, type=int,
help='Pitch embedding conv-1D kernel size')
help='Energy embedding conv-1D kernel size')
cond.add_argument('--spectral-tilt-embedding-kernel-size', default=3, type=int,
help='Spectral tilt embedding conv-1D kernel size')
cond.add_argument('--speaker-emb-weight', type=float, default=1.0,
help='Scale speaker embedding')

Expand Down
77 changes: 67 additions & 10 deletions PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
# *****************************************************************************

import functools
import json
import re
from pathlib import Path

import librosa
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from scipy import ndimage
from scipy.stats import betabinom

Expand Down Expand Up @@ -153,6 +153,7 @@ def __init__(self,
betabinomial_online_dir=None,
use_betabinomial_interpolator=True,
pitch_online_method='pyin',
include_tilt=None,
**ignored):

# Expect a list of filenames
Expand Down Expand Up @@ -204,6 +205,8 @@ def __init__(self,
self.pitch_mean = to_tensor(pitch_mean)
self.pitch_std = to_tensor(pitch_std)

self.spectral_tilt_features = include_tilt

def __getitem__(self, index):
# Separate filename and text
if self.n_speakers > 1:
Expand All @@ -217,6 +220,10 @@ def __getitem__(self, index):
text = self.get_text(text)
pitch = self.get_pitch(index, mel.size(-1))
energy = torch.norm(mel.float(), dim=0, p=2)
if self.spectral_tilt_features:
spectral_tilt = self.get_spectral_tilt(mel, audiopath, self.spectral_tilt_features)
else:
spectral_tilt = None
attn_prior = self.get_prior(index, mel.shape[1], text.shape[0])

assert pitch.size(-1) == mel.size(-1)
Expand All @@ -225,7 +232,7 @@ def __getitem__(self, index):
if len(pitch.size()) == 1:
pitch = pitch[None, :]

return (text, mel, len(text), pitch, energy, speaker, attn_prior,
return (text, mel, len(text), pitch, energy, spectral_tilt, speaker, attn_prior,
audiopath)

def __len__(self):
Expand Down Expand Up @@ -324,6 +331,45 @@ def get_pitch(self, index, mel_len=None):

return pitch_mel

def get_spectral_tilt(self, mels, audio_path, tilt_features='both'):
# # plot these during development
# fig, axes = plt.subplots(2, 1, squeeze=False)
# titles = ["Mel Spectrogram", "One Slice"]
# axes[0][0].imshow(mels, origin="lower")
# axes[0][0].set_aspect(2.5, adjustable="box")
# axes[0][0].set_ylim(0, mels.shape[0])
# axes[0][0].set_title(titles[0], fontsize="medium")
# axes[0][0].tick_params(labelsize="x-small", left=False,
# labelleft=False)
# axes[0][0].set_anchor("W")
# one_slice = mels[:, 200]
# axes[1][0].plot(one_slice)
# axes[1][0].set_ylim(min(one_slice) - 0.5, max(one_slice) + 0.5)
# axes[1][0].set_title(titles[0], fontsize="medium")
# axes[1][0].tick_params(labelsize="x-small", left=False,
# labelleft=False)
# axes[1][0].set_anchor("W")
#
# plt.show()
n_mels = mels.size(0)

# surface tilt
# shape 6 x input_frames
poly_coefficients = np.polynomial.polynomial.polyfit(np.arange(1, n_mels + 1), mels, 5)
if tilt_features == 'both' or tilt_features == 'source':
# TODO: remove hardcoded path
audio_path = re.sub('/wavs/', '/iaif_gci_wavs/', audio_path)
iaif_mels = self.get_mel(audio_path)
iaif_poly_coefficients = np.polynomial.polynomial.polyfit(
np.arange(1, n_mels + 1), iaif_mels, 5)
if tilt_features == 'source':
return torch.FloatTensor(iaif_poly_coefficients)
# shape 12 x input_frames
return torch.cat([torch.FloatTensor(poly_coefficients),
torch.FloatTensor(iaif_poly_coefficients)])
else:
return torch.FloatTensor(poly_coefficients)


class TTSCollate:
"""Zero-pads model inputs and targets based on number of frames per step"""
Expand Down Expand Up @@ -355,45 +401,54 @@ def __call__(self, batch):
mel_padded[i, :, :mel.size(1)] = mel
output_lengths[i] = mel.size(1)

n_formants = batch[0][3].shape[0]
n_formants = batch[0][3].shape[0] # default 1
pitch_padded = torch.zeros(mel_padded.size(0), n_formants,
mel_padded.size(2), dtype=batch[0][3].dtype)
energy_padded = torch.zeros_like(pitch_padded[:, 0, :])

for i in range(len(ids_sorted_decreasing)):
pitch = batch[ids_sorted_decreasing[i]][3]
energy = batch[ids_sorted_decreasing[i]][4]
spectral_tilt = batch[ids_sorted_decreasing[i]][5]
pitch_padded[i, :, :pitch.shape[1]] = pitch
energy_padded[i, :energy.shape[0]] = energy

if batch[0][5] is not None:
if batch[0][5] is not None: # if None, there is no spectral tilt
num_coefficients = batch[0][5].size(0)
spectral_tilt_padded = torch.FloatTensor(len(batch), num_coefficients, max_target_len).zero_()
for i in range(len(ids_sorted_decreasing)):
spectral_tilt_padded[i, :, :spectral_tilt.shape[1]] = spectral_tilt
else:
spectral_tilt_padded = None

if batch[0][6] is not None:
speaker = torch.zeros_like(input_lengths)
for i in range(len(ids_sorted_decreasing)):
speaker[i] = batch[ids_sorted_decreasing[i]][5]
speaker[i] = batch[ids_sorted_decreasing[i]][6]
else:
speaker = None

attn_prior_padded = torch.zeros(len(batch), max_target_len,
max_input_len)
attn_prior_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
prior = batch[ids_sorted_decreasing[i]][6]
prior = batch[ids_sorted_decreasing[i]][7]
attn_prior_padded[i, :prior.size(0), :prior.size(1)] = prior

# Count number of items - characters in text
len_x = [x[2] for x in batch]
len_x = torch.Tensor(len_x)

audiopaths = [batch[i][7] for i in ids_sorted_decreasing]
audiopaths = [batch[i][8] for i in ids_sorted_decreasing]

return (text_padded, input_lengths, mel_padded, output_lengths, len_x,
pitch_padded, energy_padded, speaker, attn_prior_padded,
pitch_padded, energy_padded, spectral_tilt_padded, speaker, attn_prior_padded,
audiopaths)


def batch_to_gpu(batch):
(text_padded, input_lengths, mel_padded, output_lengths, len_x,
pitch_padded, energy_padded, speaker, attn_prior, audiopaths) = batch
pitch_padded, energy_padded, spectral_tilt_padded, speaker, attn_prior, audiopaths) = batch

text_padded = to_gpu(text_padded).long()
input_lengths = to_gpu(input_lengths).long()
Expand All @@ -402,12 +457,14 @@ def batch_to_gpu(batch):
pitch_padded = to_gpu(pitch_padded).float()
energy_padded = to_gpu(energy_padded).float()
attn_prior = to_gpu(attn_prior).float()
if spectral_tilt_padded is not None:
spectral_tilt_padded = to_gpu(spectral_tilt_padded).float()
if speaker is not None:
speaker = to_gpu(speaker).long()

# Alignments act as both inputs and targets - pass shallow copies
x = [text_padded, input_lengths, mel_padded, output_lengths,
pitch_padded, energy_padded, speaker, attn_prior, audiopaths]
pitch_padded, energy_padded, spectral_tilt_padded, speaker, attn_prior, audiopaths]
y = [mel_padded, input_lengths, output_lengths]
len_x = torch.sum(output_lengths)
return (x, y, len_x)
29 changes: 22 additions & 7 deletions PyTorch/SpeechSynthesis/FastPitch/fastpitch/loss_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,24 @@
class FastPitchLoss(nn.Module):
def __init__(self, dur_predictor_loss_scale=1.0,
pitch_predictor_loss_scale=1.0, attn_loss_scale=1.0,
energy_predictor_loss_scale=0.1):
energy_predictor_loss_scale=0.1,
spectral_tilt_predictor_loss_scale=0.1):
super(FastPitchLoss, self).__init__()
self.dur_predictor_loss_scale = dur_predictor_loss_scale
self.pitch_predictor_loss_scale = pitch_predictor_loss_scale
self.energy_predictor_loss_scale = energy_predictor_loss_scale
self.spectral_tilt_predictor_loss_scale = spectral_tilt_predictor_loss_scale
self.attn_loss_scale = attn_loss_scale
self.attn_ctc_loss = AttentionCTCLoss()

def forward(self, model_out, targets, is_training=True, meta_agg='mean'):
# (mel_out, dec_mask, dur_pred, log_dur_pred,
# pitch_pred, pitch_tgt, energy_pred, energy_tgt,
# spectral_tilt_pred, spectral_tilt_tgt,
# attn_soft, attn_hard, attn_hard_dur, attn_logprob)
(mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, pitch_tgt,
energy_pred, energy_tgt, attn_soft, attn_hard, attn_dur,
attn_logprob) = model_out
energy_pred, energy_tgt, spectral_tilt_pred, spectral_tilt_tgt,
attn_soft, attn_hard, attn_dur, attn_logprob) = model_out

(mel_tgt, in_lens, out_lens) = targets

Expand Down Expand Up @@ -81,7 +87,16 @@ def forward(self, model_out, targets, is_training=True, meta_agg='mean'):
energy_loss = F.mse_loss(energy_tgt, energy_pred, reduction='none')
energy_loss = (energy_loss * dur_mask).sum() / dur_mask.sum()
else:
energy_loss = 0
energy_loss = torch.tensor(0, dtype=torch.int8).to(device='cuda')

if spectral_tilt_pred is not None:
spectral_tilt_pred = spectral_tilt_pred.permute(0, 2, 1)
spectral_tilt_pred = F.pad(spectral_tilt_pred, (0, ldiff, 0, 0), value=0.0)
spectral_tilt_loss = F.mse_loss(spectral_tilt_tgt, spectral_tilt_pred, reduction='none')
spectral_dur_mask = torch.repeat_interleave(dur_mask[:, None, :], spectral_tilt_loss.size(1), dim=1)
spectral_tilt_loss = (spectral_tilt_loss * spectral_dur_mask).sum() / spectral_dur_mask.sum()
else:
spectral_tilt_loss = torch.tensor(0, dtype=torch.int8).to(device='cuda')

# Attention loss
attn_loss = self.attn_ctc_loss(attn_logprob, in_lens, out_lens)
Expand All @@ -90,21 +105,21 @@ def forward(self, model_out, targets, is_training=True, meta_agg='mean'):
+ dur_pred_loss * self.dur_predictor_loss_scale
+ pitch_loss * self.pitch_predictor_loss_scale
+ energy_loss * self.energy_predictor_loss_scale
+ spectral_tilt_loss * self.spectral_tilt_predictor_loss_scale
+ attn_loss * self.attn_loss_scale)

meta = {
'loss': loss.clone().detach(),
'mel_loss': mel_loss.clone().detach(),
'duration_predictor_loss': dur_pred_loss.clone().detach(),
'pitch_loss': pitch_loss.clone().detach(),
'energy_loss': energy_loss.clone().detach(),
'spectral_tilt_loss': spectral_tilt_loss.clone().detach(),
'attn_loss': attn_loss.clone().detach(),
'dur_error': (torch.abs(dur_pred - dur_tgt).sum()
/ dur_mask.sum()).detach(),
}

if energy_pred is not None:
meta['energy_loss'] = energy_loss.clone().detach()

assert meta_agg in ('sum', 'mean')
if meta_agg == 'sum':
bsz = mel_out.size(0)
Expand Down
Loading