Skip to content

Commit

Permalink
add extract_token_set (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay authored Feb 23, 2024
1 parent 81dd4ae commit c198b37
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
19 changes: 17 additions & 2 deletions mltb2/somajo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from abc import ABC
from dataclasses import dataclass, field
from typing import Container, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
from typing import Dict, Iterable, List, Literal, Optional, Set, Tuple, Union

from somajo import SoMaJo
from tqdm import tqdm
Expand Down Expand Up @@ -62,7 +62,7 @@ def detokenize(tokens) -> str:
return result


def extract_token_class_set(sentences: Iterable, keep_token_classes: Optional[Container[str]] = None) -> Set[str]:
def extract_token_class_set(sentences: Iterable, keep_token_classes: Optional[str] = None) -> Set[str]:
"""Extract token from sentences by token class.
Args:
Expand Down Expand Up @@ -187,6 +187,21 @@ def extract_url_set(self, text: Union[Iterable, str]) -> Set[str]:
result = extract_token_class_set(sentences, keep_token_classes="URL")
return result

def extract_token_set(self, text: Union[Iterable, str], keep_token_classes: Optional[str] = None) -> Set[str]:
"""Extract tokens from text.
Args:
text: the text
keep_token_classes: The token classes to keep. If ``None`` all will be kept.
Returns:
Set of tokens.
"""
if isinstance(text, str):
text = [text]
sentences = self.somajo.tokenize_text(text)
result = extract_token_class_set(sentences, keep_token_classes=keep_token_classes)
return result


@dataclass
class UrlSwapper:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_somajo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def test_JaccardSimilarity_call_no_overlap(): # noqa: N802
assert isclose(result, 0.0)


def test_TokenExtractor_extract_token_set(): # noqa: N802
text = "Das ist ein Text. Er enthält keine URL."
token_extractor = TokenExtractor("de_CMC")
result = token_extractor.extract_token_set(text)
assert len(result) == 9
assert "Das" in result
assert "." in result


def test_TokenExtractor_extract_url_set_with_str(): # noqa: N802
url1 = "http://may.la"
url2 = "github.com"
Expand Down

0 comments on commit c198b37

Please sign in to comment.