Skip to content

Commit

Permalink
Merge pull request #99 from huridocs/options-endpoint
Browse files Browse the repository at this point in the history
Options endpoint
  • Loading branch information
gabriel-piles authored Oct 22, 2024
2 parents 3525ade + 54c616a commit 6609801
Showing 10 changed files with 94 additions and 37 deletions.
16 changes: 16 additions & 0 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@
from config import config_logger, MONGO_HOST, MONGO_PORT
from data.ExtractionIdentifier import ExtractionIdentifier
from data.LabeledData import LabeledData
from data.Option import Option
from data.Options import Options
from data.PredictionData import PredictionData
from data.Suggestion import Suggestion
from XmlFile import XmlFile
@@ -126,3 +128,17 @@ async def get_suggestions(tenant: str, extraction_id: str):
except Exception:
config_logger.error("Error", exc_info=1)
raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info")


@app.post("/options")
def save_options(options: Options):
try:
extraction_identifier = ExtractionIdentifier(run_name=options.tenant, extraction_name=options.extraction_id)
options_list = [option.model_dump() for option in options.options]
extraction_identifier.get_options_path().write_text(json.dumps(options_list))
os.utime(extraction_identifier.get_options_path().parent)
config_logger.info(f"Options {options.options[:150]} saved for {extraction_identifier}")
return True
except Exception:
config_logger.error("Error", exc_info=1)
raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info")
6 changes: 6 additions & 0 deletions src/data/ExtractionIdentifier.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,12 @@ class ExtractionIdentifier(BaseModel):
def get_path(self):
return join(DATA_PATH, self.run_name, self.extraction_name)

def get_options_path(self):
path = Path(join(DATA_PATH, self.run_name, f"{self.extraction_name}_options.json"))
if not exists(path.parent):
os.makedirs(path.parent, exist_ok=True)
return path

def get_extractor_used_path(self) -> Path:
path = Path(join(DATA_PATH, self.run_name, f"{self.extraction_name}.txt"))
if not exists(path.parent):
9 changes: 9 additions & 0 deletions src/data/Options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel

from data.Option import Option


class Options(BaseModel):
tenant: str
extraction_id: str
options: list[Option]
1 change: 0 additions & 1 deletion src/extractors/ExtractorBase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import random
from abc import abstractmethod
from os import makedirs
from os.path import exists
Original file line number Diff line number Diff line change
@@ -54,7 +54,11 @@ def get_performance(self, train_set: ExtractionData, test_set: ExtractionData) -
predictions = [x[:1] for x in predictions]

predictions_one_hot = self.one_hot_to_options_list(predictions, self.options)
score = f1_score(truth_one_hot, predictions_one_hot, average="micro")

try:
score = f1_score(truth_one_hot, predictions_one_hot, average="micro")
except ValueError:
score = 0

return 100 * score

Original file line number Diff line number Diff line change
@@ -93,16 +93,16 @@ class PdfToMultiOptionExtractor(ExtractorBase):
def __init__(self, extraction_identifier: ExtractionIdentifier):
super().__init__(extraction_identifier)
self.base_path = join(self.extraction_identifier.get_path(), "multi_option_extractor")
self.options_path = join(self.base_path, "options.json")
self.multi_value_path = join(self.base_path, "multi_value.json")
self.method_name_path = Path(join(self.base_path, "method_name.json"))

self.options: list[Option] = list()
self.multi_value = False

def create_model(self, extraction_data: ExtractionData):
self.options = extraction_data.options
self.options = self.load_options(extraction_data.options)
self.multi_value = extraction_data.multi_value
send_logs(self.extraction_identifier, f"options {[x.model_dump() for x in self.options]}")

SegmentSelector(self.extraction_identifier).prepare_model_folder()
FastSegmentSelector(self.extraction_identifier).prepare_model_folder()
@@ -123,7 +123,6 @@ def create_model(self, extraction_data: ExtractionData):
if len(extraction_data.samples) < RETRAIN_SAMPLES_THRESHOLD:
method.train(extraction_data)

self.save_json(self.options_path, [x.model_dump() for x in extraction_data.options])
self.save_json(self.multi_value_path, extraction_data.multi_value)
self.save_json(str(self.method_name_path), method.get_name())

