Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions ignite/metrics/nlp/bleu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from typing import Any, Callable, Sequence, Tuple, Union
from collections.abc import Callable, Sequence
from typing import Any

import torch
from torch import Tensor
Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(
ngram: int = 4,
smooth: str = "no_smooth",
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
average: str = "macro",
):
if ngram <= 0:
Expand All @@ -161,7 +162,7 @@ def _n_gram_counter(
candidates: Sequence[Sequence[Any]],
p_numerators: torch.Tensor,
p_denominators: torch.Tensor,
) -> Tuple[int, int]:
) -> tuple[int, int]:
if len(references) != len(candidates):
raise ValueError(
f"nb of candidates should be equal to nb of reference lists ({len(candidates)} != "
Expand Down Expand Up @@ -247,7 +248,7 @@ def reset(self) -> None:
self.ref_length_sum = 0

@reinit__is_reduced
def update(self, output: Tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
def update(self, output: tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
y_pred, y = output

if self.average == "macro":
Expand Down Expand Up @@ -279,7 +280,7 @@ def _compute_micro(self) -> float:
)
return bleu_score

def compute(self) -> Union[None, Tensor, float]:
def compute(self) -> None | Tensor | float:
if self.average == "macro":
return self._compute_macro()
elif self.average == "micro":
Expand Down
21 changes: 11 additions & 10 deletions ignite/metrics/nlp/rouge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
from collections.abc import Callable, Mapping, Sequence
from typing import Any, NamedTuple

import torch

Expand Down Expand Up @@ -128,7 +129,7 @@ def __init__(
multiref: str = "average",
alpha: float = 0,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
) -> None:
super().__init__(output_transform=output_transform, device=device)
self._alpha = alpha
Expand All @@ -153,7 +154,7 @@ def reset(self) -> None:
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
def update(self, output: tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
candidates, references = output
for _candidate, _reference in zip(candidates, references):
multiref_scores = [self._compute_score(candidate=_candidate, reference=_ref) for _ref in _reference]
Expand Down Expand Up @@ -247,7 +248,7 @@ def __init__(
multiref: str = "average",
alpha: float = 0,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
):
super().__init__(multiref=multiref, alpha=alpha, output_transform=output_transform, device=device)
self._ngram = ngram
Expand Down Expand Up @@ -318,7 +319,7 @@ def __init__(
multiref: str = "average",
alpha: float = 0,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
):
super().__init__(multiref=multiref, alpha=alpha, output_transform=output_transform, device=device)

Expand Down Expand Up @@ -386,17 +387,17 @@ class Rouge(Metric):

def __init__(
self,
variants: Optional[Sequence[Union[str, int]]] = None,
variants: Sequence[str | int] | None = None,
multiref: str = "average",
alpha: float = 0,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
):
if variants is None or len(variants) == 0:
variants = [1, 2, 4, "L"]
self.internal_metrics: List[_BaseRouge] = []
self.internal_metrics: list[_BaseRouge] = []
for m in variants:
variant: Optional[_BaseRouge] = None
variant: _BaseRouge | None = None
if isinstance(m, str) and m == "L":
variant = RougeL(multiref=multiref, alpha=alpha, output_transform=output_transform, device=device)
elif isinstance(m, int):
Expand All @@ -414,7 +415,7 @@ def reset(self) -> None:
m.reset()

@reinit__is_reduced
def update(self, output: Tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
def update(self, output: tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
for m in self.internal_metrics:
m.update(output)

Expand Down
5 changes: 3 additions & 2 deletions ignite/metrics/nlp/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import Counter
from typing import Any, Sequence, Tuple
from collections.abc import Sequence
from typing import Any

__all__ = ["ngrams", "lcs", "modified_precision"]

Expand Down Expand Up @@ -51,7 +52,7 @@ def lcs(seq_a: Sequence[Any], seq_b: Sequence[Any]) -> int:
return dp[m][n]


def modified_precision(references: Sequence[Sequence[Any]], candidate: Any, n: int) -> Tuple[int, int]:
def modified_precision(references: Sequence[Sequence[Any]], candidate: Any, n: int) -> tuple[int, int]:
"""
Compute the modified precision

Expand Down