Skip to content

Commit 62ed04d

Browse files
Add support for OpenAI Whisper API (#371)
1 parent 3a59172 commit 62ed04d

File tree

7 files changed

+1259
-747
lines changed

7 files changed

+1259
-747
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ OpenAI's [Whisper](https://github.com/openai/whisper).
1919
VTT ([Demo](https://www.loom.com/share/cf263b099ac3481082bb56d19b7c87fe))
2020
- Supports [Whisper](https://github.com/openai/whisper#available-models-and-languages),
2121
[Whisper.cpp](https://github.com/ggerganov/whisper.cpp),
22-
and [Whisper-compatible Hugging Face models](https://huggingface.co/models?other=whisper)
22+
[Whisper-compatible Hugging Face models](https://huggingface.co/models?other=whisper), and
23+
the [OpenAI Whisper API](https://platform.openai.com/docs/api-reference/introduction)
24+
- Available on Mac, Windows, and Linux
2325

2426
## Installation
2527

buzz/gui.py

+46-18
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
import logging
55
import os
66
import platform
7-
import random
87
import sys
98
from datetime import datetime
109
from enum import auto
11-
from typing import Dict, List, Optional, Tuple, Union
10+
from typing import Dict, List, Optional, Tuple
1211

1312
import humanize
1413
import sounddevice
@@ -33,7 +32,6 @@
3332
from .recording import RecordingAmplitudeListener
3433
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
3534
Task,
36-
WhisperCppFileTranscriber, WhisperFileTranscriber,
3735
get_default_output_file_path, segments_to_text, write_output, TranscriptionOptions,
3836
FileTranscriberQueueWorker, FileTranscriptionTask, RecordingTranscriber, LOADED_WHISPER_DLL)
3937

@@ -217,24 +215,23 @@ def show_model_download_error_dialog(parent: QWidget, error: str):
217215

218216
class FileTranscriberWidget(QWidget):
219217
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
220-
file_transcriber: Optional[Union[WhisperFileTranscriber,
221-
WhisperCppFileTranscriber]] = None
222218
model_loader: Optional[ModelLoader] = None
223219
transcriber_thread: Optional[QThread] = None
224220
file_transcription_options: FileTranscriptionOptions
225221
transcription_options: TranscriptionOptions
226222
is_transcribing = False
227223
# (TranscriptionOptions, FileTranscriptionOptions, str)
228224
triggered = pyqtSignal(tuple)
225+
openai_access_token_changed = pyqtSignal(str)
229226

230-
def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None,
231-
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
227+
def __init__(self, file_paths: List[str], openai_access_token: Optional[str] = None,
228+
parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
232229
super().__init__(parent, flags)
233230

234231
self.setWindowTitle(file_paths_as_title(file_paths))
235232

236233
self.file_paths = file_paths
237-
self.transcription_options = TranscriptionOptions()
234+
self.transcription_options = TranscriptionOptions(openai_access_token=openai_access_token)
238235
self.file_transcription_options = FileTranscriptionOptions(
239236
file_paths=self.file_paths)
240237

@@ -266,7 +263,9 @@ def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None,
266263
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
267264
self.transcription_options = transcription_options
268265
self.word_level_timings_checkbox.setDisabled(
269-
self.transcription_options.model.model_type == ModelType.HUGGING_FACE)
266+
self.transcription_options.model.model_type == ModelType.HUGGING_FACE or self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API)
267+
if self.transcription_options.openai_access_token is not None:
268+
self.openai_access_token_changed.emit(self.transcription_options.openai_access_token)
270269

271270
def on_click_run(self):
272271
self.run_button.setDisabled(True)
@@ -503,7 +502,10 @@ def __init__(self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowTy
503502
self.text_box.setPlaceholderText(_('Click Record to begin...'))
504503

505504
transcription_options_group_box = TranscriptionOptionsGroupBox(
506-
default_transcription_options=self.transcription_options, parent=self)
505+
default_transcription_options=self.transcription_options,
506+
# Live transcription with OpenAI Whisper API not implemented
507+
model_types=[model_type for model_type in ModelType if model_type is not ModelType.OPEN_AI_WHISPER_API],
508+
parent=self)
507509
transcription_options_group_box.transcription_options_changed.connect(
508510
self.on_transcription_options_changed)
509511

@@ -820,7 +822,7 @@ def upsert_task(self, task: FileTranscriptionTask):
820822
elif task.status == FileTranscriptionTask.Status.COMPLETED:
821823
status_widget.setText(_('Completed'))
822824
elif task.status == FileTranscriptionTask.Status.FAILED:
823-
status_widget.setText(_('Failed'))
825+
status_widget.setText(f'{_("Failed")} ({task.error})')
824826
elif task.status == FileTranscriptionTask.Status.CANCELED:
825827
status_widget.setText(_('Canceled'))
826828

@@ -925,6 +927,7 @@ class MainWindow(QMainWindow):
925927
table_widget: TranscriptionTasksTableWidget
926928
tasks: Dict[int, 'FileTranscriptionTask']
927929
tasks_changed = pyqtSignal()
930+
openai_access_token: Optional[str] = None
928931

929932
def __init__(self, tasks_cache=TasksCache()):
930933
super().__init__(flags=Qt.WindowType.Window)
@@ -1026,11 +1029,17 @@ def on_new_transcription_action_triggered(self):
10261029
return
10271030

10281031
file_transcriber_window = FileTranscriberWidget(
1029-
file_paths, self, flags=Qt.WindowType.Window)
1032+
file_paths, self.openai_access_token, self, flags=Qt.WindowType.Window)
10301033
file_transcriber_window.triggered.connect(
10311034
self.on_file_transcriber_triggered)
1035+
file_transcriber_window.openai_access_token_changed.connect(self.on_openai_access_token_changed)
10321036
file_transcriber_window.show()
10331037

1038+
# Save the access token on the main window so the user doesn't need to re-enter (at least, not while the app is
1039+
# still open)
1040+
def on_openai_access_token_changed(self, access_token: str):
1041+
self.openai_access_token = access_token
1042+
10341043
def on_open_transcript_action_triggered(self):
10351044
selected_rows = self.table_widget.selectionModel().selectedRows()
10361045
for selected_row in selected_rows:
@@ -1092,6 +1101,7 @@ def on_tasks_changed(self):
10921101
self.toolbar.set_open_transcript_action_enabled(self.should_enable_open_transcript_action())
10931102
self.toolbar.set_stop_transcription_action_enabled(self.should_enable_stop_transcription_action())
10941103
self.toolbar.set_clear_history_action_enabled(self.should_enable_clear_history_action())
1104+
self.save_tasks_to_cache()
10951105

10961106
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
10971107
self.transcriber_worker.stop()
@@ -1236,6 +1246,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
12361246
transcription_options_changed = pyqtSignal(TranscriptionOptions)
12371247

12381248
def __init__(self, default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
1249+
model_types: Optional[List[ModelType]] = None,
12391250
parent: Optional[QWidget] = None):
12401251
super().__init__(title='', parent=parent)
12411252
self.transcription_options = default_transcription_options
@@ -1261,7 +1272,9 @@ def __init__(self, default_transcription_options: TranscriptionOptions = Transcr
12611272
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)
12621273

12631274
self.model_type_combo_box = QComboBox(self)
1264-
for model_type in ModelType:
1275+
if model_types is None:
1276+
model_types = [model_type for model_type in ModelType]
1277+
for model_type in model_types:
12651278
# Hide Whisper.cpp option is whisper.dll did not load correctly.
12661279
# See: https://github.com/chidiwilliams/buzz/issues/274, https://github.com/chidiwilliams/buzz/issues/197
12671280
if model_type == ModelType.WHISPER_CPP and LOADED_WHISPER_DLL is False:
@@ -1277,18 +1290,28 @@ def __init__(self, default_transcription_options: TranscriptionOptions = Transcr
12771290
default_transcription_options.model.whisper_model_size.value.title())
12781291
self.whisper_model_size_combo_box.currentTextChanged.connect(self.on_whisper_model_size_changed)
12791292

1280-
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
1281-
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
1293+
self.openai_access_token_edit = QLineEdit(self)
1294+
self.openai_access_token_edit.setText(default_transcription_options.openai_access_token)
1295+
self.openai_access_token_edit.setEchoMode(QLineEdit.EchoMode.Password)
1296+
self.openai_access_token_edit.textChanged.connect(self.on_openai_access_token_edit_changed)
1297+
12821298
self.form_layout.addRow(_('Model:'), self.model_type_combo_box)
12831299
self.form_layout.addRow('', self.whisper_model_size_combo_box)
12841300
self.form_layout.addRow('', self.hugging_face_search_line_edit)
1301+
self.form_layout.addRow('Access Token:', self.openai_access_token_edit)
1302+
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
1303+
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
12851304

1286-
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, False)
1305+
self.reset_visible_rows()
12871306

12881307
self.form_layout.addRow('', self.advanced_settings_button)
12891308

12901309
self.setLayout(self.form_layout)
12911310

1311+
def on_openai_access_token_edit_changed(self, access_token: str):
1312+
self.transcription_options.openai_access_token = access_token
1313+
self.transcription_options_changed.emit(self.transcription_options)
1314+
12921315
def on_language_changed(self, language: str):
12931316
self.transcription_options.language = language
12941317
self.transcription_options_changed.emit(self.transcription_options)
@@ -1316,12 +1339,17 @@ def on_transcription_options_changed(self, transcription_options: TranscriptionO
13161339
self.transcription_options = transcription_options
13171340
self.transcription_options_changed.emit(transcription_options)
13181341

1319-
def on_model_type_changed(self, text: str):
1320-
model_type = ModelType(text)
1342+
def reset_visible_rows(self):
1343+
model_type = self.transcription_options.model.model_type
13211344
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE)
13221345
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
13231346
(model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP))
1347+
self.form_layout.setRowVisible(self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API)
1348+
1349+
def on_model_type_changed(self, text: str):
1350+
model_type = ModelType(text)
13241351
self.transcription_options.model.model_type = model_type
1352+
self.reset_visible_rows()
13251353
self.transcription_options_changed.emit(self.transcription_options)
13261354

13271355
def on_whisper_model_size_changed(self, text: str):

buzz/model_loader.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class ModelType(enum.Enum):
2626
WHISPER = 'Whisper'
2727
WHISPER_CPP = 'Whisper.cpp'
2828
HUGGING_FACE = 'Hugging Face'
29+
OPEN_AI_WHISPER_API = 'OpenAI Whisper API'
2930

3031

3132
@dataclass()
@@ -82,7 +83,7 @@ def run(self):
8283
expected_sha256 = url.split('/')[-2]
8384
self.download_model(url, file_path, expected_sha256)
8485

85-
else: # ModelType.HUGGING_FACE:
86+
elif self.model_type == ModelType.HUGGING_FACE:
8687
self.progress.emit((0, 100))
8788

8889
try:
@@ -95,6 +96,12 @@ def run(self):
9596
self.progress.emit((100, 100))
9697
file_path = self.hugging_face_model_id
9798

99+
elif self.model_type == ModelType.OPEN_AI_WHISPER_API:
100+
file_path = ""
101+
102+
else:
103+
raise Exception("Invalid model type: " + self.model_type.value)
104+
98105
self.finished.emit(file_path)
99106

100107
def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]):

buzz/transcriber.py

+72-12
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
import sys
1313
import tempfile
1414
import threading
15+
from abc import ABC, abstractmethod
1516
from dataclasses import dataclass, field
1617
from multiprocessing.connection import Connection
1718
from random import randint
1819
from threading import Thread
1920
from typing import Any, List, Optional, Tuple, Union
21+
import openai
2022

2123
import ffmpeg
2224
import numpy as np
@@ -61,10 +63,11 @@ class Segment:
6163
class TranscriptionOptions:
6264
language: Optional[str] = None
6365
task: Task = Task.TRANSCRIBE
64-
model: TranscriptionModel = TranscriptionModel()
66+
model: TranscriptionModel = field(default_factory=TranscriptionModel)
6567
word_level_timings: bool = False
6668
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
6769
initial_prompt: str = ''
70+
openai_access_token: Optional[str] = None
6871

6972

7073
@dataclass()
@@ -219,17 +222,34 @@ class OutputFormat(enum.Enum):
219222
VTT = 'vtt'
220223

221224

222-
class WhisperCppFileTranscriber(QObject):
225+
class FileTranscriber(QObject):
226+
transcription_task: FileTranscriptionTask
223227
progress = pyqtSignal(tuple) # (current, total)
224228
completed = pyqtSignal(list) # List[Segment]
225229
error = pyqtSignal(str)
230+
231+
def __init__(self, task: FileTranscriptionTask,
232+
parent: Optional['QObject'] = None):
233+
super().__init__(parent)
234+
self.transcription_task = task
235+
236+
@abstractmethod
237+
def run(self):
238+
...
239+
240+
@abstractmethod
241+
def stop(self):
242+
...
243+
244+
245+
class WhisperCppFileTranscriber(FileTranscriber):
226246
duration_audio_ms = sys.maxsize # max int
227247
segments: List[Segment]
228248
running = False
229249

230250
def __init__(self, task: FileTranscriptionTask,
231251
parent: Optional['QObject'] = None) -> None:
232-
super().__init__(parent)
252+
super().__init__(task, parent)
233253

234254
self.file_path = task.file_path
235255
self.language = task.transcription_options.language
@@ -332,22 +352,60 @@ def read_std_err(self):
332352
pass
333353

334354

335-
class WhisperFileTranscriber(QObject):
355+
class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
356+
def __init__(self, task: FileTranscriptionTask, parent: Optional['QObject'] = None):
357+
super().__init__(task=task, parent=parent)
358+
self.file_path = task.file_path
359+
self.task = task.transcription_options.task
360+
361+
@pyqtSlot()
362+
def run(self):
363+
try:
364+
logging.debug('Starting OpenAI Whisper API file transcription, file path = %s, task = %s', self.file_path,
365+
self.task)
366+
367+
wav_file = tempfile.mktemp() + '.wav'
368+
(
369+
ffmpeg.input(self.file_path)
370+
.output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE)
371+
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
372+
)
373+
374+
# TODO: Check if file size is more than 25MB (2.5 minutes), then chunk
375+
audio_file = open(wav_file, "rb")
376+
openai.api_key = self.transcription_task.transcription_options.openai_access_token
377+
language = self.transcription_task.transcription_options.language
378+
response_format = "verbose_json"
379+
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
380+
transcript = openai.Audio.translate("whisper-1", audio_file, response_format=response_format,
381+
language=language)
382+
else:
383+
transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format=response_format,
384+
language=language)
385+
386+
segments = [Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for segment in
387+
transcript["segments"]]
388+
self.completed.emit(segments)
389+
except Exception as exc:
390+
self.error.emit(str(exc))
391+
logging.exception('')
392+
393+
def stop(self):
394+
pass
395+
396+
397+
class WhisperFileTranscriber(FileTranscriber):
336398
"""WhisperFileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file
337399
using the default program for opening txt files. """
338400

