Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class EvaluationConfig(transcribe_speech.TranscriptionConfig):
separate_punctuation=False,
do_lowercase=False,
rm_punctuation=False,
substitutions="",
)
)

Expand Down Expand Up @@ -154,7 +155,7 @@ def main(cfg: EvaluationConfig):

predicted_text.append(data["pred_text"])

pc = PunctuationCapitalization(cfg.text_processing.punctuation_marks)
pc = PunctuationCapitalization(cfg.text_processing.punctuation_marks, cfg.text_processing.substitutions)
if cfg.text_processing.separate_punctuation:
ground_truth_text = pc.separate_punctuation(ground_truth_text)
predicted_text = pc.separate_punctuation(predicted_text)
Expand All @@ -164,6 +165,9 @@ def main(cfg: EvaluationConfig):
if cfg.text_processing.rm_punctuation:
ground_truth_text = pc.rm_punctuation(ground_truth_text)
predicted_text = pc.rm_punctuation(predicted_text)
if cfg.text_processing.substitutions:
ground_truth_text = pc.substitute_equivalents(ground_truth_text)
predicted_text = pc.substitute_equivalents(predicted_text)

# Test for invalid manifest supplied
if invalid_manifest:
Expand Down
41 changes: 40 additions & 1 deletion nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,20 +637,27 @@ def compute_metrics_per_sample(


class PunctuationCapitalization:
def __init__(self, punctuation_marks: str):
def __init__(self, punctuation_marks: str, substitutions: str):
"""
Class for text processing with punctuation and capitalization. Can be used with class TextProcessingConfig.

Args:
punctuation_marks (str): String with punctuation marks to process.
substitutions (str): String of equivalencies to substitute, separated with `;`.
Example: punctuation_marks = '.,?'
substitutions = "fi~fi"
"""
if punctuation_marks:
self.regex_punctuation = re.compile(fr"([{''.join(punctuation_marks)}])")
self.regex_extra_space = re.compile(r'\s{2,}')
else:
self.regex_punctuation = None

if substitutions:
self.substitutions = self._parse_substitutions(substitutions)
else:
self.substitutions = None

def separate_punctuation(self, lines: List[str]) -> List[str]:
if self.regex_punctuation is not None:
return [
Expand All @@ -668,6 +675,35 @@ def rm_punctuation(self, lines: List[str]) -> List[str]:
else:
return lines

def substitute_equivalents(self, lines: List[str]) -> List[str]:
if self.substitutions is not None:
return [line.replace(orig, sub) for orig, sub in self.substitutions.items() for line in lines]
else:
return lines

@staticmethod
def _parse_substitutions(s: str) -> dict[str, str]:
"""
Parse substitutions from a string: "src~dst;src2~dst2;..."

Supports either literal Unicode (preferred) or escape sequences like "\\u0587".
"""
# Decode \uXXXX / \UXXXXXXXX only when present (avoid surprising behavior).
decode = lambda t: t.encode("utf-8").decode("unicode_escape") if ("\\u" in t or "\\U" in t) else t

subs = {}
for raw in (p.strip() for p in s.split(";")):
if not raw:
continue
if "~" not in raw:
raise ValueError(f"Invalid substitution '{raw}'. Expected 'SRC~DST'.")
src, dst = (x.strip() for x in raw.split("~", 1))
if not src:
raise ValueError(f"Invalid substitution '{raw}'. SRC must be non-empty.")
subs[decode(src)] = decode(dst)

return subs


@dataclass
class TextProcessingConfig:
Expand All @@ -682,3 +718,6 @@ class TextProcessingConfig:

# Whether to separate punctuation with the previouse word by space.
separate_punctuation: bool = True

# Characters (or combinations of) to treat as equivalent
substitutions: str = ""