diff --git a/inference_tts_scale.py b/inference_tts_scale.py index 2ebb78c..b79ee51 100644 --- a/inference_tts_scale.py +++ b/inference_tts_scale.py @@ -4,6 +4,7 @@ import numpy as np import torch import torchaudio +import psutil from data.tokenizer import ( AudioTokenizer, @@ -40,7 +41,7 @@ def get_args(): @torch.no_grad() -def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame): +def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame, half=False): # phonemize text_tokens = [phn2num[phn] for phn in tokenize_text( @@ -49,6 +50,7 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token ] text_tokens = torch.LongTensor(text_tokens).unsqueeze(0) text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]]) + print("finished phonemize") # encode audio encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame) @@ -56,12 +58,19 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.") + process = psutil.Process() + print(f"finished encode; memory usage: {process.memory_info().rss}") + + text_tokens = text_tokens.to(device) + if half: + text_tokens = text_tokens.half() + # forward stime = time.time() if decode_config['sample_batch_size'] <= 1: logging.info(f"running inference with batch size 1") concat_frames, gen_frames = model.inference_tts( - text_tokens.to(device), + text_tokens, text_tokens_lens.to(device), original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] top_k=decode_config['top_k'], @@ -74,7 +83,7 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token else: logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.") concat_frames, gen_frames = model.inference_tts_batch( - text_tokens.to(device), + text_tokens, text_tokens_lens.to(device), original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] top_k=decode_config['top_k'], @@ -85,6 +94,9 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token batch_size = decode_config['sample_batch_size'], silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens'] ) # output is [1,K,T] + + print("finished forward pass") + logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.") logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.") diff --git a/memtest.py b/memtest.py new file mode 100644 index 0000000..3eb2f04 --- /dev/null +++ b/memtest.py @@ -0,0 +1,171 @@ +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["USER"] = "neow" # TODO change this to your username + +import torch +import torchaudio +import numpy as np +import random + +from data.tokenizer import ( + AudioTokenizer, + TextTokenizer, +) + +import subprocess as sp +import os + +def get_gpu_memory(): + command = "nvidia-smi --query-gpu=memory.free --format=csv" + memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:] + memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] + print(memory_free_values) + +get_gpu_memory() + +if __name__ == "__main__": + # load model, encodec, and phn2num + # # load model, tokenizer, and other necessary files + device = "cuda" if torch.cuda.is_available() else "cpu" + from models import voicecraft + + # import models.voicecraft as voicecraft + voicecraft_name = "giga330M.pth" # or giga330M.pth + ckpt_fn = f"./pretrained_models/{voicecraft_name}" + encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th" + if not os.path.exists(ckpt_fn): + os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true") + os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}") + if not os.path.exists(encodec_fn): + os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th") + os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th") + + ckpt = torch.load(ckpt_fn, map_location="cpu") + model = voicecraft.VoiceCraft(ckpt["config"]) + model.load_state_dict(ckpt["model"]) + model.to(device) + # model.half() + model.eval() + + print("loaded model") + get_gpu_memory() + + phn2num = ckpt['phn2num'] + + text_tokenizer = TextTokenizer(backend="espeak") + audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu + + # %% + + # Prepare your audio + # point to the original audio whose speech you want to clone + # write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file + orig_audio = "./demo/84_121550_000074_000000.wav" + orig_transcript = "But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks," + + # move the audio and transcript to temp folder + temp_folder = "./demo/temp" + os.makedirs(temp_folder, exist_ok=True) + os.system(f"cp {orig_audio} {temp_folder}") + filename = os.path.splitext(orig_audio.split("/")[-1])[0] + with open(f"{temp_folder}/{filename}.txt", "w") as f: + f.write(orig_transcript) + # run MFA to get the alignment + align_temp = f"{temp_folder}/mfa_alignments" + + # # if the above fails, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue + # !source ~/.bashrc && \ + # conda activate voicecraft && \ + # mfa align -v --clean -j 1 --output_format csv {temp_folder} \ + # english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000 + + + # take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt + cut_off_sec = 7.0 # NOTE: according to forced-alignment file demo/temp/mfa_alignments/84_121550_000074_000000.csv, the word "common" stop as 3.01 sec, this should be different for different audio + target_transcript = "But when I had approached so near to them The common I cannot believe that the same model can also do text to speech synthesis as well! I love shuffle 512 and janise" + # NOTE: 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec. + audio_fn = f"{temp_folder}/{filename}.wav" + info = torchaudio.info(audio_fn) + audio_dur = info.num_frames / info.sample_rate + + assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" + prompt_end_frame = int(cut_off_sec * info.sample_rate) + + # run the model to get the output + # hyperparameters for inference + codec_audio_sr = 16000 + codec_sr = 50 + top_k = 0 + top_p = 0.8 + temperature = 1 + silence_tokens = [1388, 1898, 131] + kvcache = 0 # NOTE if OOM, change this to 0, or try the 330M model + + # NOTE adjust the below three arguments if the generation is not as good + stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1 + sample_batch_size = 1 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number. + seed = 1 # change seed if you are still unhappy with the result + + + def seed_everything(seed): + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + + seed_everything(seed) + + decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, + 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr, + "silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size} + from inference_tts_scale import inference_one_sample + + print("before inference") + get_gpu_memory() + + concated_audio, gen_audio = inference_one_sample(model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer, + audio_fn, target_transcript, device, decode_config, prompt_end_frame, False) + print("after inference") + get_gpu_memory() + + # save segments for comparison + concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() + # logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}") + + + # display the audio + # from IPython.display import Audio + # + # print("concatenate prompt and generated:") + # display(Audio(concated_audio, rate=codec_audio_sr)) + # + # print("generated:") + # display(Audio(gen_audio, rate=codec_audio_sr)) + + # # save the audio + # # output_dir + output_dir = "/home/pyp/VoiceCraft/demo/generated_tts" + os.makedirs(output_dir, exist_ok=True) + seg_save_fn_gen = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav" + seg_save_fn_concat = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav" + + torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr) + torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr) + + print("finished running") + + # if you get error importing T5 in transformers + # try + # pip uninstall Pillow + # pip install Pillow + # you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored + + # %% + + diff --git a/models/voicecraft.py b/models/voicecraft.py index 8d83729..ee94c80 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -109,15 +109,15 @@ def __init__(self, args): self.text_embedding = TokenEmbedding( dim_model=self.args.d_model, - vocab_size=self.n_text_tokens, + vocab_size=self.n_text_tokens, dropout=self.args.text_embedding_dropout ) self.audio_embedding = nn.ModuleList( [ TokenEmbedding( - dim_model=self.args.audio_embedding_dim, - vocab_size=self.n_audio_tokens[k], + dim_model=self.args.audio_embedding_dim, + vocab_size=self.n_audio_tokens[k], dropout=self.args.audio_embedding_dropout ) for k in range(self.args.n_codebooks) ] @@ -150,13 +150,13 @@ def __init__(self, args): num_layers=self.args.num_decoder_layers, norm=LayerNorm(self.args.d_model), ) - + self.predict_layer = nn.ModuleList( [ nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks) ] ) - + self.accuracy_metrics = nn.ModuleList( [MulticlassAccuracy( self.n_audio_tokens[k], @@ -167,7 +167,7 @@ def __init__(self, args): ) for k in range(self.args.n_codebooks)] ) - + def prepare_mask_intervals(self, y_lens): mask_intervals = [] non_mask_intervals = [] @@ -203,12 +203,12 @@ def prepare_mask_intervals(self, y_lens): temp_mask_end = gap - 1 mask_len = random.randint(temp_mask_start, temp_mask_end) ends.append(start + mask_len) - + mask_intervals.append([(s,e) for s,e in zip(starts, ends)]) non_mask_intervals.append([(ns,ne) for ns, ne in zip([0]+ends, starts+[y_len])]) return mask_intervals, non_mask_intervals - + def rearrange(self, y, non_mask_intervals, mask_intervals): reduced_eog = getattr(self.args, "reduced_eog", 0) rearranged_y = [] @@ -223,7 +223,7 @@ def rearrange(self, y, non_mask_intervals, mask_intervals): cur_y = [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in non_mask_intervals[i]] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment) rearranged_y.append(cur_y) return rearranged_y - + def shift(self, rearranged_y): shifted_y = [] patterns = [] @@ -233,7 +233,7 @@ def shift(self, rearranged_y): shifted_y.append([item[0].squeeze(0) for item in out]) # the first item is values, later two are indexes and mask patterns.append(cur_patterns) return shifted_y, patterns - + def insert_mask(self, shifted_y): inserted_y = [] mask_position = [] @@ -259,7 +259,7 @@ def insert_mask(self, shifted_y): inserted_y.append(cur_inserted_y) mask_position.append(cur_mask_position) return inserted_y, mask_position, mask_value - + def cat_y(self, inserted_y, mask_position, y_lens): reduced_eog = getattr(self.args, "reduced_eog", 0) cated_y = [] @@ -289,9 +289,9 @@ def embed_y(self, cated_y, mask_position, mask_value): embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D] for i in range(len(embedded_y)): if len(mask_position[i]) > 0: - embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]] + embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]] return embedded_y - + def prepare_input_target(self, y, y_lens): # rearrange y # assume y shape: [B T K], K is n_codebooks @@ -328,16 +328,16 @@ def prepare_input_target(self, y, y_lens): inserted_y, mask_position, mask_value = self.insert_mask(shifted_y) assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0] assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}" - + # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y))) - + # embed remember to separately embed the mask tokens embedded_y = self.embed_y(cated_y, mask_position, mask_value) #BTD assert embedded_y.shape[1:] == torch.Size((max(new_y_lens), self.args.d_model)), embedded_y.shape - + # positional embedding y_input = self.audio_positional_embedding(embedded_y) @@ -354,9 +354,9 @@ def remove_mask(self, logits, mask_position, new_y_lens): non_mask_intervals = [[non_mask_positions[i]+1, non_mask_positions[i+1]] for i in range(len(non_mask_positions)-1)] cur_logits_use = [logits[i, :, l:r] for l,r in non_mask_intervals] logits_use.append(cur_logits_use) - + return logits_use - + def revert_pattern(self, patterns, logits_use): logits_final = [] logit_masks = [] @@ -376,9 +376,10 @@ def revert_pattern(self, patterns, logits_use): return logits_final, logit_masks + @torch.autocast(device_type="cuda", dtype=torch.float16) def dec_forward( - self, - x_input, + self, + x_input, x_lens, x_attention_mask, x_padding_mask, @@ -418,7 +419,8 @@ def dec_forward( xy_input = torch.cat([x_input, y_input], dim=1) if past == None: # do not use kvcache - out, _ = self.decoder((xy_input, None), mask=xy_attn_mask) + out, _ = self.decoder((xy_input, None), mask=xy_attn_mask) + # out = out.half() # TODO: make this an option => only on if dtype = float16 return out[:, x_lens.max():], None else: # use kvcache if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet @@ -438,6 +440,7 @@ def dec_forward( else: # used kvcache return out, present + @torch.autocast(device_type="cuda", dtype=torch.float16) def forward(self, batch): """ Args: @@ -467,7 +470,7 @@ def forward(self, batch): x_input = self.text_positional_embedding(x_input) y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns = self.prepare_input_target(y, y_lens) y_out = self.dec_forward( - x_input, + x_input, x_lens, x_attention_mask, x_padding_mask, @@ -478,13 +481,13 @@ def forward(self, batch): ) y_out = y_out[0] # no kv-caching during training assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D] - + logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card] # take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern) assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape logits_use = self.remove_mask(logits, mask_position, new_y_lens) - + # revert the pattern shift for each logits section in each sample logits_final, logit_masks = self.revert_pattern(patterns, logits_use) assert logits_final[0][0].shape[0] == self.args.n_codebooks and logits_final[0][0].shape[2] == self.n_audio_tokens[0], f"it is: {logits_final[0][0].shape}, but should be [K, T, card]" @@ -507,7 +510,7 @@ def forward(self, batch): loss.append(F.cross_entropy(logit, target, reduction='mean')) top10acc.append(self.accuracy_metrics[k](logit.detach(), target)) ntokens.append(len(logit)) - + all_ntokens = sum(ntokens) if self.args.codebook_weight != None: codebook_weight = eval(self.args.codebook_weight) @@ -524,7 +527,7 @@ def forward(self, batch): "top10acc_by_codebook": top10acc_by_codebook, "effective_ntoken": ntokens, } - + def inference( self, x: torch.Tensor, @@ -593,7 +596,7 @@ def inference( non_mask_intervals = [[ (ns, ne) for ns, ne in zip(ends, starts) ]] - + # rearrange y # will add have EOG in each section (SOG will be generated by the pattern class) # but mask can be inserted later after we have shifted the input @@ -625,7 +628,7 @@ def inference( inserted_y, mask_position, mask_value = self.insert_mask(shifted_y) assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0] assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}" - + # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y))) @@ -678,10 +681,10 @@ def inference( ##################### silence repetition handling ##################### # prepare the cache placeholder # n_layers, 2, bsz, num_heads, src_len, head_dim - past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None # handle multi-span kv-cache new_masked_span = False - + def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen): if n_eog == 0: logits_adjust = logits @@ -755,7 +758,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t while True: y_out, present = self.dec_forward( - x_input, + x_input, x_lens, x_attention_mask, x_padding_mask, @@ -835,7 +838,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device) y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) - + assert len(generated) == num_mask, f"len(generated): {len(generated)}, num_mask: {num_mask}" # # combine non_masked_span with generated spans @@ -866,7 +869,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t expected_y_len = y_len - sum([item[1] - item[0] for item in mask_intervals[0]]) + sum([item - self.args.n_codebooks for item in num_gen]) assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}" - + if self.args.special_first: res = res - int(self.args.n_special) @@ -947,7 +950,7 @@ def inference_tts( assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D] embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D] - + # positional embedding y_input = self.audio_positional_embedding(embedded_y) @@ -978,7 +981,7 @@ def inference_tts( # prepare the cache placeholder # n_layers, 2, bsz, num_heads, src_len, head_dim - past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") @@ -1034,7 +1037,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t return samples, codebook_eog, prev_token, consec_silence_count while True: y_out, present = self.dec_forward( - x_input, + x_input, x_lens, x_attention_mask, x_padding_mask, @@ -1058,9 +1061,9 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans for jj in range(self.args.n_codebooks): logits[jj][self.args.eog] = -10000. - + samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen) - + cur_num_gen += 1 cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K] @@ -1078,14 +1081,14 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t break else: assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}" - + embedded_y = torch.cat([embedded_y, samples_emb], dim=1) y_input = self.audio_positional_embedding(embedded_y) # [B T D] # make attention mask and padding mask y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device) y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) - + assert len(generated) == 1, f"len(generated): {len(generated)}" # revert the pattern @@ -1105,7 +1108,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t flatten_gen.append(unshifted_span) assert len(flatten_gen) == 1, len(flatten_gen) - + # combine res = [y[0], flatten_gen[0]] res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T] @@ -1197,7 +1200,7 @@ def inference_tts_batch( assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D] embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D] - + # positional embedding y_input = self.audio_positional_embedding(embedded_y) @@ -1228,7 +1231,7 @@ def inference_tts_batch( # prepare the cache placeholder # n_layers, 2, bsz, num_heads, src_len, head_dim - past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") @@ -1311,7 +1314,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t else: assert x_input.shape[0] == batch_size and x_padding_mask.shape[0] == batch_size and y_input.shape[0] == batch_size and new_y_lens.shape[0] == batch_size, f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}" y_out, present = self.dec_forward( - x_input, + x_input, x_lens, x_attention_mask, x_padding_mask, @@ -1337,7 +1340,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t for jj in range(self.args.n_codebooks): logits[:,jj,self.args.eog] = -10000. samples, codebook_eog, prev_tokens, consec_silence_counts, keep = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep) - + cur_num_gen += 1 if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples assert keep == None @@ -1347,7 +1350,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t assert keep != None cur_generated = cur_generated[keep] cur_generated.append(samples[keep].squeeze(-1)) - else: # we are generating the rest eogs for the 'keep' sample + else: # we are generating the rest eogs for the 'keep' sample cur_generated.append(samples[keep].squeeze(-1)) # samples.shape is [K,1] @@ -1364,14 +1367,14 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t break else: assert samples_emb.shape == torch.Size((batch_size,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}" - + embedded_y = torch.cat([embedded_y, samples_emb], dim=1) y_input = self.audio_positional_embedding(embedded_y) # [B T D] # make attention mask and padding mask y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size) y_padding_mask = torch.full((batch_size,new_y_lens[0]), False).to(y.device) - + assert len(generated) == 1, f"len(generated): {len(generated)}" # revert the pattern @@ -1391,7 +1394,7 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t flatten_gen.append(unshifted_span) assert len(flatten_gen) == 1, len(flatten_gen) - + # combine res = [y[0], flatten_gen[0]] res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]