Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does not work on macOS with device="mps": "Can't infer missing attention mask on mps device" #148

Open
ChristianWeyer opened this issue Oct 14, 2024 · 1 comment

Comments

@ChristianWeyer
Copy link

This is my simple test script:

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

torch_device = "mps:0"
torch_dtype = torch.bfloat16
model_name = "parler-tts/parler-tts-mini-v1"

attn_implementation = "eager" # "sdpa" or "flash_attention_2"

model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
    attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Hey, how are you doing today?"
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(torch_device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device)

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()

sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

I get this error:

ValueError: Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device.

Any idea what could be wrong?
Thanks!

@tulas75
Copy link

tulas75 commented Oct 16, 2024

same problem for me.
I tried to use transformers version 4.44.2 (not supported by parler-tts) and it seems to use GPU but at the end saving the wav file, I get an error.

NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants