Skip to content

Commit

Permalink
added a case to leave 、unmodified if phrase starts and ends with it
Browse files Browse the repository at this point in the history
  • Loading branch information
kanjieater committed Aug 11, 2024
1 parent 9a19458 commit 1152560
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
12 changes: 9 additions & 3 deletions subplz/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ def has_ending_punctuation(s: str) -> bool:
def has_punctuation(s: str) -> bool:
return bool(find_punctuation_index(s))

def has_double_comma(str_starts: str, str_ends: str) -> bool:
comma = """、"""
return str_starts[-1] == comma and str_ends[-1] == comma



def count_non_punctuation(s: str) -> int:
punctuation = """「"'“¿([{-.。,,!!??::”)]}、)」"""
Expand Down Expand Up @@ -391,6 +396,7 @@ def find_index_with_non_punctuation_end(indices: List[int]) -> List[int]:

return result


def print_modified_segments(segments, new_segments, final_segments,modified_new_segment_debug_log, modified_final_segment_debug_log):
print("Modified Start segments:")
for index in modified_new_segment_debug_log:
Expand Down Expand Up @@ -431,7 +437,7 @@ def shift_align(segments: List[Segment]) -> List[Segment]:
continue
if non_punc_count <= 2:
# If the first segment has 2 or fewer non-punctuation characters
if i > 0 and len(new_segments) > 0 and not has_ending_punctuation(new_segments[-1].text[-1]):
if i > 0 and len(new_segments) > 0 and not has_ending_punctuation(new_segments[-1].text[-1]) and not has_double_comma(new_segments[-1].text, text[:start_index + 1]):
# Move part of the text to the previous segment
prev_segment = new_segments.pop()
prev_segment.text += text[:start_index + 1] # Include the punctuation
Expand All @@ -457,7 +463,7 @@ def shift_align(segments: List[Segment]) -> List[Segment]:
continue
if count_non_punctuation(text[indices[-1]:]) <= 2:
# Move part of the text to the next segment
if i+1 < len(new_segments) and not has_ending_punctuation(new_segments[i+1].text[0]):
if i+1 < len(new_segments) and not has_ending_punctuation(new_segments[i+1].text[0]) and not has_double_comma(new_segments[i+1].text[0], text[:start_index + 1]):
next_segment = segments[i + 1]
next_segment.text = text[last_index+1:] + next_segment.text # Keep the punctuation
text = text[:last_index+1] # Exclude the punctuation
Expand All @@ -468,5 +474,5 @@ def shift_align(segments: List[Segment]) -> List[Segment]:
continue
final_segments.append(Segment(text, segment.start, segment.end))

print_modified_segments(segments, new_segments, final_segments, modified_new_segment_debug_log, modified_final_segment_debug_log)
# print_modified_segments(segments, new_segments, final_segments, modified_new_segment_debug_log, modified_final_segment_debug_log)
return final_segments
17 changes: 17 additions & 0 deletions subplz/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,23 @@ def test_shift_align_wa_pattern():
assert result[1].text == '知っている。'


def test_shift_align_maintain_double_comma():
# Given segments that require complex alignment
segments = [
Segment(text='一群の騎馬が、', start=375.00, end=376.00),
Segment(text='東へ、ペシャワール城塞の方角へと疾駆している。', start=376.00, end=378.00),
Segment(text='一群の騎馬が、東へ、', start=375.00, end=376.00),
Segment(text='ペシャワール城塞の方角へと疾駆している。', start=376.00, end=378.00)
]
result = shift_align(segments)

assert len(result) == len(segments)
assert result[0] == segments[0]
assert result[1] == segments[1]
assert result[2] == segments[2]
assert result[3] == segments[3]


def test_shift_align_no_change():
# These lines shouldn't change
segments = [
Expand Down

0 comments on commit 1152560

Please sign in to comment.