-
Notifications
You must be signed in to change notification settings - Fork 18
conditioning_not_working #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
07cae0f
89fbe48
40a93c2
878115a
5e54a81
67097fc
2d0a53f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -136,6 +136,7 @@ def __init__(self, | |
| symbol_set='english_basic', | ||
| p_arpabet=1.0, | ||
| n_speakers=1, | ||
| n_conditions=1, | ||
| load_mel_from_disk=True, | ||
| load_pitch_from_disk=True, | ||
| pitch_mean=214.72203, # LJSpeech defaults | ||
|
|
@@ -160,9 +161,10 @@ def __init__(self, | |
| audiopaths_and_text = [audiopaths_and_text] | ||
|
|
||
| self.dataset_path = dataset_path | ||
| # this now returns a list of dicts | ||
| self.audiopaths_and_text = load_filepaths_and_text( | ||
| audiopaths_and_text, dataset_path, | ||
| has_speakers=(n_speakers > 1)) | ||
| has_speakers=(n_speakers > 1), has_conditions=(n_conditions > 1)) | ||
| self.load_mel_from_disk = load_mel_from_disk | ||
| if not load_mel_from_disk: | ||
| self.max_wav_value = max_wav_value | ||
|
|
@@ -181,6 +183,7 @@ def __init__(self, | |
|
|
||
| self.tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet) | ||
| self.n_speakers = n_speakers | ||
| self.n_conditions = n_conditions | ||
| self.pitch_tmp_dir = pitch_online_dir | ||
| self.f0_method = pitch_online_method | ||
| self.betabinomial_tmp_dir = betabinomial_online_dir | ||
|
|
@@ -189,13 +192,13 @@ def __init__(self, | |
| if use_betabinomial_interpolator: | ||
| self.betabinomial_interpolator = BetaBinomialInterpolator() | ||
|
|
||
| expected_columns = (2 + int(load_pitch_from_disk) + (n_speakers > 1)) | ||
|
|
||
| expected_columns = (2 + int(load_pitch_from_disk) + (n_speakers > 1) + (n_conditions > 1)) | ||
| print('EXPECTED COLUMNS IS ' + str(expected_columns)) | ||
| assert not (load_pitch_from_disk and self.pitch_tmp_dir is not None) | ||
|
|
||
| if len(self.audiopaths_and_text[0]) < expected_columns: | ||
| raise ValueError(f'Expected {expected_columns} columns in audiopaths file. ' | ||
| 'The format is <mel_or_wav>|[<pitch>|]<text>[|<speaker_id>]') | ||
| 'The format is <mel_or_wav>|[<pitch>|]<text>[|<speaker_id>|<condition_id>]') | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, I guess we're checking it here? |
||
|
|
||
| if len(self.audiopaths_and_text[0]) > expected_columns: | ||
| print('WARNING: Audiopaths file has more columns than expected') | ||
|
|
@@ -205,16 +208,26 @@ def __init__(self, | |
| self.pitch_std = to_tensor(pitch_std) | ||
|
|
||
| def __getitem__(self, index): | ||
| # Separate filename and text | ||
| # Indexing items using dictionary entries | ||
| audiopath = self.audiopaths_and_text[index]['mels'] | ||
| text = self.audiopaths_and_text[index]['text'] | ||
| speaker = None | ||
| condition = None | ||
| if self.n_speakers > 1: | ||
| audiopath, *extra, text, speaker = self.audiopaths_and_text[index] | ||
| speaker = int(speaker) | ||
| else: | ||
| audiopath, *extra, text = self.audiopaths_and_text[index] | ||
| speaker = None | ||
| speaker = int(self.audiopaths_and_text[index]['speaker']) | ||
| if self.n_conditions > 1: | ||
| cond = self.audiopaths_and_text[index]['condition'] | ||
| if cond is None or cond == 'None': | ||
| print(audiopath, text, self.audiopaths_and_text[index]) | ||
| condition = int(self.audiopaths_and_text[index]['condition']) | ||
|
|
||
| mel = self.get_mel(audiopath) | ||
| if mel.size(1) > 700: | ||
| print('MEL LEN: ', mel.size(), audiopath) | ||
| text = self.get_text(text) | ||
| length = len(text) | ||
| if length >= 130: | ||
| print('LENGTH: ', len(text), audiopath) | ||
| pitch = self.get_pitch(index, mel.size(-1)) | ||
| energy = torch.norm(mel.float(), dim=0, p=2) | ||
| attn_prior = self.get_prior(index, mel.shape[1], text.shape[0]) | ||
|
|
@@ -226,7 +239,7 @@ def __getitem__(self, index): | |
| pitch = pitch[None, :] | ||
|
|
||
| return (text, mel, len(text), pitch, energy, speaker, attn_prior, | ||
| audiopath) | ||
| audiopath, condition) | ||
|
|
||
| def __len__(self): | ||
| return len(self.audiopaths_and_text) | ||
|
|
@@ -287,15 +300,15 @@ def get_prior(self, index, mel_len, text_len): | |
| return attn_prior | ||
|
|
||
| def get_pitch(self, index, mel_len=None): | ||
| audiopath, *fields = self.audiopaths_and_text[index] | ||
| audiopath = self.audiopaths_and_text[index]['mels'] | ||
|
|
||
| # why do we need the speaker here? | ||
| spk = 0 | ||
| if self.n_speakers > 1: | ||
| spk = int(fields[-1]) | ||
| else: | ||
| spk = 0 | ||
| spk = int(self.audiopaths_and_text[index]['speaker']) | ||
|
|
||
| if self.load_pitch_from_disk: | ||
| pitchpath = fields[0] | ||
| pitchpath = self.audiopaths_and_text[index]['pitch'] | ||
| pitch = torch.load(pitchpath) | ||
| if self.pitch_mean is not None: | ||
| assert self.pitch_std is not None | ||
|
|
@@ -386,14 +399,21 @@ def __call__(self, batch): | |
|
|
||
| audiopaths = [batch[i][7] for i in ids_sorted_decreasing] | ||
|
|
||
| if batch[0][8] is not None: | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I imagine this is the bit that would need updating once the other code is merged? |
||
| condition = torch.zeros_like(input_lengths) | ||
| for i in range(len(ids_sorted_decreasing)): | ||
| condition[i] = batch[ids_sorted_decreasing[i]][8] | ||
| else: | ||
| condition = None | ||
|
|
||
| return (text_padded, input_lengths, mel_padded, output_lengths, len_x, | ||
| pitch_padded, energy_padded, speaker, attn_prior_padded, | ||
| audiopaths) | ||
| audiopaths, condition) | ||
|
|
||
|
|
||
| def batch_to_gpu(batch): | ||
| (text_padded, input_lengths, mel_padded, output_lengths, len_x, | ||
| pitch_padded, energy_padded, speaker, attn_prior, audiopaths) = batch | ||
| pitch_padded, energy_padded, speaker, attn_prior, audiopaths, condition) = batch | ||
|
|
||
| text_padded = to_gpu(text_padded).long() | ||
| input_lengths = to_gpu(input_lengths).long() | ||
|
|
@@ -404,10 +424,13 @@ def batch_to_gpu(batch): | |
| attn_prior = to_gpu(attn_prior).float() | ||
| if speaker is not None: | ||
| speaker = to_gpu(speaker).long() | ||
| if condition is not None: | ||
| condition = to_gpu(condition).long() | ||
|
|
||
| # Alignments act as both inputs and targets - pass shallow copies | ||
| x = [text_padded, input_lengths, mel_padded, output_lengths, | ||
| pitch_padded, energy_padded, speaker, attn_prior, audiopaths] | ||
| pitch_padded, energy_padded, speaker, attn_prior, audiopaths, condition] | ||
| y = [mel_padded, input_lengths, output_lengths] | ||
| len_x = torch.sum(output_lengths) | ||
| return (x, y, len_x) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -126,7 +126,7 @@ def __init__(self, n_mel_channels, n_symbols, padding_idx, | |
| energy_predictor_kernel_size, energy_predictor_filter_size, | ||
| p_energy_predictor_dropout, energy_predictor_n_layers, | ||
| energy_embedding_kernel_size, | ||
| n_speakers, speaker_emb_weight, pitch_conditioning_formants=1): | ||
| n_speakers, speaker_emb_weight, n_conditions, condition_emb_weight, pitch_conditioning_formants=1): | ||
| super(FastPitch, self).__init__() | ||
|
|
||
| self.encoder = FFTransformer( | ||
|
|
@@ -149,6 +149,14 @@ def __init__(self, n_mel_channels, n_symbols, padding_idx, | |
| self.speaker_emb = None | ||
| self.speaker_emb_weight = speaker_emb_weight | ||
|
|
||
| #Have to figure out what symbols_embedding_dim is | ||
| if n_conditions > 1: | ||
| self.condition_emb = nn.Embedding(n_conditions, symbols_embedding_dim) | ||
| else: | ||
| self.condition_emb = None | ||
| self.condition_emb_weight = condition_emb_weight | ||
|
|
||
|
|
||
| self.duration_predictor = TemporalPredictor( | ||
| in_fft_output_size, | ||
| filter_size=dur_predictor_filter_size, | ||
|
|
@@ -242,7 +250,7 @@ def binarize_attention_parallel(self, attn, in_lens, out_lens): | |
| def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): | ||
|
|
||
| (inputs, input_lens, mel_tgt, mel_lens, pitch_dense, energy_dense, | ||
| speaker, attn_prior, audiopaths) = inputs | ||
| speaker, attn_prior, audiopaths, condition) = inputs | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as a side-note I am wondring if we should make inputs/outputs some enum or datatype that also doesn't rely on indices to get different things out |
||
|
|
||
| mel_max_len = mel_tgt.size(2) | ||
|
|
||
|
|
@@ -253,8 +261,15 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): | |
| spk_emb = self.speaker_emb(speaker).unsqueeze(1) | ||
| spk_emb.mul_(self.speaker_emb_weight) | ||
|
|
||
| # Calculate discrete condition embedding | ||
| if self.condition_emb is None: | ||
| cond_emb = 0 | ||
| else: | ||
| cond_emb = self.condition_emb(condition).unsqueeze(1) | ||
| cond_emb.mul_(self.condition_emb_weight) | ||
|
|
||
| # Input FFT | ||
| enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb) | ||
| enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb, conditioning_2=cond_emb) #need to add condition conditioning here | ||
|
|
||
| # Alignment | ||
| text_emb = self.encoder.word_emb(inputs) | ||
|
|
@@ -281,7 +296,7 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): | |
| dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration) | ||
|
|
||
| # Predict pitch | ||
| pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1) | ||
| pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1) #maybe we want to condition pitch prediction on the conditioning parameter. | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool idea, we should make a ticket for this |
||
|
|
||
| # Average pitch over characters | ||
| pitch_tgt = average_pitch(pitch_dense, dur_tgt) | ||
|
|
@@ -290,7 +305,7 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): | |
| pitch_emb = self.pitch_emb(pitch_tgt) | ||
| else: | ||
| pitch_emb = self.pitch_emb(pitch_pred) | ||
| enc_out = enc_out + pitch_emb.transpose(1, 2) | ||
| enc_out = enc_out + pitch_emb.transpose(1, 2) #Adding with encoder output | ||
|
|
||
| # Predict energy | ||
| if self.energy_conditioning: | ||
|
|
@@ -302,13 +317,13 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): | |
|
|
||
| energy_emb = self.energy_emb(energy_tgt) | ||
| energy_tgt = energy_tgt.squeeze(1) | ||
| enc_out = enc_out + energy_emb.transpose(1, 2) | ||
| enc_out = enc_out + energy_emb.transpose(1, 2) #adding to encoder output | ||
| else: | ||
| energy_pred = None | ||
| energy_tgt = None | ||
|
|
||
| len_regulated, dec_lens = regulate_len( | ||
| dur_tgt, enc_out, pace, mel_max_len) | ||
| dur_tgt, enc_out, pace, mel_max_len) #upsampling | ||
|
|
||
| # Output FFT | ||
| dec_out, dec_mask = self.decoder(len_regulated, dec_lens) | ||
|
|
@@ -319,7 +334,7 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): | |
|
|
||
| def infer(self, inputs, pace=1.0, dur_tgt=None, pitch_tgt=None, | ||
| energy_tgt=None, pitch_transform=None, max_duration=75, | ||
| speaker=0): | ||
| speaker=0, condition=0): | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so because this is the condition index, the default is 0 (despite no condition being n_conditions = 1 and so far there only being able to be 1 condition?) |
||
|
|
||
| if self.speaker_emb is None: | ||
| spk_emb = 0 | ||
|
|
@@ -329,8 +344,16 @@ def infer(self, inputs, pace=1.0, dur_tgt=None, pitch_tgt=None, | |
| spk_emb = self.speaker_emb(speaker).unsqueeze(1) | ||
| spk_emb.mul_(self.speaker_emb_weight) | ||
|
|
||
| if self.condition_emb is None: | ||
| cond_emb = 0 | ||
| else: | ||
| condition = (torch.ones(inputs.size(0)).long().to(inputs.device) | ||
| * condition) | ||
| cond_emb = self.condition_emb(condition).unsqueeze(1) | ||
| cond_emb.mul_(self.condition_emb_weight) | ||
|
|
||
| # Input FFT | ||
| enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb) | ||
| enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb, conditioning_2=cond_emb) #need to add conditioning here but will it take list? | ||
|
|
||
| # Predict durations | ||
| log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -192,7 +192,7 @@ def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size, | |
| dropatt=dropatt, pre_lnorm=pre_lnorm) | ||
| ) | ||
|
|
||
| def forward(self, dec_inp, seq_lens=None, conditioning=0): | ||
| def forward(self, dec_inp, seq_lens=None, conditioning=0, conditioning_2=0): #here when called we add speaker or other discrete condition | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could make condition a tuple, or rename the conditionings to conditioning_speaker conditioning_other |
||
| if self.word_emb is None: | ||
| inp = dec_inp | ||
| mask = mask_from_lens(seq_lens).unsqueeze(2) | ||
|
|
@@ -204,7 +204,7 @@ def forward(self, dec_inp, seq_lens=None, conditioning=0): | |
| pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype) | ||
| pos_emb = self.pos_emb(pos_seq) * mask | ||
|
|
||
| out = self.drop(inp + pos_emb + conditioning) | ||
| out = self.drop(inp + pos_emb + conditioning + conditioning_2) # so here we add more conditioning | ||
|
|
||
| for layer in self.layers: | ||
| out = layer(out, mask=mask) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default is 1 but to have conditions there should be more than 1? If 1 is a way of saying there are no conditions, why not 0?: