Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does not work on macOS with device="mps": "Can't infer missing attention mask on mps device" #148

Open
ChristianWeyer opened this issue Oct 14, 2024 · 25 comments

Comments

@ChristianWeyer
Copy link

This is my simple test script:

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

torch_device = "mps:0"
torch_dtype = torch.bfloat16
model_name = "parler-tts/parler-tts-mini-v1"

attn_implementation = "eager" # "sdpa" or "flash_attention_2"

model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
    attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Hey, how are you doing today?"
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(torch_device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device)

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()

sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

I get this error:

ValueError: Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device.

Any idea what could be wrong?
Thanks!

@tulas75
Copy link

tulas75 commented Oct 16, 2024

same problem for me.
I tried to use transformers version 4.44.2 (not supported by parler-tts) and it seems to use GPU but at the end saving the wav file, I get an error.

NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

@ylacombe
Copy link
Collaborator

On the last commit pushed, I've bumped the transformers version, could you try again, after having installed again from scratch

@tulas75
Copy link

tulas75 commented Nov 2, 2024

Same error even with transformers 4.46.1
NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS

There's a weird thing. I tried to set torch_dtype = torch.bfloat16 but when I run the code (same code from @ChristianWeyer ) I got the following logs. It seems it uses float32.

Flash attention 2 is not installed
/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
WeightNorm.apply(module, name, dim)
Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
"_name_or_path": "google/flan-t5-large",
"architectures": [
"T5ForConditionalGeneration"
],
"classifier_dropout": 0.0,
"d_ff": 2816,
"d_kv": 64,
"d_model": 1024,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"n_positions": 512,
"num_decoder_layers": 24,
"num_heads": 16,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"transformers_version": "4.46.1",
"use_cache": true,
"vocab_size": 32128
}

Config of the audio_encoder: <class 'parler_tts.dac_wrapper.modeling_dac.DACModel'> is overwritten by shared audio_encoder config: DACConfig {
"_name_or_path": "parler-tts/dac_44khZ_8kbps",
"architectures": [
"DACModel"
],
"codebook_size": 1024,
"frame_rate": 86,
"latent_dim": 1024,
"model_bitrate": 8,
"model_type": "dac_on_the_hub",
"num_codebooks": 9,
"sampling_rate": 44100,
"torch_dtype": "float32",
"transformers_version": "4.46.1"
}

Config of the decoder: <class 'parler_tts.modeling_parler_tts.ParlerTTSForCausalLM'> is overwritten by shared decoder config: ParlerTTSDecoderConfig {
"_name_or_path": "/fsx/yoach/tmp/artefacts/parler-tts-mini/decoder",
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": true,
"architectures": [
"ParlerTTSForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1025,
"codebook_weights": null,
"cross_attention_implementation_strategy": null,
"dropout": 0.1,
"eos_token_id": 1024,
"ffn_dim": 4096,
"hidden_size": 1024,
"initializer_factor": 0.02,
"is_decoder": true,
"layerdrop": 0.0,
"max_position_embeddings": 4096,
"model_type": "parler_tts_decoder",
"num_attention_heads": 16,
"num_codebooks": 9,
"num_cross_attention_key_value_heads": 16,
"num_hidden_layers": 24,
"num_key_value_heads": 16,
"pad_token_id": 1024,
"rope_embeddings": false,
"rope_theta": 10000.0,
"scale_embedding": false,
"tie_word_embeddings": false,
"torch_dtype": "float32",
"transformers_version": "4.46.1",
"use_cache": true,
"use_fused_lm_heads": false,
"vocab_size": 1088
}

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
Traceback (most recent call last):
File "/Users/tulas/Projects/parler-tts/main.py", line 39, in
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/parler_tts/modeling_parler_tts.py", line 3633, in generate
sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **single_audio_decode_kwargs).audio_values
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/parler_tts/dac_wrapper/modeling_dac.py", line 139, in decode
audio_values = self.model.decode(audio_values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/dac/model/dac.py", line 266, in decode
return self.decoder(z)
^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/dac/model/dac.py", line 144, in forward
return self.model(x)
^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward
input = module(input)
^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/dac/model/dac.py", line 112, in forward
return self.block(x)
^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward
input = module(input)
^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/dac/model/dac.py", line 36, in forward
y = self.block(x)
^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward
input = module(input)
^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 375, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 370, in _conv_forward
return F.conv1d(
^^^^^^^^^
NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

@chigkim
Copy link

chigkim commented Nov 11, 2024

It looks like a problem with Pytorch with MPS. :(

pytorch/pytorch#134416

They just changed the output message to "Output channels > 65536 not supported at the MPS device." removing the message "As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1".

// TODO: MPS convolution kernel currently does not support output channels > 2^16

pytorch/pytorch@aa3ae50

@hvaara
Copy link

hvaara commented Nov 14, 2024

Feel free to follow pytorch/pytorch#140722 for updates on a fix in PyTorch. Tentative fix in pytorch/pytorch#140726.

@hvaara
Copy link

hvaara commented Nov 14, 2024

The channel size issue has been fixed in PyTorch on macOS 15.1. It should be available in PyTorch nightly in < 24h.

While testing the fix I discovered that descript-audiotools, which parler-tts is a transitive dependent of, requires torch.distributed for types. I don't know why, but unfortunately torch.distributed is disabled by default in PyTorch on macOS. This should be the last remaining step to get parler-tts working on macOS with PyTorch/MPS.

The most straight-forward approach is probably to handle unavailability gracefully in descript-audiotools. I'm quite curious why support was removed for macOS though, since this was definitely supported in the past (ref pytorch/pytorch#20380 (comment)).

@QueryType
Copy link

Does it work now? Or we still wait for the final fix?

@hvaara
Copy link

hvaara commented Dec 7, 2024

Yes, I think this is fixed, and no further code changes are needed.

I think I misread the docs. From what I can tell, torch.distributed is available in the nightlies, but not by default when compiling from source.

Install nightlies with

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

then

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

device = "mps"
model_name = "parler-tts/parler-tts-mini-v1"

model = ParlerTTSForConditionalGeneration.from_pretrained(model_name, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Hey, how are you doing today?"
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

should provide good output.

This issue can be closed.

cc @ylacombe

@tulas75
Copy link

tulas75 commented Dec 7, 2024

Hi @hvaara

Thank you for your clarification and support, but it doesn't resolve the issue for me.
After changing the code accordingly to your post, I got the following error:

NotImplementedError: Output channels > 65536 not supported at the MPS device.

tnx
.t

@hvaara
Copy link

hvaara commented Dec 7, 2024

@tulas75 you need a pretty recent version of PyTorch nightly. What's the output of torch.__version__?

@tulas75
Copy link

tulas75 commented Dec 7, 2024

@hvaara 2.6.0.dev20241206

@hvaara
Copy link

hvaara commented Dec 7, 2024

@tulas75 interesting. And what's the code/command you ran?

@tulas75
Copy link

tulas75 commented Dec 7, 2024

@hvaara

I installed everything from scratch. These the steps I followed:

  • I created a python 3.12 version virtualenv.
  • I installed torch nightly and parler-tts.
  • I created a main.py with your code.
  • I ran python main.py
  • I got the same error: NotImplementedError: Output channels > 65536 not supported at the MPS device.

This is the pip list output
(env) tulas@andreas-Mac-Studio par-tts % pip list
Package Version


absl-py 2.1.0
accelerate 1.2.0
argbind 0.3.9
asttokens 3.0.0
audioread 3.0.1
certifi 2024.8.30
cffi 1.17.1
charset-normalizer 3.4.0
contourpy 1.3.1
cycler 0.12.1
decorator 5.1.1
descript-audio-codec 1.0.0
descript-audiotools 0.7.4
docstring_parser 0.16
einops 0.8.0
executing 2.1.0
ffmpy 0.4.0
filelock 3.16.1
fire 0.7.0
flatten-dict 0.4.2
fonttools 4.55.2
fsspec 2024.10.0
future 1.0.0
grpcio 1.68.1
huggingface-hub 0.26.5
idna 3.10
importlib_resources 6.4.5
ipython 8.30.0
jedi 0.19.2
Jinja2 3.1.4
joblib 1.4.2
julius 0.2.7
kiwisolver 1.4.7
lazy_loader 0.4
librosa 0.10.2.post1
llvmlite 0.43.0
Markdown 3.7
markdown-it-py 3.0.0
markdown2 2.5.1
MarkupSafe 2.1.5
matplotlib 3.9.3
matplotlib-inline 0.1.7
mdurl 0.1.2
mpmath 1.3.0
msgpack 1.1.0
networkx 3.4.2
numba 0.60.0
numpy 2.0.2
packaging 24.2
parler_tts 0.2.2
parso 0.8.4
pexpect 4.9.0
pillow 11.0.0
pip 24.3.1
platformdirs 4.3.6
pooch 1.8.2
prompt_toolkit 3.0.48
protobuf 4.25.5
psutil 6.1.0
ptyprocess 0.7.0
pure_eval 0.2.3
pycparser 2.22
Pygments 2.18.0
pyloudnorm 0.1.1
pyparsing 3.2.0
pystoi 0.4.1
python-dateutil 2.9.0.post0
PyYAML 6.0.2
randomname 0.2.1
regex 2024.11.6
requests 2.32.3
rich 13.9.4
safetensors 0.4.5
scikit-learn 1.5.2
scipy 1.14.1
sentencepiece 0.2.0
setuptools 70.2.0
six 1.17.0
soundfile 0.12.1
soxr 0.5.0.post1
stack-data 0.6.3
sympy 1.13.1
tensorboard 2.18.0
tensorboard-data-server 0.7.2
termcolor 2.5.0
threadpoolctl 3.5.0
tokenizers 0.20.3
torch 2.6.0.dev20241206
torch-stoi 0.2.3
torchaudio 2.5.0.dev20241206
torchvision 0.20.0.dev20241206
tqdm 4.67.1
traitlets 5.14.3
transformers 4.46.1
typing_extensions 4.12.2
urllib3 2.2.3
wcwidth 0.2.13
Werkzeug 3.1.3

@hvaara
Copy link

hvaara commented Dec 7, 2024

@tulas75 thanks for the info! My working hypothesis is that the fix from the main branch isn't in nightly for some reason. IIRC I haven't actually tested with nightly (only main).

I'm not at my computer right now, so I'll try to repro when I'm back.

@hvaara
Copy link

hvaara commented Dec 8, 2024

@tulas75 I'm not able to repro using nightly versions (same as you)

torch                     2.6.0.dev20241206
torchaudio                2.5.0.dev20241206

In the script that produces the error, can you do a print(torch.__version__) immediately before the statement that causes error and post the entire script and the output of the script to https://gist.github.com/?

@QueryType
Copy link

QueryType commented Dec 9, 2024

Hello, I still get the attention mask error! Followed all the steps and used the code above. Can you tell where I missed?

python main.py Flash attention 2 is not installed model.safetensors: 100%|██████████████████████████████████████████████████████████████████████| 3.51G/3.51G [02:51<00:00, 13.1MB/s] /opt/homebrew/Caskroom/miniconda/base/envs/parler-tts/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: torch.nn.utils.weight_normis deprecated in favor oftorch.nn.utils.parametrizations.weight_norm. WeightNorm.apply(module, name, dim) generation_config.json: 100%|██████████████████████████████████████████████████████████████████████| 265/265 [00:00<00:00, 770kB/s] tokenizer_config.json: 100%|██████████████████████████████████████████████████████████████████| 20.8k/20.8k [00:00<00:00, 25.6MB/s] spiece.model: 100%|█████████████████████████████████████████████████████████████████████████████| 792k/792k [00:00<00:00, 1.56MB/s] tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████████| 2.42M/2.42M [00:01<00:00, 2.24MB/s] special_tokens_map.json: 100%|████████████████████████████████████████████████████████████████| 2.54k/2.54k [00:00<00:00, 11.8MB/s] The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's attention_maskto obtain reliable results. Traceback (most recent call last): File "/Volumes/d/code/aiml/parler-tts/main.py", line 18, in <module> generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/Caskroom/miniconda/base/envs/parler-tts/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/Caskroom/miniconda/base/envs/parler-tts/lib/python3.11/site-packages/parler_tts/modeling_parler_tts.py", line 3292, in generate model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/Caskroom/miniconda/base/envs/parler-tts/lib/python3.11/site-packages/transformers/generation/utils.py", line 493, in _prepare_attention_mask_for_generation raise ValueError( ValueError: Can't infer missing attention mask onmpsdevice. Please provide anattention_maskor use a different device. (parler-tts) nireve@Niranjans-Mini parler-tts % pip freeze | grep torch torch==2.6.0.dev20241208 torch-stoi==0.2.3 torchaudio==2.5.0.dev20241208 torchvision==0.20.0.dev20241208

@chigkim
Copy link

chigkim commented Dec 9, 2024

I confirm it works. This is what I did from scratch.
Run line by line and make sure there are no obvious errors.

git clone https://github.com/huggingface/parler-tts
cd parler-tts
python3 -m venv .venv
source .venv/bin/activate
pip install -e .
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip install Accelerate

Then try running the testing script that @hvaara provided earlier.

@hvaara
Copy link

hvaara commented Dec 9, 2024

@chigkim it works as in it's producing a good audio file in the end?

@QueryType
Copy link

I confirm it works. This is what I did from scratch. Run line by line and make sure there are no obvious errors.

git clone https://github.com/huggingface/parler-tts
cd parler-tts
python3 -m venv .venv
source .venv/bin/activate
pip install -e .
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip install Accelerate

Then try running the testing script that @hvaara provided earlier.

I had to make 2 modifications to the above steps to get it working on my Mac Mini M2.

  1. During pip install -e . there is a bug in the current version during compile. I switched to earlier commit.
  2. I had to add the attention mask manually in the code, add load it in the device. (thanks to chatGPT)
    The modified code is:
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

device = "mps"
model_name = "parler-tts/parler-tts-mini-v1"

model = ParlerTTSForConditionalGeneration.from_pretrained(model_name, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Hey, how are you doing today?"
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Create the attention mask (1s for real tokens, 0s for padding tokens)
attention_mask = input_ids.ne(tokenizer.pad_token_id).long().to(device)


generation = model.generate(input_ids=input_ids, 
                            prompt_input_ids=prompt_input_ids,
                            attention_mask=attention_mask)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

After this the code worked, I could see GPU usuage and it was very fast (than when from CPU). Thanks to all.

@ChristianWeyer
Copy link
Author

I confirm it works. This is what I did from scratch. Run line by line and make sure there are no obvious errors.

git clone https://github.com/huggingface/parler-tts
cd parler-tts
python3 -m venv .venv
source .venv/bin/activate
pip install -e .
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip install Accelerate

Then try running the testing script that @hvaara provided earlier.

I had to make 2 modifications to the above steps to get it working on my Mac Mini M2.

  1. During pip install -e . there is a bug in the current version during compile. I switched to earlier commit.
  2. I had to add the attention mask manually in the code, add load it in the device. (thanks to chatGPT)
    The modified code is:
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

device = "mps"
model_name = "parler-tts/parler-tts-mini-v1"

model = ParlerTTSForConditionalGeneration.from_pretrained(model_name, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Hey, how are you doing today?"
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Create the attention mask (1s for real tokens, 0s for padding tokens)
attention_mask = input_ids.ne(tokenizer.pad_token_id).long().to(device)


generation = model.generate(input_ids=input_ids, 
                            prompt_input_ids=prompt_input_ids,
                            attention_mask=attention_mask)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

After this the code worked, I could see GPU usuage and it was very fast (than when from CPU). Thanks to all.

Will you create a PR?

@hvaara
Copy link

hvaara commented Dec 10, 2024

@ChristianWeyer I didn't think a PR was needed. What and where do you think something needs to be changed?

@ChristianWeyer
Copy link
Author

I mean, @QueryType had to make changes to get it working...

@hvaara
Copy link

hvaara commented Dec 10, 2024

You mean the attention_mask? This is likely because pad_token_id and eos_token_id is the same value in

pad_token_id=2048,
bos_token_id=2049,
eos_token_id=2048,

It's just a warning. It's not specific to the current issue, and doesn't prevent you from generating good output with parler-tts. It can also be seen in the fine tuning example notebook.

image

I believe the warning is benign and essentially WAI.

The original error

ValueError: Can't infer missing attention mask on mpsdevice. Please provide anattention_mask or use a different device.

would prevent you from generating output with parler-tts, and was likely fixed when @ylacombe bumped the transformers version.

The second error

NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

would also prevent you from generating output with parler-tts due to pytorch/pytorch#134416, and was fixed in pytorch/pytorch#140726.

I don't think any further changes is needed in order to mark the current issue as fixed, but please let us know if you disagree.

@QueryType
Copy link

You mean the attention_mask? This is likely because pad_token_id and eos_token_id is the same value in

pad_token_id=2048,
bos_token_id=2049,
eos_token_id=2048,

It's just a warning. It's not specific to the current issue, and doesn't prevent you from generating good output with parler-tts. It can also be seen in the fine tuning example notebook.

image I believe the warning is benign and essentially WAI.

The original error

ValueError: Can't infer missing attention mask on mpsdevice. Please provide anattention_mask or use a different device.

would prevent you from generating output with parler-tts, and was likely fixed when @ylacombe bumped the transformers version.

The second error

NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

would also prevent you from generating output with parler-tts due to pytorch/pytorch#134416, and was fixed in pytorch/pytorch#140726.

I don't think any further changes is needed in order to mark the current issue as fixed, but please let us know if you disagree.

Me too, I do not think we need a PR. The issue with the compilation, is already under a fix, I think.

@hvaara
Copy link

hvaara commented Dec 11, 2024

Heads up: pytorch/pytorch#140726 introduced a regression affecting at least the Conv2d op. I'm currently investigating the scope of the impact and potential mitigations. As of now, it appears that Conv1d is unaffected, and its fix will be included in PyTorch v2.6.0. For updates, please follow pytorch/pytorch#142836.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants