Skip to content

Commit

Permalink
CU-86956a6y5: Add optional doc limit to comparison tool
Browse files Browse the repository at this point in the history
  • Loading branch information
mart-r committed Aug 12, 2024
1 parent 9fc0c9b commit b34dc3a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
13 changes: 9 additions & 4 deletions medcat/compare_models/comp_nbhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ipywidgets import widgets
from IPython.display import display
import os
from typing import List
from typing import List, Optional


from compare import get_diffs_for
Expand All @@ -16,11 +16,12 @@
class NBComparer:

def __init__(self, model_path_1: str, model_path_2: str,
documents_file: str, is_mct_export_compare: bool,
documents_file: str, doc_limit: int, is_mct_export_compare: bool,
cui_filter: str, filter_children: bool) -> None:
self.model_path_1 = model_path_1
self.model_path_2 = model_path_2
self.documents_file = documents_file
self.doc_limit = doc_limit
self.is_mct_export_compare = is_mct_export_compare
self.cui_filter = cui_filter
self.filter_children = filter_children
Expand All @@ -30,7 +31,7 @@ def _run_comparison(self):
(self.cdb_comp, self.tally1, self.tally2, self.ann_diffs) = get_diffs_for(
self.model_path_1, self.model_path_2, self.documents_file,
cui_filter=self.cui_filter, include_children_in_filter=self.filter_children,
supervised_train_comparison_model=self.is_mct_export_compare)
supervised_train_comparison_model=self.is_mct_export_compare, doc_limit=self.doc_limit)

def show_all(self):
parse_and_show(self.cdb_comp, self.tally1, self.tally2, self.ann_diffs)
Expand Down Expand Up @@ -71,6 +72,7 @@ class NBInputter:
mc1_title = "Choose model 1"
mc2_title = "Choose model 2 (or an MCT export)"
docs_title = "Choose the documents file (.csv with 'text' field)"
docs_limit_title = "Limit the number of documents to run (-1 to disable)"
mct_export_title = "Is the 2nd path an MCT export (instead of a model)?"
cui_filter_title_overall = "CUI Filter"
cui_filter_title_file_chooser = "Choose file with comma-separated CUIs"
Expand All @@ -81,6 +83,7 @@ def __init__(self) -> None:
self.model1_chooser = FileChooser(_def_path)
self.model2_chooser = FileChooser(_def_path)
self.documents_chooser = FileChooser(".")
self.doc_limit = widgets.IntText(-1)
self.ckbox = widgets.Checkbox(description="MCT export compare")

self.cui_filter_chooser = FileChooser(".", description="The CUI filter file")
Expand All @@ -93,6 +96,7 @@ def show_all(self):
widgets.VBox([widgets.Label(self.mc1_title), self.model1_chooser]),
widgets.VBox([widgets.Label(self.mc2_title), self.model2_chooser]),
widgets.VBox([widgets.Label(self.docs_title), self.documents_chooser]),
widgets.VBox([widgets.Label(self.docs_limit_title), self.doc_limit]),
widgets.VBox([widgets.Label(self.mct_export_title), self.ckbox])
])

Expand All @@ -115,6 +119,7 @@ def _get_params(self):
model_path_1 = self.model1_chooser.selected
model_path_2 = self.model2_chooser.selected
documents_file = self.documents_chooser.selected
doc_limit = self.doc_limit.value
is_mct_export_compare = self.ckbox.value
if not is_mct_export_compare:
print(f"For models, selected:\nModel1: {model_path_1}\nModel2: {model_path_2}"
Expand All @@ -132,7 +137,7 @@ def _get_params(self):
if self.cui_children.value and self.cui_children.value > 0:
filter_children = self.cui_children.value
print(f"For CUI filter, selected:\nFilter: {cui_filter}\nChildren: {filter_children}")
return (model_path_1, model_path_2, documents_file, is_mct_export_compare, cui_filter, filter_children)
return (model_path_1, model_path_2, documents_file, doc_limit, is_mct_export_compare, cui_filter, filter_children)

def get_comparison(self) -> NBComparer:
return NBComparer(*self._get_params())
27 changes: 18 additions & 9 deletions medcat/compare_models/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import tqdm
import tempfile
import os
from itertools import islice

from compare_cdb import compare as compare_cdbs, CDBCompareResults
from compare_annotations import ResultsTally, PerAnnotationDifferences
Expand All @@ -17,18 +17,22 @@



def load_documents(file_name: str) -> Iterator[Tuple[str, str]]:
def load_documents(file_name: str, doc_limit: int = -1) -> Iterator[Tuple[str, str]]:
with open(file_name) as f:
df = pd.read_csv(f, names=["id", "text"])
if df.iloc[0].id == "id" and df.iloc[0].text == "text":
# removes the header
# but also messes up the index a little
df = df.iloc[1:, :]
yield from df.itertuples(index=False)
if doc_limit == -1:
yield from df.itertuples(index=False)
else:
yield from islice(df.itertuples(index=False), doc_limit)


def do_counting(cat1: CAT, cat2: CAT,
ann_diffs: PerAnnotationDifferences) -> ResultsTally:
ann_diffs: PerAnnotationDifferences,
doc_limit: int = -1) -> ResultsTally:
def cui2name(cat, cui):
if cui in cat.cdb.cui2preferred_name:
return cat.cdb.cui2preferred_name[cui]
Expand All @@ -39,7 +43,8 @@ def cui2name(cat, cui):
cui2name=partial(cui2name, cat1))
res2 = ResultsTally(pt2ch=_get_pt2ch(cat2), cat_data=cat2.cdb.make_stats(),
cui2name=partial(cui2name, cat2))
for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values()):
total = doc_limit if doc_limit != -1 else None
for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values(), total=total):
res1.count(per_doc.raw1)
res2.count(per_doc.raw2)
return res1, res2
Expand All @@ -52,6 +57,7 @@ def _get_pt2ch(cat: CAT) -> Optional[Dict]:
def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str, str]],
show_progress: bool = True,
keep_raw: bool = True,
doc_limit: int = -1
) -> PerAnnotationDifferences:
pt2ch1: Optional[Dict] = _get_pt2ch(cat1)
pt2ch2: Optional[Dict] = _get_pt2ch(cat2)
Expand All @@ -63,7 +69,8 @@ def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str
model2_cuis=set(cat2.cdb.cui2names),
keep_raw=keep_raw,
save_options=save_opts)
for doc_id, doc in tqdm.tqdm(documents, disable=not show_progress):
total = doc_limit if doc_limit != -1 else None
for doc_id, doc in tqdm.tqdm(documents, disable=not show_progress, total=total):
pad.look_at_doc(cat1.get_entities(doc), cat2.get_entities(doc), doc_id, doc)
pad.finalise()
return pad
Expand Down Expand Up @@ -107,9 +114,10 @@ def get_diffs_for(model_pack_path_1: str,
include_children_in_filter: Optional[int] = None,
supervised_train_comparison_model: bool = False,
keep_raw: bool = True,
doc_limit: int = -1,
) -> Tuple[CDBCompareResults, ResultsTally, ResultsTally, PerAnnotationDifferences]:
validate_input(model_pack_path_1, model_pack_path_2, documents_file, cui_filter, supervised_train_comparison_model)
documents = load_documents(documents_file)
documents = load_documents(documents_file, doc_limit=doc_limit)
if show_progress:
print("Loading [1]", model_pack_path_1)
cat1 = CAT.load_model_pack(model_pack_path_1)
Expand Down Expand Up @@ -145,10 +153,11 @@ def get_diffs_for(model_pack_path_1: str,
len(cui_filter), "CUIs")
cat1.config.linking.filters.cuis = cui_filter
cat2.config.linking.filters.cuis = cui_filter
ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw)
ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw,
doc_limit=doc_limit)
if show_progress:
print("Counting [1&2]")
res1, res2 = do_counting(cat1, cat2, ann_diffs)
res1, res2 = do_counting(cat1, cat2, ann_diffs, doc_limit=doc_limit)
if show_progress:
print("CDB compare")
cdb_diff = compare_cdbs(cat1.cdb, cat2.cdb)
Expand Down

0 comments on commit b34dc3a

Please sign in to comment.