From edd1241c0f5b5d10ad151057bd0975c405aeae78 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Sun, 8 Dec 2024 15:56:33 +0700 Subject: [PATCH] chore: refactor code --- .gitignore | 4 + models/whispervq/app.py | 223 ++---------------- models/whispervq/models/audio.py | 22 ++ .../whispervq/routes/AudioTokenizerRoute.py | 19 ++ .../services/AudioTokenizerService.py | 167 +++++++++++++ .../whispervq/{ => utils}/custom_component.py | 77 +++--- 6 files changed, 280 insertions(+), 232 deletions(-) create mode 100644 models/whispervq/models/audio.py create mode 100644 models/whispervq/routes/AudioTokenizerRoute.py create mode 100644 models/whispervq/services/AudioTokenizerService.py rename models/whispervq/{ => utils}/custom_component.py (79%) diff --git a/.gitignore b/.gitignore index 8b13789..b14879a 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,5 @@ +__pycache__ +.cache +*.pt +*.model \ No newline at end of file diff --git a/models/whispervq/app.py b/models/whispervq/app.py index 6598de7..57550be 100644 --- a/models/whispervq/app.py +++ b/models/whispervq/app.py @@ -1,4 +1,4 @@ -import argparse +import argparse, os,sys parser = argparse.ArgumentParser(description="WhisperVQ Application") parser.add_argument('--log-path', type=str, default='whisper.log', help='The log file path') @@ -6,32 +6,22 @@ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'TRACE'], help='The log level') parser.add_argument('--port', type=int, default=3348, help='The port to run the WhisperVQ app on') +parser.add_argument('--device-id', type=str, default="0", + help='The port to run the WhisperVQ app on') parser.add_argument('--package-dir', type=str, default="", help='The package-dir to be extended to sys.path') args = parser.parse_args() -import sys -sys.path.insert(0, args.environment) -import tempfile -from typing import Tuple -from enum import Enum -import io +sys.path.insert(0, args.package_dir) +os.environ["CUDA_VISIBLE_DEVICES"] =args.device_id # Use the first Nvidia GPU + import logging -from custom_component import CustomRQBottleneckTransformer -from whisperspeech.vq_stoks import RQBottleneckTransformer -from huggingface_hub import hf_hub_download import uvicorn -from transformers import WhisperModel, WhisperProcessor -from fastapi.responses import JSONResponse -from fastapi import FastAPI, File, UploadFile, HTTPException +from fastapi import FastAPI from contextlib import asynccontextmanager -import torchaudio -import torch import os import time import psutil import threading - - logging.basicConfig(level=args.log_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(args.log_path), @@ -39,200 +29,24 @@ ]) logger = logging.getLogger(__name__) -os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use the first GPU - -device = "cuda" if torch.cuda.is_available() else "cpu" -if not os.path.exists(os.path.dirname(os.path.realpath(__file__))+"/whisper-vq-stoks-v3-7lang-fixed.model"): - hf_hub_download( - repo_id="jan-hq/WhisperVQ", - filename="whisper-vq-stoks-v3-7lang-fixed.model", - local_dir=".", - ) -vq_model = CustomRQBottleneckTransformer.load_vq_only( - os.path.dirname(os.path.realpath(__file__)) + - "/whisper-vq-stoks-v3-7lang-fixed.model" -).to(device) -vq_model.load_encoder(device) -vq_model.eval() +# after set up logger we can import and use services +from services.AudioTokenizerService import get_audio_tokenizer_service +from routes.AudioTokenizerRoute import audio_tokenizer_router @asynccontextmanager async def lifespan(app: FastAPI): - + + # on startup + get_audio_tokenizer_service() yield # on shutdown - -# vq_model = torch.compile(vq_model) - - -class AudioFormat(str, Enum): - WAV = "wav" # Supported by both backends - MP3 = "mp3" # Supported by ffmpeg - FLAC = "flac" # Supported by both - AAC = "aac" # Supported by ffmpeg - OGG = "ogg" # Supported by ffmpeg - OPUS = "opus" # Supported by ffmpeg - PCM = "pcm" # Raw PCM data - - -# Format to backend mapping -FORMAT_BACKENDS = { - AudioFormat.WAV: ["soundfile", "ffmpeg"], - AudioFormat.MP3: ["ffmpeg"], - AudioFormat.FLAC: ["soundfile", "ffmpeg"], - AudioFormat.AAC: ["ffmpeg"], - AudioFormat.OGG: ["ffmpeg"], - AudioFormat.OPUS: ["ffmpeg"], - AudioFormat.PCM: ["soundfile"] -} - - -class AudioProcessor: - def __init__(self): - self.available_backends = torchaudio.list_audio_backends() - logger.info(f"Available backends: {self.available_backends}") - - # Verify ffmpeg support - self.has_ffmpeg = "ffmpeg" in self.available_backends - if not self.has_ffmpeg: - logger.warning( - "FFMPEG backend not available. Some formats may not be supported") - - def _get_best_backend(self, format: AudioFormat) -> str: - """Determine the best backend for the given format""" - supported_backends = FORMAT_BACKENDS[format] - for backend in supported_backends: - if backend in self.available_backends: - return backend - raise ValueError(f"No available backend supports format {format}") - - async def load_audio( - self, - file_obj: bytes, - format: AudioFormat, - target_sr: int = 16000 - ) -> Tuple[torch.Tensor, int]: - """ - Load audio from bytes object with format handling - - Args: - file_obj: Audio file bytes - format: Audio format enum - target_sr: Target sample rate (default: 16000) - - Returns: - Tuple[torch.Tensor, int]: Audio tensor and sample rate - """ - try: - # Get appropriate backend - backend = self._get_best_backend(format) - torchaudio.set_audio_backend(backend) - logger.info(f"Using {backend} backend for {format} format") - - if format == AudioFormat.PCM: - # Handle raw PCM - wav = torch.frombuffer(file_obj, dtype=torch.int16) - wav = wav.float() / 32768.0 # Normalize to [-1, 1] - wav = wav.unsqueeze(0) # Add channel dimension - sr = target_sr - else: - # For formats that might need ffmpeg processing - if os.name == "nt": # for windows - wav, sr = torchaudio.load(io.BytesIO(file_obj)) - else: - with tempfile.NamedTemporaryFile(suffix=f".{format}") as temp_file: - # Write bytes to temporary file - temp_file.write(file_obj) - temp_file.flush() - - # Load audio - wav, sr = torchaudio.load(temp_file.name) - - # Convert to mono if stereo - if wav.shape[0] > 1: - wav = torch.mean(wav, dim=0, keepdim=True) - - # Resample if needed - if sr != target_sr: - wav = torchaudio.functional.resample(wav, sr, target_sr) - sr = target_sr - - return wav, sr - - except Exception as e: - logger.error(f"Error loading audio: {e}") - raise HTTPException( - status_code=400, - detail=f"Error processing {format} audio: {str(e)}" - ) - - def get_format_info(self) -> dict: - """Get information about supported formats""" - supported_formats = {} - for format in AudioFormat: - try: - backend = self._get_best_backend(format) - supported_formats[format] = { - "supported": True, - "backend": backend - } - except ValueError: - supported_formats[format] = { - "supported": False, - "backend": None - } - return supported_formats - - -audio_processor = AudioProcessor() - app = FastAPI(lifespan=lifespan) - -@app.get("/supported_formats") -async def get_supported_formats(): - """Endpoint to check supported formats""" - return audio_processor.get_format_info() - - -@app.post("/tokenize/{format}") -async def tokenize_audio(format: AudioFormat = "wav", file: UploadFile = File(...)): - try: - # Read file - file_obj = await file.read() - - # Load and process audio - wav, sr = await audio_processor.load_audio(file_obj, format) - - # Ensure we're using CUDA if available - device = "cuda" if torch.cuda.is_available() else "cpu" - wav = wav.to(device) - - # Generate tokens - with torch.no_grad(): - codes = vq_model.encode_audio(wav) - codes = codes[0].cpu().tolist() - - # Format result - result = ''.join(f'<|sound_{num:04d}|>' for num in codes) - - return JSONResponse(content={ - "model_name": "whisper-vq-stoks-v3-7lang-fixed.model", - "tokens": f'<|sound_start|>{result}<|sound_end|>', - "format": format, - "sample_rate": sr, - "backend_used": audio_processor._get_best_backend(format) - }) - - except Exception as e: - logger.error(f"Error processing request: {e}") - raise HTTPException( - status_code=500, - detail=f"Error processing request: {str(e)}" - ) - +# include the routes +app.include_router(audio_tokenizer_router) def self_terminate(): time.sleep(1) @@ -240,8 +54,8 @@ def self_terminate(): parent.kill() -@app.post("/kill") -async def kill(): +@app.delete("/destroy") +async def destroy(): threading.Thread(target=self_terminate, daemon=True).start() return {"success": True} @@ -263,8 +77,7 @@ async def kill(): LOGGING_CONFIG["loggers"]["uvicorn.access"]["level"] = args.log_level # Print supported formats at startup - processor = AudioProcessor() - format_info = processor.get_format_info() + format_info = get_audio_tokenizer_service().get_format_info() logger.info("Supported formats:") for format, info in format_info.items(): logger.info(f"{format}: {info}") diff --git a/models/whispervq/models/audio.py b/models/whispervq/models/audio.py new file mode 100644 index 0000000..3a34073 --- /dev/null +++ b/models/whispervq/models/audio.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel +from enum import Enum + +class AudioFormat(str, Enum): + WAV = "wav" # Supported by both backends + MP3 = "mp3" # Supported by ffmpeg + FLAC = "flac" # Supported by both + AAC = "aac" # Supported by ffmpeg + OGG = "ogg" # Supported by ffmpeg + OPUS = "opus" # Supported by ffmpeg + PCM = "pcm" # Raw PCM data + +# Format to backend mapping +FORMAT_BACKENDS = { + AudioFormat.WAV: ["soundfile", "ffmpeg"], + AudioFormat.MP3: ["ffmpeg"], + AudioFormat.FLAC: ["soundfile", "ffmpeg"], + AudioFormat.AAC: ["ffmpeg"], + AudioFormat.OGG: ["ffmpeg"], + AudioFormat.OPUS: ["ffmpeg"], + AudioFormat.PCM: ["soundfile"] +} diff --git a/models/whispervq/routes/AudioTokenizerRoute.py b/models/whispervq/routes/AudioTokenizerRoute.py new file mode 100644 index 0000000..cc552a1 --- /dev/null +++ b/models/whispervq/routes/AudioTokenizerRoute.py @@ -0,0 +1,19 @@ +from services.AudioTokenizerService import get_audio_tokenizer_service +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import File, UploadFile +from models.audio import AudioFormat, FORMAT_BACKENDS + +audio_tokenizer_router = APIRouter( + prefix="/tokenize", tags=["audio"]) + + +@audio_tokenizer_router.post("/{format}") +async def tokenize_audio(format: AudioFormat = "wav", file: UploadFile = File(...)): + file_obj = await file.read() + get_audio_tokenizer_service().tokenize(file_obj, format) + return get_audio_tokenizer_service().tokenize(file_obj, format) + + +@audio_tokenizer_router.get("/supported_formats") +async def get_supported_formats(): + return get_audio_tokenizer_service().get_format_info() diff --git a/models/whispervq/services/AudioTokenizerService.py b/models/whispervq/services/AudioTokenizerService.py new file mode 100644 index 0000000..0fe6471 --- /dev/null +++ b/models/whispervq/services/AudioTokenizerService.py @@ -0,0 +1,167 @@ +import io +import os +from huggingface_hub import hf_hub_download +from models.audio import AudioFormat, FORMAT_BACKENDS +import tempfile +import logging +import torchaudio +from fastapi import HTTPException +from fastapi.responses import JSONResponse +import torch +from typing import Tuple +from utils.custom_component import CustomRQBottleneckTransformer +logger = logging.getLogger(__name__) + + +class AudioTokenizerService: + def __init__(self): + self.available_backends = torchaudio.list_audio_backends() + logger.info(f"Available backends: {self.available_backends}") + main_directory = os.path.dirname( + os.path.dirname(os.path.realpath(__file__))) + + # Verify ffmpeg support + self.has_ffmpeg = "ffmpeg" in self.available_backends + if not self.has_ffmpeg: + logger.warning( + "FFMPEG backend not available. Some formats may not be supported") + device = "cuda" if torch.cuda.is_available() else "cpu" + if not os.path.exists(main_directory+"/whisper-vq-stoks-v3-7lang-fixed.model"): + hf_hub_download( + repo_id="jan-hq/WhisperVQ", + filename="whisper-vq-stoks-v3-7lang-fixed.model", + local_dir=main_directory, + ) + self.vq_model = CustomRQBottleneckTransformer.load_vq_only( + main_directory + + "/whisper-vq-stoks-v3-7lang-fixed.model" + ).to(device) + self.vq_model.load_encoder(device) + self.vq_model.eval() + # vq_model = torch.compile(vq_model) + + def _get_best_backend(self, format: AudioFormat) -> str: + """Determine the best backend for the given format""" + supported_backends = FORMAT_BACKENDS[format] + for backend in supported_backends: + if backend in self.available_backends: + return backend + raise ValueError(f"No available backend supports format {format}") + + def load_audio( + self, + file_obj: bytes, + format: AudioFormat, + target_sr: int = 16000 + ) -> Tuple[torch.Tensor, int]: + """ + Load audio from bytes object with format handling + + Args: + file_obj: Audio file bytes + format: Audio format enum + target_sr: Target sample rate (default: 16000) + + Returns: + Tuple[torch.Tensor, int]: Audio tensor and sample rate + """ + try: + # Get appropriate backend + backend = self._get_best_backend(format) + torchaudio.set_audio_backend(backend) + logger.info(f"Using {backend} backend for {format} format") + + if format == AudioFormat.PCM: + # Handle raw PCM + wav = torch.frombuffer(file_obj, dtype=torch.int16) + wav = wav.float() / 32768.0 # Normalize to [-1, 1] + wav = wav.unsqueeze(0) # Add channel dimension + sr = target_sr + else: + # For formats that might need ffmpeg processing + if os.name == "nt": # for windows + wav, sr = torchaudio.load(io.BytesIO(file_obj)) + else: + with tempfile.NamedTemporaryFile(suffix=f".{format}") as temp_file: + # Write bytes to temporary file + temp_file.write(file_obj) + temp_file.flush() + + # Load audio + wav, sr = torchaudio.load(temp_file.name) + + # Convert to mono if stereo + if wav.shape[0] > 1: + wav = torch.mean(wav, dim=0, keepdim=True) + + # Resample if needed + if sr != target_sr: + wav = torchaudio.functional.resample(wav, sr, target_sr) + sr = target_sr + + return wav, sr + + except Exception as e: + logger.error(f"Error loading audio: {e}") + raise HTTPException( + status_code=400, + detail=f"Error processing {format} audio: {str(e)}" + ) + + def get_format_info(self) -> dict: + """Get information about supported formats""" + supported_formats = {} + for format in AudioFormat: + try: + backend = self._get_best_backend(format) + supported_formats[format] = { + "supported": True, + "backend": backend + } + except ValueError: + supported_formats[format] = { + "supported": False, + "backend": None + } + return supported_formats + + def tokenize(self, audio_data: bytes, format: AudioFormat = "wav"): + try: + wav, sr = self.load_audio(audio_data, format) + + # Ensure we're using CUDA if available + device = "cuda" if torch.cuda.is_available() else "cpu" + wav = wav.to(device) + + # Generate tokens + with torch.no_grad(): + codes = self.vq_model.encode_audio(wav) + codes = codes[0].cpu().tolist() + + # Format result + result = ''.join(f'<|sound_{num:04d}|>' for num in codes) + + return JSONResponse(content={ + "model_name": "whisper-vq-stoks-v3-7lang-fixed.model", + "tokens": f'<|sound_start|>{result}<|sound_end|>', + "format": format, + "sample_rate": sr, + "backend_used": self._get_best_backend(format) + }) + + except Exception as e: + logger.error(f"Error processing request: {e}") + raise HTTPException( + status_code=500, + detail=f"Error processing request: {str(e)}" + ) + + +_audio_tokenizer_service = None + + +def get_audio_tokenizer_service(): + global _audio_tokenizer_service + if _audio_tokenizer_service is None: + _audio_tokenizer_service = AudioTokenizerService() + return _audio_tokenizer_service diff --git a/models/whispervq/custom_component.py b/models/whispervq/utils/custom_component.py similarity index 79% rename from models/whispervq/custom_component.py rename to models/whispervq/utils/custom_component.py index ebd7d9a..ae6b28f 100644 --- a/models/whispervq/custom_component.py +++ b/models/whispervq/utils/custom_component.py @@ -12,12 +12,17 @@ import urllib from tqdm import tqdm import torchaudio -_HF_MODELS = { + +_HF_MODELS = { "medium": "https://huggingface.co/jan-hq/WhisperVQ/resolve/main/medium_encoder_only.pt", } + + def available_models() -> List[str]: """Returns the names of available models""" return list(_HF_MODELS.keys()) + + def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: os.makedirs(root, exist_ok=True) @@ -25,13 +30,15 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: download_target = os.path.join(root, os.path.basename(url)) if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") + raise RuntimeError( + f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): with open(download_target, "rb") as f: model_bytes = f.read() return model_bytes if in_memory else download_target - + import ssl + ssl._create_default_https_context = ssl._create_unverified_context with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm( total=int(source.info().get("Content-Length")), @@ -50,30 +57,36 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: model_bytes = open(download_target, "rb").read() return model_bytes if in_memory else download_target + + class CustomWhisperEncoder(nn.Module): """ Lightweight wrapper that only loads the AudioEncoder part of Whisper """ + def __init__(self, name: str, device: str = None, download_root: str = None, in_memory: bool = False,): super().__init__() if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if download_root is None: default = os.path.join(os.path.expanduser("~"), ".cache") - download_root = os.path.dirname(os.path.realpath(__file__)) #os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") + # os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") + download_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) if name in _HF_MODELS: - checkpoint_file = _download(_HF_MODELS[name], download_root, in_memory) + checkpoint_file = _download( + _HF_MODELS[name], download_root, in_memory) elif os.path.isfile(name): checkpoint_file = open(name, "rb").read() if in_memory else name else: raise RuntimeError( f"Model {name} not found; available models = {available_models()}" ) - + # Load weights with ( - io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") + io.BytesIO(checkpoint_file) if in_memory else open( + checkpoint_file, "rb") ) as fp: checkpoint = torch.load(fp, map_location=device) del checkpoint_file @@ -85,37 +98,41 @@ def __init__(self, name: str, device: str = None, download_root: str = None, in_ dims.n_audio_head, dims.n_audio_layer, ) - + self.encoder.load_state_dict(checkpoint["model_state_dict"]) - + if device: self.to(device) - + self.eval() def forward(self, mel: torch.Tensor): return self.encoder(mel) - + + class CustomRQBottleneckTransformer(RQBottleneckTransformer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @classmethod def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model", - repo_id=None, filename=None, local_filename=None): + repo_id=None, filename=None, local_filename=None): if repo_id is None and filename is None and local_filename is None: if ":" in ref: repo_id, filename = ref.split(":", 1) else: local_filename = ref if not local_filename: - local_filename = hf_hub_download(repo_id=repo_id, filename=filename) - + local_filename = hf_hub_download( + repo_id=repo_id, filename=filename) + # Load the spec spec = torch.load(local_filename) - + # Create instance with minimal required components - instance = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec.get('tunables', {})))) - + instance = cls(**spec['config'], tunables=Tunables(** + Tunables.upgrade(spec.get('tunables', {})))) + # Load only necessary state dict entries required_components = { 'rq', 'mlp', 'mlp_ln' @@ -124,49 +141,55 @@ def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-e k: v for k, v in spec['state_dict'].items() if any(k.startswith(comp) for comp in required_components) } - + instance.load_state_dict(filtered_state_dict, strict=False) instance.eval() return instance def load_encoder(self, device=None): - if self.whmodel is not None: return + if self.whmodel is not None: + return device = device or self.device # Use our custom encoder-only model if self.whmodel is None: - encoder = CustomWhisperEncoder(self.whisper_model_name, device=device) - self.whmodel = [encoder] + encoder = CustomWhisperEncoder( + self.whisper_model_name, device=device) + self.whmodel = encoder multilingual = not self.whisper_model_name.endswith('.en') self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual) def optimzed_encode_mel(self, mel): - assert len(mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)" + assert len( + mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)" self.load_encoder() n = mel.shape[-1] if n > whisper.audio.N_FRAMES: padding = 0 - padded = mel[:,:,:whisper.audio.N_FRAMES] + padded = mel[:, :, :whisper.audio.N_FRAMES] else: padding = -n % whisper.audio.N_FRAMES padded = F.pad(mel, (0, padding), value=-1.5) - embs = self.whmodel[0].encoder(padded)#.to(self.whmodel[0].device))#[:,:n//2] + # .to(self.whmodel[0].device))#[:,:n//2] + embs = self.whmodel.encoder(padded) stoks = self.quantize(embs) if self.tunables.mask_embs: - return stoks[:,:n//2//self.downsample] + return stoks[:, :n//2//self.downsample] else: return stoks # overide + def encode_audio(self, audio): if isinstance(audio, str): x, sr = torchaudio.load(audio) x = torchaudio.transforms.Resample(sr, 16000)(x)[0] audio = x.unsqueeze(0) return self.optimzed_encode_mel(self.log_mel_spectrogram(audio).to(self.device)) - + + if __name__ == "__main__": # Load the model vqmodel = CustomRQBottleneckTransformer.load_vq_only( "whisper-vq-stoks-v3-7lang-fixed.model" ).to("cuda") vqmodel.load_encoder('cuda') - vqmodel.eval() \ No newline at end of file + vqmodel.eval()