Skip to content

Commit

Permalink
Remove TensorRT Execution Provider from available providers in ASR, T…
Browse files Browse the repository at this point in the history
…TS, and VAD modules
  • Loading branch information
dnhkng committed Jan 12, 2025
1 parent 60e5fa3 commit 77cff97
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
6 changes: 5 additions & 1 deletion glados/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ def __init__(
sample_rate: int = SAMPLE_RATE,
) -> None:
self.sample_rate = sample_rate

providers = ort.get_available_providers()
if "TensorrtExecutionProvider" in providers:
providers.remove("TensorrtExecutionProvider")

self.session = ort.InferenceSession(
model_path,
sess_options=ort.SessionOptions(),
providers=ort.get_available_providers(),
providers=providers,
)
self.vocab = self._load_vocabulary(tokens_file)

Expand Down
6 changes: 5 additions & 1 deletion glados/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,14 @@ class Synthesizer:
"""

def __init__(self, model_path: str, speaker_id: Optional[int] = None):
providers = ort.get_available_providers()
if "TensorrtExecutionProvider" in providers:
providers.remove("TensorrtExecutionProvider")

self.session = ort.InferenceSession(
model_path,
sess_options=ort.SessionOptions(),
providers=ort.get_available_providers(),
providers=providers,
)
self.phonemizer = phonemizer.Phonemizer()
# self.id_map = PHONEME_ID_MAP
Expand Down
6 changes: 5 additions & 1 deletion glados/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ class VAD:
_initial_c = np.zeros((2, 1, 64)).astype("float32")

def __init__(self, model_path, window_size_samples: int = int(SAMPLE_RATE / 10)):
providers = ort.get_available_providers()
if "TensorrtExecutionProvider" in providers:
providers.remove("TensorrtExecutionProvider")

self.ort_sess = ort.InferenceSession(
model_path,
sess_options=ort.SessionOptions(),
providers=ort.get_available_providers(),
providers=providers,
)
self.window_size_samples = window_size_samples
self.sr = SAMPLE_RATE
Expand Down

0 comments on commit 77cff97

Please sign in to comment.