diff --git a/src/extractors/pdf_to_multi_option_extractor/PdfToMultiOptionExtractor.py b/src/extractors/pdf_to_multi_option_extractor/PdfToMultiOptionExtractor.py index 9943be9..636dbdd 100644 --- a/src/extractors/pdf_to_multi_option_extractor/PdfToMultiOptionExtractor.py +++ b/src/extractors/pdf_to_multi_option_extractor/PdfToMultiOptionExtractor.py @@ -58,6 +58,9 @@ from extractors.pdf_to_multi_option_extractor.multi_option_extraction_methods.SentenceSelectorFuzzyCommas import ( SentenceSelectorFuzzyCommas, ) +from extractors.segment_selector.FastAndPositionsSegmentSelector import FastAndPositionsSegmentSelector +from extractors.segment_selector.FastSegmentSelector import FastSegmentSelector +from extractors.segment_selector.SegmentSelector import SegmentSelector from send_logs import send_logs RETRAIN_SAMPLES_THRESHOLD = 250 @@ -101,6 +104,10 @@ def create_model(self, extraction_data: ExtractionData): self.options = extraction_data.options self.multi_value = extraction_data.multi_value + SegmentSelector(self.extraction_identifier).prepare_model_folder() + FastSegmentSelector(self.extraction_identifier).prepare_model_folder() + FastAndPositionsSegmentSelector(self.extraction_identifier).prepare_model_folder() + send_logs(self.extraction_identifier, self.get_stats(extraction_data)) performance_train_set, performance_test_set = ExtractorBase.get_train_test_sets(extraction_data) diff --git a/src/extractors/pdf_to_multi_option_extractor/multi_option_extraction_methods/FastSegmentSelectorFuzzy95.py b/src/extractors/pdf_to_multi_option_extractor/multi_option_extraction_methods/FastSegmentSelectorFuzzy95.py index 2108ecf..e921682 100644 --- a/src/extractors/pdf_to_multi_option_extractor/multi_option_extraction_methods/FastSegmentSelectorFuzzy95.py +++ b/src/extractors/pdf_to_multi_option_extractor/multi_option_extraction_methods/FastSegmentSelectorFuzzy95.py @@ -40,7 +40,7 @@ def train(self, multi_option_data: ExtractionData): for sample in multi_option_data.samples: marked_segments.extend(self.get_marked_segments(sample)) - FastSegmentSelector(self.extraction_identifier).create_model(marked_segments) + FastSegmentSelector(self.extraction_identifier, self.get_name()).create_model(marked_segments) def predict(self, multi_option_data: ExtractionData) -> list[list[Option]]: self.set_parameters(multi_option_data) @@ -48,7 +48,7 @@ def predict(self, multi_option_data: ExtractionData) -> list[list[Option]]: return FuzzyAll95().predict(self.extraction_data) def get_prediction_data(self, extraction_data: ExtractionData) -> ExtractionData: - fast_segment_selector = FastSegmentSelector(self.extraction_identifier) + fast_segment_selector = FastSegmentSelector(self.extraction_identifier, self.get_name()) predict_samples = list() for sample in extraction_data.samples: selected_segments = fast_segment_selector.predict(self.fix_two_pages_segments(sample)) diff --git a/src/extractors/segment_selector/FastSegmentSelector.py b/src/extractors/segment_selector/FastSegmentSelector.py index 72a5bf1..6b312fc 100644 --- a/src/extractors/segment_selector/FastSegmentSelector.py +++ b/src/extractors/segment_selector/FastSegmentSelector.py @@ -14,12 +14,18 @@ class FastSegmentSelector: - def __init__(self, extraction_identifier: ExtractionIdentifier): + def __init__(self, extraction_identifier: ExtractionIdentifier, method_name: str = ""): self.extraction_identifier = extraction_identifier self.text_types = [TokenType.TEXT, TokenType.LIST_ITEM, TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.CAPTION] self.previous_words, self.next_words, self.text_segments = [], [], [] + self.method_name = method_name - self.fast_segment_selector_path = Path(self.extraction_identifier.get_path(), self.__class__.__name__) + if method_name: + self.fast_segment_selector_path = Path( + self.extraction_identifier.get_path(), method_name, self.__class__.__name__ + ) + else: + self.fast_segment_selector_path = Path(self.extraction_identifier.get_path(), self.__class__.__name__) if not self.fast_segment_selector_path.exists(): os.makedirs(self.fast_segment_selector_path, exist_ok=True) @@ -95,7 +101,10 @@ def save_predictive_common_words(self, segments): Path(self.next_words_path).write_text(json.dumps(self.next_words)) def create_model(self, segments: list[PdfDataSegment]): - if not segments or Path(self.model_path).exists(): + if not segments: + return + + if not self.method_name and Path(self.model_path).exists(): return self.text_segments = [x for x in segments if x.segment_type in self.text_types] diff --git a/src/start_queue_processor.py b/src/start_queue_processor.py index 334f450..6980522 100644 --- a/src/start_queue_processor.py +++ b/src/start_queue_processor.py @@ -84,9 +84,7 @@ def task_to_string(extraction_task: ExtractionTask): except Exception: pass - config_logger.info("Is GPU used?") - config_logger.info(torch.cuda.is_available()) - + config_logger.info(f"Waiting for messages. Is GPU used? {torch.cuda.is_available()}") queues_names = QUEUES_NAMES.split(" ") queue_processor = QueueProcessor(REDIS_HOST, REDIS_PORT, queues_names, config_logger) - queue_processor.start(process, run_once=True) + queue_processor.start(process)