339401
current_process: multiprocessing.Process
340-
progress = pyqtSignal(tuple) # (current, total)
341-
completed = pyqtSignal(list) # List[Segment]
342-
error = pyqtSignal(str)
343402
running = False
344403
read_line_thread: Optional[Thread] = None
345404
READ_LINE_THREAD_STOP_TOKEN = '--STOP--'
346405

347406
def __init__(self, task: FileTranscriptionTask,
348407
parent: Optional['QObject'] = None) -> None:
349-
super().__init__(parent)
350-
self.transcription_task = task
408+
super().__init__(task, parent)
351409
self.segments = []
352410
self.started_process = False
353411
self.stopped = False
@@ -570,8 +628,7 @@ def __del__(self):
570628
class FileTranscriberQueueWorker(QObject):
571629
tasks_queue: multiprocessing.Queue
572630
current_task: Optional[FileTranscriptionTask] = None
573-
current_transcriber: Optional[WhisperFileTranscriber |
574-
WhisperCppFileTranscriber] = None
631+
current_transcriber: Optional[FileTranscriber] = None
575632
current_transcriber_thread: Optional[QThread] = None
576633
task_updated = pyqtSignal(FileTranscriptionTask)
577634
completed = pyqtSignal()
@@ -605,9 +662,12 @@ def run(self):
605662

606663
logging.debug('Starting next transcription task')
607664

608-
if self.current_task.transcription_options.model.model_type == ModelType.WHISPER_CPP:
665+
model_type = self.current_task.transcription_options.model.model_type
666+
if model_type == ModelType.WHISPER_CPP:
609667
self.current_transcriber = WhisperCppFileTranscriber(
610668
task=self.current_task)
669+
elif model_type == ModelType.OPEN_AI_WHISPER_API:
670+
self.current_transcriber = OpenAIWhisperAPIFileTranscriber(task=self.current_task)
611671
else:
612672
self.current_transcriber = WhisperFileTranscriber(
613673
task=self.current_task)

0 commit comments

Comments
 (0)