Skip to content

Commit

Permalink
Merge pull request #95 from huridocs/fix-segment-selector
Browse files Browse the repository at this point in the history
Fix segment selector
  • Loading branch information
gabriel-piles authored Oct 12, 2024
2 parents 173ae5a + 98c5bed commit 613f1dc
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
from extractors.pdf_to_multi_option_extractor.multi_option_extraction_methods.SentenceSelectorFuzzyCommas import (
SentenceSelectorFuzzyCommas,
)
from extractors.segment_selector.FastAndPositionsSegmentSelector import FastAndPositionsSegmentSelector
from extractors.segment_selector.FastSegmentSelector import FastSegmentSelector
from extractors.segment_selector.SegmentSelector import SegmentSelector
from send_logs import send_logs

RETRAIN_SAMPLES_THRESHOLD = 250
Expand Down Expand Up @@ -101,6 +104,10 @@ def create_model(self, extraction_data: ExtractionData):
self.options = extraction_data.options
self.multi_value = extraction_data.multi_value

SegmentSelector(self.extraction_identifier).prepare_model_folder()
FastSegmentSelector(self.extraction_identifier).prepare_model_folder()
FastAndPositionsSegmentSelector(self.extraction_identifier).prepare_model_folder()

send_logs(self.extraction_identifier, self.get_stats(extraction_data))

performance_train_set, performance_test_set = ExtractorBase.get_train_test_sets(extraction_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def train(self, multi_option_data: ExtractionData):
for sample in multi_option_data.samples:
marked_segments.extend(self.get_marked_segments(sample))

FastSegmentSelector(self.extraction_identifier).create_model(marked_segments)
FastSegmentSelector(self.extraction_identifier, self.get_name()).create_model(marked_segments)

def predict(self, multi_option_data: ExtractionData) -> list[list[Option]]:
self.set_parameters(multi_option_data)
self.extraction_data = self.get_prediction_data(multi_option_data)
return FuzzyAll95().predict(self.extraction_data)

def get_prediction_data(self, extraction_data: ExtractionData) -> ExtractionData:
fast_segment_selector = FastSegmentSelector(self.extraction_identifier)
fast_segment_selector = FastSegmentSelector(self.extraction_identifier, self.get_name())
predict_samples = list()
for sample in extraction_data.samples:
selected_segments = fast_segment_selector.predict(self.fix_two_pages_segments(sample))
Expand Down
15 changes: 12 additions & 3 deletions src/extractors/segment_selector/FastSegmentSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@


class FastSegmentSelector:
def __init__(self, extraction_identifier: ExtractionIdentifier):
def __init__(self, extraction_identifier: ExtractionIdentifier, method_name: str = ""):
self.extraction_identifier = extraction_identifier
self.text_types = [TokenType.TEXT, TokenType.LIST_ITEM, TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.CAPTION]
self.previous_words, self.next_words, self.text_segments = [], [], []
self.method_name = method_name

self.fast_segment_selector_path = Path(self.extraction_identifier.get_path(), self.__class__.__name__)
if method_name:
self.fast_segment_selector_path = Path(
self.extraction_identifier.get_path(), method_name, self.__class__.__name__
)
else:
self.fast_segment_selector_path = Path(self.extraction_identifier.get_path(), self.__class__.__name__)
if not self.fast_segment_selector_path.exists():
os.makedirs(self.fast_segment_selector_path, exist_ok=True)

Expand Down Expand Up @@ -95,7 +101,10 @@ def save_predictive_common_words(self, segments):
Path(self.next_words_path).write_text(json.dumps(self.next_words))

def create_model(self, segments: list[PdfDataSegment]):
if not segments or Path(self.model_path).exists():
if not segments:
return

if not self.method_name and Path(self.model_path).exists():
return

self.text_segments = [x for x in segments if x.segment_type in self.text_types]
Expand Down
6 changes: 2 additions & 4 deletions src/start_queue_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def task_to_string(extraction_task: ExtractionTask):
except Exception:
pass

config_logger.info("Is GPU used?")
config_logger.info(torch.cuda.is_available())

config_logger.info(f"Waiting for messages. Is GPU used? {torch.cuda.is_available()}")
queues_names = QUEUES_NAMES.split(" ")
queue_processor = QueueProcessor(REDIS_HOST, REDIS_PORT, queues_names, config_logger)
queue_processor.start(process, run_once=True)
queue_processor.start(process)

0 comments on commit 613f1dc

Please sign in to comment.