Skip to content

Commit

Permalink
Fixing categorization of nested labels (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdegeus authored Jul 13, 2023
1 parent 27194f7 commit c462a42
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 47 deletions.
54 changes: 50 additions & 4 deletions tests/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@


class MyTests(unittest.TestCase):
"""
Tests
"""

def test_equation(self):
text = r"""
foo bar
Expand Down Expand Up @@ -76,6 +72,56 @@ def test_custom(self):
tex.format_labels()
self.assertEqual(str(tex).strip(), text.strip())

def test_hybrid(self):
text = r"""
\begin{itemize}
\item
\begin{referee}
Some question
\end{referee}
Some response
\begin{figure}[htp]
\centering
\subfloat{\label{fig:1a}}
\subfloat{\label{fig:1b}}
\includegraphics[width=\linewidth]{foo}
\caption{Foo bar}
\label{fig:1}
\end{figure}
\end{itemize}
"""

tex = texplain.TeX(text)
tex.format_labels()
self.assertEqual(str(tex).strip(), text.strip())

# ---

text = r"""
Foo bar
\section{My section}
%
\label{sec:my}
Foo bar
\begin{figure}[htp]
\centering
\subfloat{\label{fig:1a}}
\subfloat{\label{fig:1b}}
\includegraphics[width=\linewidth]{foo}
\caption{Foo bar}
\label{fig:1}
\end{figure}
"""

tex = texplain.TeX(text)
tex.format_labels()
self.assertEqual(str(tex).strip(), text.strip())


if __name__ == "__main__":
unittest.main()
68 changes: 25 additions & 43 deletions texplain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,30 +1812,22 @@ def _format_command(

def _classify_for_label(text: str) -> tuple[list[str], NDArray[np.int_]]:
"""
Classify characters to identify to which environment a label belongs.
Classify each character.
This can be used for example to figure out to which environment a label belongs.
:param text: The text to classify.
:return:
``(categories, classification)`` where ``categories`` is the list of label categories
``"eq"``, ``"fig"``, etc.) and ``classification`` is an array of the same length as ``text``
where each element is the index of the category to which the character belongs.
``(categories, classification)`` where ``categories`` is the list of categories
(``"eq"``, ``"fig"``, etc.; with ``"misc"`` for unknown) and ``classification`` is an array
of the same length as ``text`` where each element is the index in ``categories`` to which
the character belongs.
"""

categories = ["misc", "eq", "item", "note", "sec", "ch", "fig", "tab"]
classification = np.zeros(len(text), dtype=int)
starting = -1 * np.ones((len(text), len(categories)), dtype=int)
braces = find_matching(text, "{", "}", ignore_escaped=True)

# ---

r = -1

for match in re.finditer(r"(\s*\\label\s*\{)", text):
i = match.span()[0]
j = braces[match.span()[1] - 1]
classification[i:j] = r
r -= 1

# ---
# "eq"

r = categories.index("eq")

Expand All @@ -1847,7 +1839,7 @@ def _classify_for_label(text: str) -> tuple[list[str], NDArray[np.int_]]:
closing_match=1,
)
for i, j in index.items():
classification[i:j] = r
starting[i:j, r] = i

index = find_matching(
text,
Expand All @@ -1857,7 +1849,7 @@ def _classify_for_label(text: str) -> tuple[list[str], NDArray[np.int_]]:
closing_match=1,
)
for i, j in index.items():
classification[i:j] = r
starting[i:j, r] = i

index = find_matching(
text,
Expand All @@ -1867,9 +1859,9 @@ def _classify_for_label(text: str) -> tuple[list[str], NDArray[np.int_]]:
closing_match=1,
)
for i, j in index.items():
classification[i:j] = r
starting[i:j, r] = i

# ---
# "fig"

r = categories.index("fig")

Expand All @@ -1881,9 +1873,9 @@ def _classify_for_label(text: str) -> tuple[list[str], NDArray[np.int_]]:
closing_match=1,
)
for i, j in index.items():
classification[i:j] = r
starting[i:j, r] = i

# ---
# "tab"

r = categories.index("tab")

Expand All @@ -1895,9 +1887,9 @@ def _classify_for_label(text: str) -> tuple[list[str], NDArray[np.int_]]:
closing_match=1,
)
for i, j in index.items():
classification[i:j] = r
starting[i:j, r] = i

# ---
# "item"

r = categories.index("item")

Expand All @@ -1909,7 +1901,7 @@ def _classify_for_label(text: str) -> tuple[list[str], NDArray[np.int_]]:
closing_match=1,
)
for i, j in index.items():
classification[i:j] = r
starting[i:j, r] = i

index = find_matching(
text,
Expand All @@ -1919,44 +1911,34 @@ def _classify_for_label(text: str) -> tuple[list[str], NDArray[np.int_]]:
closing_match=1,
)
for i, j in index.items():
classification[i:j] = r
starting[i:j, r] = i

# ---
# "note"

r = categories.index("note")

for match in re.finditer(r"(\\footnote\s*\{)", text):
i = match.span()[0]
j = braces[match.span()[1] - 1]
classification[i:j] = r
starting[i:j, r] = i

# ---
# "sec"

r = categories.index("sec")

for match in re.finditer(r"(\\)(sub)*(section\s*\{)", text):
i = match.span()[0]
j = braces[match.span()[1] - 1]
classification[i:j] = r

if classification[j + 1] < 0:
classification[classification == classification[j + 1]] = r
starting[i:, r] = i

# ---
# "ch"

r = categories.index("ch")

for match in re.finditer(r"(\\)(chapter\s*\{)", text):
i = match.span()[0]
j = braces[match.span()[1] - 1]
classification[i:j] = r

if classification[j + 1] < 0:
classification[classification == classification[j + 1]] = r

# ---
starting[i:, r] = i

return categories, np.where(classification < 0, categories.index("misc"), classification)
return categories, np.argmax(starting, axis=1)


class TeX:
Expand Down

0 comments on commit c462a42

Please sign in to comment.