diff --git a/speechmetrics/relative/__init__.py b/speechmetrics/relative/__init__.py index bf2e691..326ac5a 100644 --- a/speechmetrics/relative/__init__.py +++ b/speechmetrics/relative/__init__.py @@ -1,3 +1,4 @@ from . import bsseval from . import pesq from . import stoi +from . import estoi diff --git a/speechmetrics/relative/estoi.py b/speechmetrics/relative/estoi.py new file mode 100644 index 0000000..0382a97 --- /dev/null +++ b/speechmetrics/relative/estoi.py @@ -0,0 +1,10 @@ +from .stoi import STOI + + +class ESTOI(STOI): + def __init__(self, *args, **kwargs): + super(ESTOI, self).__init__(*args, **kwargs, estoi=True) + + +def load(window, hop=None): + return ESTOI(window, hop) diff --git a/speechmetrics/relative/stoi.py b/speechmetrics/relative/stoi.py index 38a248f..fbde602 100644 --- a/speechmetrics/relative/stoi.py +++ b/speechmetrics/relative/stoi.py @@ -2,16 +2,19 @@ class STOI(Metric): - def __init__(self, window, hop=None): - super(STOI, self).__init__(name='STOI', window=window, hop=hop) + def __init__(self, window, hop=None, estoi=False): + name = 'ESTOI' if estoi else 'STOI' + super(STOI, self).__init__(name=name, window=window, hop=hop) self.mono = True + self.estoi = estoi def test_window(self, audios, rate): from pystoi.stoi import stoi if len(audios) != 2: raise ValueError('STOI needs a reference and a test signals.') - return {'stoi':stoi(audios[1], audios[0], rate, extended=False)} + key = 'estoi' if self.estoi else 'stoi' + return {key: stoi(audios[1], audios[0], rate, extended=self.estoi)} def load(window, hop=None):