diff --git a/kaldi_active_grammar/compiler.py b/kaldi_active_grammar/compiler.py index db1531e..47dd94f 100644 --- a/kaldi_active_grammar/compiler.py +++ b/kaldi_active_grammar/compiler.py @@ -16,6 +16,7 @@ from .wfst import WFST, NativeWFST, SymbolTable from .model import Model from .wrapper import KaldiAgfCompiler, KaldiAgfNNet3Decoder, KaldiLafNNet3Decoder +import kaldi_active_grammar.whisper_dictation as whisper_dictation import kaldi_active_grammar.defaults as defaults _log = _log.getChild('compiler') @@ -646,6 +647,7 @@ def parse_output_for_rule(self, kaldi_rule, output): self._log.error("parsed_output(%r).lower() != output(%r)" % (parsed_output, output)) return words + plain_dictation_regex = re.compile(r'(?<=#nonterm:dictation )(.*?)(?= #nonterm:end)') # lookbehind & lookahead assertions alternative_dictation_regex = re.compile(r'(?<=#nonterm:dictation_cloud )(.*?)(?= #nonterm:end)') # lookbehind & lookahead assertions def parse_output(self, output, dictation_info_func=None): @@ -659,10 +661,16 @@ def parse_output(self, output, dictation_info_func=None): kaldi_rule_id = int(nonterm_token[len('#nonterm:rule'):]) kaldi_rule = self.kaldi_rule_by_id_dict[kaldi_rule_id] - if self.alternative_dictation and dictation_info_func and kaldi_rule.has_dictation and '#nonterm:dictation_cloud' in parsed_output: + # Debug dictation settings + #print("DEBUG: ", self.alternative_dictation, "B", dictation_info_func, "C", kaldi_rule.has_dictation, "D", parsed_output) + + #if self.alternative_dictation and dictation_info_func and kaldi_rule.has_dictation and '#nonterm:dictation_cloud' in parsed_output: + if self.alternative_dictation and dictation_info_func and kaldi_rule.has_dictation and '#nonterm:dictation' in parsed_output: try: if callable(self.alternative_dictation): alternative_text_func = self.alternative_dictation + elif self.alternative_dictation == 'whisper': + alternative_text_func = whisper_dictation.Whisper.transcribe_data_sync else: raise TypeError("Invalid alternative_dictation value: %r" % self.alternative_dictation) @@ -677,7 +685,8 @@ def parse_output(self, output, dictation_info_func=None): 'offset_end': times[words.index('#nonterm:end', index)], } for index, (word, time, length) in enumerate(word_align) - if word.startswith('#nonterm:dictation_cloud')] + if word.startswith('#nonterm:dictation')] + #if word.startswith('#nonterm:dictation_cloud')] # If last dictation is at end of utterance, include rest of audio_data; else, include half of audio_data between dictation end and start of next word dictation_span = dictation_spans[-1] @@ -688,9 +697,17 @@ def parse_output(self, output, dictation_info_func=None): dictation_span['offset_end'] = (dictation_span['offset_end'] + next_word_time) // 2 def replace_dictation(matchobj): - orig_text = matchobj.group(1) + orig_text = matchobj.group(1) # "orig_text" holds the dictation result from Kaldi dictation. dictation_span = dictation_spans.pop(0) dictation_audio = audio_data[dictation_span['offset_start'] : dictation_span['offset_end']] + if self.alternative_dictation == 'whisper': + self.cloud_dictation_lang = "en-US" # FIXME: hardcoded language! + # Whisper dictation backend can take audio data in a wav file. + # Store a file in the system temp folder (this should work on Linux and Windows, and probably OS X) + #import tempfile + #temp_dir = tempfile.TemporaryDirectory().name + #audio_filename = os.path.join(temp_dir,"whisper.wav") + #whisper_dictation.write_wav('/tmp/whisper.wav', dictation_audio) kwargs = dict(language_code=self.cloud_dictation_lang) with debug_timer(self._log.debug, 'alternative_dictation call'): alternative_text = alternative_text_func(dictation_audio, **kwargs) @@ -699,6 +716,7 @@ def replace_dictation(matchobj): return (alternative_text or orig_text) parsed_output = self.alternative_dictation_regex.sub(replace_dictation, parsed_output) + parsed_output = self.plain_dictation_regex.sub(replace_dictation, parsed_output) except Exception as e: self._log.exception("Exception performing alternative dictation") diff --git a/kaldi_active_grammar/whisper_dictation.py b/kaldi_active_grammar/whisper_dictation.py new file mode 100644 index 0000000..0da3205 --- /dev/null +++ b/kaldi_active_grammar/whisper_dictation.py @@ -0,0 +1,108 @@ +# A crude way of using OpenAI Whisper for dictation in KaldiAG. +# This is the RPC client, that sends data to the local whisper RPC server process. +# By Shervin Emami (www.shervinemami.com) 2022 +# Based on "alternative_dictation.py" from KaldiAG v1.8, when KaldiAG had some basic support for GCloud dictation. +# +# KaldiAG is (c) Copyright 2019 by David Zurow +# Licensed under the AGPL-3.0; see LICENSE.txt file. +# + +# Compatibility between Python2 vs Python3: +from __future__ import print_function # print function with Python 2/3 compatibility +from __future__ import division + +import sys +if sys.version_info[0] == 3: + # Python3 + from xmlrpc.client import ServerProxy +else: + # Python2 + from xmlrpclib import ServerProxy +import wave + +verbose = False + +WHISPER_SERVER_ACCESS = "http://127.0.0.1:8002" # Where to find our whisper server. Note that Shervin's KaldiAG setup already runs RPC servers on ports 8000 and 8001 +whisper_client = ServerProxy(WHISPER_SERVER_ACCESS, allow_none=True) + +# Choose what to do if whisper dictation fails (eg: trouble connecting to our local whisper RPC server), +# Some users will want to return "None" so that their Kaldi or other dictation backend will perform the dictation without interrupting the user. +# But some users will want the entire speech engine to close, so that it's obvious when whisper didn't work. +EXIT_IF_WHISPER_FAILED = True + + +# Create a new process, for the whisper_server to run in the background. It expects "whisper_server.py" to be in the same folder as this Python file. +import subprocess +import os +pardir = os.path.abspath(os.path.join(__file__, os.pardir)) +whisper_server = os.path.abspath(os.path.join(pardir, "whisper_server.py")) +subprocess.Popen([sys.executable, whisper_server]) + + + +def write_wav(filename, audio_data, sample_rate=16000): + wf = wave.open(filename, 'wb') + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio_data) + wf.close() + + +def testCUDA(): + print("Test CUDA") + + import torch + # Making the code device-agnostic + device_name = 'cuda' if torch.cuda.is_available() else 'cpu' + if device_name == 'cuda': + print(f"CUDA version: {torch.version.cuda}") + print(f"Name of current CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") + + # Creating a test tensor + x = torch.randint(1, 100, (100, 1000)) + # Checking the device name: + # Should return 'cpu' by default + print("Default pytorch device (should be 'CPU'): ", x.device) + # Transferring tensor to GPU + x = x.to(torch.device(device_name)) + # Checking the device name: + # Should return 'cuda:0' + print("CUDA pytorch device (should be 'cuda:0'): ", x.device) + # Applying same GPU-accelerated tensor operation + res_gpu = x ** 2 + res_cpu = res_gpu.cpu() + print("result: ", res_cpu) + + + +class Whisper(object): + + # Use Whisper to convert the audio data into a text string. If speech_data is not given, will load the audio from a wav file. + @staticmethod + def transcribe_data_sync(speech_data=None, model='default', language_code='en-US'): + # It's possible that calling a GPU-accelerated PyTorch function within the KaldiAG Dragonfly process will cause Dragonfly's + # calls to xdotool via Text() can have quite long latency on Linux (~200ms instead of ~50ms per call!). So + # here (within the Dragonfly process) we will make an RPC interprocess call to our whisper process, that can be GPU-accelerated. + + # For debugging latency of GPU-accelerated PyTorch: + #testCUDA() + #return "words" + + try: + print("Calling the whisper_server RPC server.") + result = whisper_client.transcribe_using_whisper(speech_data, model, language_code) + if result: + return result + except Exception as e: + print("Warning: Exception ", e) + print("Couldn't access the whisper_server at", WHISPER_SERVER_ACCESS, ", is it running?") + + # If we've gotten to this line here, then whisper dictation failed. + if EXIT_IF_WHISPER_FAILED: + print("Exiting the speech recognition engine, since whisper failed.") + os.kill(os.getpid(), 9) + sys.exit(1) + + return None + diff --git a/kaldi_active_grammar/whisper_server.py b/kaldi_active_grammar/whisper_server.py new file mode 100755 index 0000000..e413b17 --- /dev/null +++ b/kaldi_active_grammar/whisper_server.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python + +# A crude way of using OpenAI Whisper for dictation in KaldiAG. +# This is the Whisper RPC server process. +# By Shervin Emami (www.shervinemami.com) 2022 +# Based on "alternative_dictation.py" from KaldiAG v1.8, when KaldiAG had some basic support for GCloud dictation. +# +# KaldiAG is (c) Copyright 2019 by David Zurow +# Licensed under the AGPL-3.0; see LICENSE.txt file. +# + +# Compatibility between Python2 vs Python3: +from __future__ import print_function # print function with Python 2/3 compatibility +from __future__ import division + +import sys +if sys.version_info[0] == 3: + # Python3 + from xmlrpc.server import SimpleXMLRPCServer + from xmlrpc.server import SimpleXMLRPCRequestHandler +else: + # Python2 + from SimpleXMLRPCServer import SimpleXMLRPCServer + from SimpleXMLRPCServer import SimpleXMLRPCRequestHandler + +from datetime import datetime +import subprocess +import time +from threading import Timer +import numpy as np +import tempfile +import os +import io + + +# Choose between Whisper models: "tiny.en", "base.en", "small.en", "medium.en" or "large". +# If you have a powerful GPU, try "medium.en". +# If you don't have a GPU, stick to "tiny.en" or "base.en". +# TODO: Allow passing the model name from the user +#model_filename = "tiny.en" +model_filename = "medium.en" + +# If the audio data is being transferred from Kaldi/Dragonfly to Whisper using a wav file, look for it in the system temp folder. +temp_dir = tempfile.TemporaryDirectory().name +audio_filename = os.path.join(temp_dir,"whisper.wav") + +# Whisper allows passing "prompt" that is intended to be the previous sentence or some similar related text, to give a hint +# about what it should expect. This includes formatting, so for example giving a hint of "40's" can push whisper closer to +# decoding the phrase "forties" as "40's" instead of "40s". +# Since I want a lot of commas but not fullstops or capitalising of phrases, I'm using a hint_prompt this way. +# TODO: Allow passing whisper args from the user script. +hint_prompt="oh OK yeah sure, in my 40's I mostly benchmarked a profile of ARM CPU core optimisation" + +verbose = False + +WHISPER_SERVER_ADDRESS = ("127.0.0.1", 8002) # Set up our server address. Note that Shervin's KaldiAG setup already runs RPC servers on ports 8000 and 8001 + + +try: + import whisper + whisper_imported = True +except ImportError: + whisper_imported = False +whisper_model = None +whisper_model_started_loading = False + + +# If you have a BlinkStick USB controlled RGB LED, then set this to True. +ENABLE_BLINKSTICK = True + +# BlinkStick USB LED +if ENABLE_BLINKSTICK: + try: + from blinkstick import blinkstick + bstick = blinkstick.find_first() + if bstick: + print("Found BlinkStick USB LED", bstick.get_serial()) + else: + print("Warning: Couldn't access the BlinkStick USB LED") + except: + bstick = None + +# Show the current mode, using the USB LED. +# args can be 'off', 'on', 'disabled' or 'sleeping'. +def updateLED(args, grammarMode = "Normal"): + if ENABLE_BLINKSTICK: + try: + #print("In updateLED ", args, grammarMode) + if bstick: + V = 5 # LED Brightness upto 255 + if args == "on": + # Set my BlinkStick LED to green (ON, Normal mode) or blue (ON, Command mode) + if grammarMode == "Normal": + bstick.set_color(red=0, green=V, blue=0) + elif grammarMode == "Yellow": + bstick.set_color(red=V, green=V, blue=0) + elif grammarMode == "Pink": + bstick.set_color(red=V*1.2, green=V/3, blue=V/2.5) + elif grammarMode == "BlueGreen": + bstick.set_color(red=1, green=9, blue=3) + else: + bstick.set_color(red=0, green=0, blue=V*1.2) + elif args == "disabled": + # Set my BlinkStick LED to red (disabled) + bstick.set_color(red=V*2, green=0, blue=0) + elif args == "sleeping": + # Set my BlinkStick LED to purple (sleeping) + bstick.set_color(red=1, green=0, blue=0) + elif args == "off": + # Set my BlinkStick LED to black (off) + bstick.set_color(red=0, green=0, blue=0) + except: + print("Warning: Couldn't access the BlinkStick USB LED") + pass + + +def load_whisper_model(): + global model_filename + global whisper_model + global whisper_model_started_loading + if not whisper_model_started_loading: + whisper_model_started_loading = True # Block other RPC threads from loading the model too + print(datetime.now(), "[Loading whisper pytorch model '" + model_filename + "' during startup. This can take a long time!]") + updateLED("on", "Pink") + whisper_model = whisper.load_model(model_filename) + if whisper_model: + print(datetime.now(), "[Finished loading whisper model]") + updateLED("on", "Command") # Assume that the user is going back to Command-mode after being in Dictation-mode. + + # We should expect the first transciption to be slower than usual, since the GPU must load drivers and ramp up its clocks and perhaps other things. + # So let's perform a dummy transcription now, to preload everything needed for fast dictation. + #transcribe_using_whisper(None) + else: + # Stall here until the whisper model has been loaded (by another RPC thread running in parallel) + print(datetime.now(), "[whisper model is already being loaded. Waiting until it is ready]") + while not whisper_model: + print(".") + time.sleep(0.3) # Sleep 0.3 seconds before trying again + +# Give the rest of the speech recognition system a few seconds to do heavy initialisation steps, before we load our heavy whisper model. +Timer(3.0, load_whisper_model).start() + + + +# Decode the audio. +# "decode" will use GreedyDecoder (fast) if beam_size=None, or it will use BeamSearch (slower but more reliable) if beam_size=5 or similar. +# Set fp16=True for RTX GPUs, or False for GTX GPUs. Because GTX & older GPUs are extremely fast at FP32 but terrible at FP16, +# whereas new GPUs such as RTX GPUs are extremely fast at both FP32 and FP16, and in fact slightly faster at FP16. +# Set beam_size=5 if you can handle the slower speed, or use beam_size of 0-3 if you want faster results, but with more chance that whisper +# will get stuck in a repetetive mental loop for a while. +# TODO: Allow passing whisper args from the user script. +options = whisper.DecodingOptions(language="en", fp16=False, prompt=hint_prompt, best_of=None, beam_size=3, temperature=0.0, patience=1.3) + + +def testCUDA(): + print("Test CUDA") + + import torch + + # Making the code device-agnostic + device_name = 'cuda' if torch.cuda.is_available() else 'cpu' + if device_name == 'cuda': + print(f"CUDA version: {torch.version.cuda}") + print(f"Name of current CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") + + # Creating a test tensor + x = torch.randint(1, 100, (100, 1000)) + # Checking the device name: + # Should return 'cpu' by default + print("Default pytorch device (should be 'CPU'): ", x.device) + # Transferring tensor to GPU + x = x.to(torch.device(device_name)) + # Checking the device name: + # Should return 'cuda:0' + print("CUDA pytorch device (should be 'cuda:0'): ", x.device) + # Applying same GPU-accelerated tensor operation + res_gpu = x ** 2 + res_cpu = res_gpu.cpu() + print("result: ", res_cpu) + + +# Use Whisper to convert the audio data into a text string. If speech_data is not given, will load the audio from a wav file. +def transcribe_using_whisper(speech_data=None, model='default', language_code='en-US'): + transcript = "Unknown" + # Wrap our code in a big try/exception block, to make debugging RPCs more clear. + try: + + # For debugging latency of GPU-accelerated PyTorch: + #testCUDA() + #return "words" + + if not whisper_imported: + _log.error("Cannot find one of the Whisper packages!") + return None + + # Allow to lazily load the model upon first actual use, so that startup is fast for times when we just want command-mode, not whisper dictation. + global whisper_model + if not whisper_model: + load_whisper_model() + + updateLED("on", "Normal") + start_inference = time.perf_counter() + + # Load the audio data. + # Whisper is much faster when using 'model.decode' instead of 'model.transcribe'. See "https://github.com/openai/whisper/discussions/391" + # So instead of simply calling transcribe(filename), we will load the audio, pad it to 30 seconds, generate the log-mel spectogram, copy the data to GPU, then decode the audio. + audio = None + try: + #result = whisper_model.transcribe(audio_filename, language='english') + if speech_data: + audio = np.frombuffer(speech_data.data, np.int16).flatten().astype(np.float32) / 32768.0 + else: + audio = whisper.load_audio(audio_filename) + except Exception as e: + print(datetime.now(), "[Exception!:", e, "]") + audio = None + + if not isinstance(audio, np.ndarray): + # We couldn't load the audio file, so use an empty buffer of silence. + audio = np.zeros((10000), np.float32) + + # Whisper only works on 30 second audio segments. + audio = whisper.pad_or_trim(audio) + + # Make log-mel spectrogram and move it to the same device as the model (GPU) + mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device) + + # Decode the audio. + result = whisper.decode(whisper_model, mel, options) + if result: + transcript = result.text + + elapsed_inference = time.perf_counter() - start_inference + + print(datetime.now(), u"[According to OpenAI Whisper after {}s, you said: {}".format(elapsed_inference, transcript)) + if verbose: + print(datetime.now(), "[", result, "]") + updateLED("on", "Command") # Assume that the user is going back to Command-mode after being in Dictation-mode. + except Exception as e: + print(datetime.now(), "[Exception raised: ", e, "]") + return None + + return transcript + + + +# Restrict XMLRPC server to a particular path. +class RequestHandler(SimpleXMLRPCRequestHandler): + rpc_paths = ('/RPC2',) + +def setup_xmlrpc_server(): + server_quit = 0 + print(datetime.now(), "[Setting up the whisper_server XMLRPC server at", WHISPER_SERVER_ADDRESS, "]") + whisperServer = SimpleXMLRPCServer(WHISPER_SERVER_ADDRESS, requestHandler=RequestHandler, allow_none=True) + whisperServer.register_function(xmlrpc_kill, "kill") + whisperServer.register_function(transcribe_using_whisper, "transcribe_using_whisper") + #TODO: Disable this for security when not debugging: + #whisperServer.register_introspection_functions() + return whisperServer + + +def xmlrpc_kill(): + print(datetime.now(), "[XMLRPC whisper_server received kill event]") + after(2, die) + +def die(): + print(datetime.now(), "[Closing the whisper_server]") + server_quit = 1 + temp_dir.cleanup() # Remove the tmp audio file that we created. + os.kill(os.getpid(), 9) + sys.exit() + + +def whisper_server_main(args): + print(datetime.now(), "[whisper_server process has started]") + + whisperServer = setup_xmlrpc_server() + + # Get the XMLRPC server to start in the background quite soon + server_quit = 0 + def start_server(): + while not server_quit: + whisperServer._handle_request_noblock() + #print(".") # Show that another request has been handled. + + Timer(0.3, start_server).start() + + print(datetime.now(), "[whisper_server is ready]") + + # Run forever ... + + +if __name__ == "__main__": + #print(datetime.now(), "[whisper_server is in __main__()]") + whisper_server_main(sys.argv[1:]) + #print(datetime.now(), "[whisper_server is ending __main__()]") + +#print(datetime.now(), "[whisper_server is at end of script.]") +