@@ -146,7 +145,8 @@ def get_suggestions(self, predictions_samples: list[PredictionSample]) -> list[S
return suggestions

def get_predictions(self, predictions_samples: list[PredictionSample]) -> (list[TrainingSample], list[list[Option]]):
self.load_options()
self.options = self.load_options()
self.multi_value = self.load_multi_value()
training_samples = [TrainingSample(pdf_data=sample.pdf_data) for sample in predictions_samples]
extraction_data = ExtractionData(
multi_value=self.multi_value,
@@ -165,15 +165,22 @@ def get_predictions(self, predictions_samples: list[PredictionSample]) -> (list[

return method.get_samples_for_context(extraction_data), prediction

def load_options(self):
if not exists(self.options_path) or not exists(self.multi_value_path):
return
def load_options(self, options: list[Option] = None) -> list[Option]:
if options:
self.extraction_identifier.get_options_path().write_text(json.dumps([x.model_dump() for x in options]))
return options

with open(self.options_path, "r") as file:
self.options = [Option(**x) for x in json.load(file)]
if not exists(self.extraction_identifier.get_options_path()):
return []

return [Option(**x) for x in json.loads(self.extraction_identifier.get_options_path().read_text())]

def load_multi_value(self) -> bool:
if not exists(self.multi_value_path):
return False

with open(self.multi_value_path, "r") as file:
self.multi_value = json.load(file)
return json.load(file)

def get_best_method(self, multi_option_data: ExtractionData) -> PdfMultiOptionMethod:
best_method_instance = self.METHODS[0]
@@ -228,7 +235,7 @@ def get_predictions_method(self):
return self.METHODS[0]

def can_be_used(self, extraction_data: ExtractionData) -> bool:
if not extraction_data.options:
if not extraction_data.options and not extraction_data.extraction_identifier.get_options_path().exists():
return False

for sample in extraction_data.samples:
Original file line number Diff line number Diff line change
@@ -122,8 +122,9 @@ def get_aliases_path(self) -> Path:
def get_aliases(self, sample: TrainingSample) -> dict[str, str]:
segments = [segment for segment in sample.pdf_data.pdf_data_segments if segment.ml_label]
appearances, not_found_texts = self.get_appearances_for_segments(segments, dict())
truth_options = self.clean_texts([option.label for option in sample.labeled_data.values], False)

values_ids = [option.id for option in sample.labeled_data.values]
values_labels = [option.label for option in self.options if option.id in values_ids]
truth_options = self.clean_texts(values_labels, False)
not_found_options = [option for option in truth_options if option not in appearances]
return self.find_aliases(not_found_options, not_found_texts)

Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import shutil
from os.path import join, exists
from pathlib import Path
from typing import Type
@@ -54,9 +53,8 @@ def __init__(self, extraction_identifier):
super().__init__(extraction_identifier)

self.base_path = join(self.extraction_identifier.get_path(), "text_to_multi_option")
self.options_path = join(self.base_path, "options.json")
self.multi_value_path = join(self.base_path, "multi_value.json")
self.method_name_path = Path(join(self.base_path, "method_name.json"))
self.multi_value_path = Path(self.base_path, "multi_value.json")
self.method_name_path = Path(self.base_path, "method_name.json")

self.options: list[Option] = list()
self.multi_value = False
@@ -80,7 +78,8 @@ def get_suggestions(self, predictions_samples: list[PredictionSample]) -> list[S
return suggestions

def get_predictions_method(self):
self.load_options()
self.options = self.load_options()
self.multi_value = self.load_multi_value()
method_name = json.loads(self.method_name_path.read_text())
for method in self.METHODS:
method_instance = method(self.extraction_identifier, self.options, self.multi_value)
@@ -89,25 +88,30 @@ def get_predictions_method(self):

return self.METHODS[0](self.extraction_identifier, self.options, self.multi_value)

def load_options(self):
if not exists(self.options_path) or not exists(self.multi_value_path):
return
def load_options(self, options: list[Option] = None) -> list[Option]:
if options:
self.extraction_identifier.get_options_path().write_text(json.dumps([x.model_dump() for x in options]))
return options

with open(self.options_path, "r") as file:
self.options = [Option(**x) for x in json.load(file)]
if not exists(self.extraction_identifier.get_options_path()):
return []

return [Option(**x) for x in json.loads(self.extraction_identifier.get_options_path().read_text())]

def load_multi_value(self) -> bool:
if not self.multi_value_path.exists():
return False

with open(self.multi_value_path, "r") as file:
self.multi_value = json.load(file)
return json.loads(self.multi_value_path.read_text())

def create_model(self, extraction_data: ExtractionData) -> tuple[bool, str]:
self.options = extraction_data.options
self.options = self.load_options(extraction_data.options)
self.multi_value = extraction_data.multi_value

best_method_instance = self.get_best_method(extraction_data)
best_method_instance.train(extraction_data)

self.save_json(self.options_path, [x.model_dump() for x in extraction_data.options])
self.save_json(self.multi_value_path, extraction_data.multi_value)
self.save_json(str(self.multi_value_path), extraction_data.multi_value)
self.save_json(str(self.method_name_path), best_method_instance.get_name())
return True, ""

@@ -150,7 +154,7 @@ def remove_models(self):
method_instance.remove_model()

def can_be_used(self, extraction_data: ExtractionData) -> bool:
if not extraction_data.options:
if not extraction_data.options and not extraction_data.extraction_identifier.get_options_path().exists():
return False

for sample in extraction_data.samples:
Original file line number Diff line number Diff line change
@@ -13,16 +13,22 @@ class TestTextToMultiOptionExtraction(TestCase):
extraction_id = "extraction_id"

def test_is_valid(self):
extraction_identifier = ExtractionIdentifier(run_name=self.TENANT, extraction_name=self.extraction_id)
extraction_identifier = ExtractionIdentifier(run_name=self.TENANT, extraction_name="other")
options = [Option(id="1", label="1"), Option(id="2", label="2"), Option(id="3", label="3")]

samples_text = [TrainingSample(labeled_data=LabeledData(source_text="1"))]
samples_no_text = [TrainingSample(labeled_data=LabeledData(source_text=""))]

multi_option_extraction = TextToMultiOptionExtractor(extraction_identifier)
self.assertFalse(multi_option_extraction.can_be_used(ExtractionData(samples=samples_text)))
self.assertFalse(multi_option_extraction.can_be_used(ExtractionData(options=options, samples=samples_no_text)))
self.assertTrue(multi_option_extraction.can_be_used(ExtractionData(options=options, samples=samples_text)))
no_options = ExtractionData(extraction_identifier=extraction_identifier, samples=samples_text)
no_text = ExtractionData(extraction_identifier=extraction_identifier, options=options, samples=samples_no_text)
valid_extraction_data = ExtractionData(
extraction_identifier=extraction_identifier, options=options, samples=samples_text
)

self.assertFalse(multi_option_extraction.can_be_used(no_options))
self.assertFalse(multi_option_extraction.can_be_used(no_text))
self.assertTrue(multi_option_extraction.can_be_used(valid_extraction_data))

def test_single_value(self):
extraction_identifier = ExtractionIdentifier(run_name=self.TENANT, extraction_name=self.extraction_id)
9 changes: 7 additions & 2 deletions src/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -183,7 +183,11 @@ def test_pdf_to_multi_option(self):
files = {"file": stream}
requests.post(f"{SERVER_URL}/xml_to_train/{tenant}/{extraction_id}", files=files)

options = [Option(id="1", label="United Nations"), Option(id="2", label="Other")]
options = {
"tenant": tenant,
"extraction_id": extraction_id,
"options": [Option(id="1", label="United Nations").model_dump(), Option(id="2", label="Other").model_dump()],
}

labeled_data_json = {
"id": extraction_id,
@@ -197,6 +201,7 @@ def test_pdf_to_multi_option(self):
}

requests.post(f"{SERVER_URL}/labeled_data", json=labeled_data_json)
requests.post(f"{SERVER_URL}/options", json=options)

with open(test_xml_path, mode="rb") as stream:
files = {"file": stream}
@@ -216,7 +221,7 @@ def test_pdf_to_multi_option(self):
task = ExtractionTask(
tenant=tenant,
task="create_model",
params=Params(id=extraction_id, options=options, multi_value=False, metadata={"name": "test"}),
params=Params(id=extraction_id, multi_value=False, metadata={"name": "test"}),
)

QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()

0 comments on commit 6609801

Please sign in to comment.