From c198b377232fc0eb8dffd016d1bd47afeb395ec1 Mon Sep 17 00:00:00 2001 From: Philip May Date: Fri, 23 Feb 2024 12:41:40 +0100 Subject: [PATCH] add extract_token_set (#152) --- mltb2/somajo.py | 19 +++++++++++++++++-- tests/test_somajo.py | 9 +++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mltb2/somajo.py b/mltb2/somajo.py index 4fd7f88..7d17677 100644 --- a/mltb2/somajo.py +++ b/mltb2/somajo.py @@ -12,7 +12,7 @@ from abc import ABC from dataclasses import dataclass, field -from typing import Container, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Literal, Optional, Set, Tuple, Union from somajo import SoMaJo from tqdm import tqdm @@ -62,7 +62,7 @@ def detokenize(tokens) -> str: return result -def extract_token_class_set(sentences: Iterable, keep_token_classes: Optional[Container[str]] = None) -> Set[str]: +def extract_token_class_set(sentences: Iterable, keep_token_classes: Optional[str] = None) -> Set[str]: """Extract token from sentences by token class. Args: @@ -187,6 +187,21 @@ def extract_url_set(self, text: Union[Iterable, str]) -> Set[str]: result = extract_token_class_set(sentences, keep_token_classes="URL") return result + def extract_token_set(self, text: Union[Iterable, str], keep_token_classes: Optional[str] = None) -> Set[str]: + """Extract tokens from text. + + Args: + text: the text + keep_token_classes: The token classes to keep. If ``None`` all will be kept. + Returns: + Set of tokens. + """ + if isinstance(text, str): + text = [text] + sentences = self.somajo.tokenize_text(text) + result = extract_token_class_set(sentences, keep_token_classes=keep_token_classes) + return result + @dataclass class UrlSwapper: diff --git a/tests/test_somajo.py b/tests/test_somajo.py index 0853e10..fe7ecc8 100644 --- a/tests/test_somajo.py +++ b/tests/test_somajo.py @@ -69,6 +69,15 @@ def test_JaccardSimilarity_call_no_overlap(): # noqa: N802 assert isclose(result, 0.0) +def test_TokenExtractor_extract_token_set(): # noqa: N802 + text = "Das ist ein Text. Er enthält keine URL." + token_extractor = TokenExtractor("de_CMC") + result = token_extractor.extract_token_set(text) + assert len(result) == 9 + assert "Das" in result + assert "." in result + + def test_TokenExtractor_extract_url_set_with_str(): # noqa: N802 url1 = "http://may.la" url2 = "github.com"