Skip to content

Commit

Permalink
CU-86956a6y5 improve comparison (#15)
Browse files Browse the repository at this point in the history
* CU-86956a6y5: Add helper for comparison in notebook

* CU-86956a6y5: Update model comparison notebook

* CU-86956a6y5: Update per document view to allow ignoring empty documents

* CU-86956a6y5: Update notebook with new demo

* CU-86956a6y5: Add (truncated) demo data

* CU-86956a6y5: Add optional doc limit to comparison tool
  • Loading branch information
mart-r authored Sep 6, 2024
1 parent cf0908e commit 69a225d
Show file tree
Hide file tree
Showing 5 changed files with 18,635 additions and 671 deletions.
143 changes: 143 additions & 0 deletions medcat/compare_models/comp_nbhelper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from ipyfilechooser import FileChooser
from ipywidgets import widgets
from IPython.display import display
import os
from typing import List, Optional


from compare import get_diffs_for
from output import parse_and_show, show_dict_deep, compare_dicts


_def_path = '../../models/modelpack'
_def_path = _def_path if os.path.exists(_def_path) else '.'


class NBComparer:

def __init__(self, model_path_1: str, model_path_2: str,
documents_file: str, doc_limit: int, is_mct_export_compare: bool,
cui_filter: str, filter_children: bool) -> None:
self.model_path_1 = model_path_1
self.model_path_2 = model_path_2
self.documents_file = documents_file
self.doc_limit = doc_limit
self.is_mct_export_compare = is_mct_export_compare
self.cui_filter = cui_filter
self.filter_children = filter_children
self._run_comparison()

def _run_comparison(self):
(self.cdb_comp, self.tally1, self.tally2, self.ann_diffs) = get_diffs_for(
self.model_path_1, self.model_path_2, self.documents_file,
cui_filter=self.cui_filter, include_children_in_filter=self.filter_children,
supervised_train_comparison_model=self.is_mct_export_compare, doc_limit=self.doc_limit)

def show_all(self):
parse_and_show(self.cdb_comp, self.tally1, self.tally2, self.ann_diffs)

def show_per_document(self, limit: int = -1, print_delimiter: bool = True,
ignore_empty: bool = True):
cnt = 0
for key in self.ann_diffs.per_doc_results.keys():
comp_dict = self.ann_diffs.per_doc_results[key].nr_of_comparisons
if not ignore_empty or comp_dict: # ignore empty ones
if print_delimiter:
print('='*20,f'\n{key}', f'\n{"="*20}')
show_dict_deep(self.ann_diffs.per_doc_results[key].nr_of_comparisons)
cnt += 1
if limit > -1 and cnt == limit:
break

def diffs_to_csv(self, file_path: str) -> None:
self.ann_diffs.to_csv(file_path)

def compare_for_cui(self, cui: str, include_children: int = 2) -> None:
per_cui1 = self.tally1.get_for_cui(cui, include_children=include_children)
per_cui2 = self.tally2.get_for_cui(cui, include_children=include_children)
compare_dicts(per_cui1, per_cui2)

def show_docs(self, docs: List[str], show_delimiter: bool = True,
omit_identical: bool = True):
for doc_name, pair in self.ann_diffs.iter_ann_pairs(docs=docs, omit_identical=omit_identical):
if show_delimiter:
print('='*20,f'\n{doc_name} ({pair.comparison_type})', f'\n{"="*20}')
# NOTE: if only one of the two has an annotation, the other one will be None
# the following will deal with that automatically, though
compare_dicts(pair.one, pair.two)


class NBInputter:
models_overall_title = "Models and data"
mc1_title = "Choose model 1"
mc2_title = "Choose model 2 (or an MCT export)"
docs_title = "Choose the documents file (.csv with 'text' field)"
docs_limit_title = "Limit the number of documents to run (-1 to disable)"
mct_export_title = "Is the 2nd path an MCT export (instead of a model)?"
cui_filter_title_overall = "CUI Filter"
cui_filter_title_file_chooser = "Choose file with comma-separated CUIs"
cui_filter_title_text = "List comma-separated CUIs"
cui_children_title = "How many layers of children of concepts to include?"

def __init__(self) -> None:
self.model1_chooser = FileChooser(_def_path)
self.model2_chooser = FileChooser(_def_path)
self.documents_chooser = FileChooser(".")
self.doc_limit = widgets.IntText(-1)
self.ckbox = widgets.Checkbox(description="MCT export compare")

self.cui_filter_chooser = FileChooser(".", description="The CUI filter file")
self.cui_filter_box = widgets.Textarea(description="CUI list")
self.cui_children = widgets.IntText(description="Children", value=-1)

def show_all(self):
model_choosers = widgets.VBox([
widgets.HTML(f"<h2>{self.models_overall_title}</h2>"),
widgets.VBox([widgets.Label(self.mc1_title), self.model1_chooser]),
widgets.VBox([widgets.Label(self.mc2_title), self.model2_chooser]),
widgets.VBox([widgets.Label(self.docs_title), self.documents_chooser]),
widgets.VBox([widgets.Label(self.docs_limit_title), self.doc_limit]),
widgets.VBox([widgets.Label(self.mct_export_title), self.ckbox])
])

cui_filter = widgets.VBox([
widgets.HTML(f"<h2>{self.cui_filter_title_overall}</h2>"),
widgets.VBox([widgets.Label(self.cui_filter_title_file_chooser), self.cui_filter_chooser]),
widgets.VBox([widgets.Label(self.cui_filter_title_text), self.cui_filter_box]),
widgets.VBox([widgets.Label(self.cui_children_title), self.cui_children])
])

# Combine all sections into a main VBox
main_box = widgets.VBox([
model_choosers,
cui_filter
])
display(main_box)


def _get_params(self):
model_path_1 = self.model1_chooser.selected
model_path_2 = self.model2_chooser.selected
documents_file = self.documents_chooser.selected
doc_limit = self.doc_limit.value
is_mct_export_compare = self.ckbox.value
if not is_mct_export_compare:
print(f"For models, selected:\nModel1: {model_path_1}\nModel2: {model_path_2}"
f"\nDocuments: {documents_file}")
else:
print(f"Selected:\nModel: {model_path_1}\nMCT export: {model_path_2}"
f"\nDocuments: {documents_file}")
# CUI filter
cui_filter = None
filter_children = None
if self.cui_filter_chooser.selected:
cui_filter = self.cui_filter_chooser.selected
elif self.cui_filter_box.value:
cui_filter = self.cui_filter_box.value
if self.cui_children.value and self.cui_children.value > 0:
filter_children = self.cui_children.value
print(f"For CUI filter, selected:\nFilter: {cui_filter}\nChildren: {filter_children}")
return (model_path_1, model_path_2, documents_file, doc_limit, is_mct_export_compare, cui_filter, filter_children)

def get_comparison(self) -> NBComparer:
return NBComparer(*self._get_params())
27 changes: 18 additions & 9 deletions medcat/compare_models/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import tqdm
import tempfile
import os
from itertools import islice

from compare_cdb import compare as compare_cdbs, CDBCompareResults
from compare_annotations import ResultsTally, PerAnnotationDifferences
Expand All @@ -17,18 +17,22 @@



def load_documents(file_name: str) -> Iterator[Tuple[str, str]]:
def load_documents(file_name: str, doc_limit: int = -1) -> Iterator[Tuple[str, str]]:
with open(file_name) as f:
df = pd.read_csv(f, names=["id", "text"])
if df.iloc[0].id == "id" and df.iloc[0].text == "text":
# removes the header
# but also messes up the index a little
df = df.iloc[1:, :]
yield from df.itertuples(index=False)
if doc_limit == -1:
yield from df.itertuples(index=False)
else:
yield from islice(df.itertuples(index=False), doc_limit)


def do_counting(cat1: CAT, cat2: CAT,
ann_diffs: PerAnnotationDifferences) -> ResultsTally:
ann_diffs: PerAnnotationDifferences,
doc_limit: int = -1) -> ResultsTally:
def cui2name(cat, cui):
if cui in cat.cdb.cui2preferred_name:
return cat.cdb.cui2preferred_name[cui]
Expand All @@ -39,7 +43,8 @@ def cui2name(cat, cui):
cui2name=partial(cui2name, cat1))
res2 = ResultsTally(pt2ch=_get_pt2ch(cat2), cat_data=cat2.cdb.make_stats(),
cui2name=partial(cui2name, cat2))
for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values()):
total = doc_limit if doc_limit != -1 else None
for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values(), total=total):
res1.count(per_doc.raw1)
res2.count(per_doc.raw2)
return res1, res2
Expand All @@ -52,6 +57,7 @@ def _get_pt2ch(cat: CAT) -> Optional[Dict]:
def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str, str]],
show_progress: bool = True,
keep_raw: bool = True,
doc_limit: int = -1
) -> PerAnnotationDifferences:
pt2ch1: Optional[Dict] = _get_pt2ch(cat1)
pt2ch2: Optional[Dict] = _get_pt2ch(cat2)
Expand All @@ -63,7 +69,8 @@ def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str
model2_cuis=set(cat2.cdb.cui2names),
keep_raw=keep_raw,
save_options=save_opts)
for doc_id, doc in tqdm.tqdm(documents, disable=not show_progress):
total = doc_limit if doc_limit != -1 else None
for doc_id, doc in tqdm.tqdm(documents, disable=not show_progress, total=total):
pad.look_at_doc(cat1.get_entities(doc), cat2.get_entities(doc), doc_id, doc)
pad.finalise()
return pad
Expand Down Expand Up @@ -107,9 +114,10 @@ def get_diffs_for(model_pack_path_1: str,
include_children_in_filter: Optional[int] = None,
supervised_train_comparison_model: bool = False,
keep_raw: bool = True,
doc_limit: int = -1,
) -> Tuple[CDBCompareResults, ResultsTally, ResultsTally, PerAnnotationDifferences]:
validate_input(model_pack_path_1, model_pack_path_2, documents_file, cui_filter, supervised_train_comparison_model)
documents = load_documents(documents_file)
documents = load_documents(documents_file, doc_limit=doc_limit)
if show_progress:
print("Loading [1]", model_pack_path_1)
cat1 = CAT.load_model_pack(model_pack_path_1)
Expand Down Expand Up @@ -145,10 +153,11 @@ def get_diffs_for(model_pack_path_1: str,
len(cui_filter), "CUIs")
cat1.config.linking.filters.cuis = cui_filter
cat2.config.linking.filters.cuis = cui_filter
ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw)
ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw,
doc_limit=doc_limit)
if show_progress:
print("Counting [1&2]")
res1, res2 = do_counting(cat1, cat2, ann_diffs)
res1, res2 = do_counting(cat1, cat2, ann_diffs, doc_limit=doc_limit)
if show_progress:
print("CDB compare")
cdb_diff = compare_cdbs(cat1.cdb, cat2.cdb)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
289001005,289004002,226207007,165224005,129043005,129040008,129041007,704440004,704439001,704437004,129065005,129039006,129035000,165232002,1080000000000000,716422006,45850009,284908004,129045003,129062008,714887007,714916007,719024002,715127003,714915006,282882001,302040002,302043000,165243005,365112008,165255004,105504002,301563003,301497008,160680006,301589003,165248001,165249009,362000000000000,270469004,160729004,248000000000000,160734000,405807003,160685001,285035003,285038001,285034004,428483004,1912002,431188001,864000000000000,301563003,1070000000000000,301497008,78459008,284915007,307439001,229798009,229799001,229797004,31000000000000,818000000000000,763692001,404934007,863721000000101,1100000017,1100000016,1100000030,1100000030,1100000011,863721000000101,863721000000101,863721000000101,282884000,282884000,282884000,1100000027,361721000000103,1100000028,361721000000103,1100000028,361721000000103,1100000027,361721000000103,361721000000103,1100000027,1100000028,1100000027,361721000000103,361721000000103,1100000028,1100000031,31031000119102,310131003,895488007,895488007,394923006,394923006,394923006,394923006,394923006,248171000000108,248171000000108,248171000000108,248171000000108,1100000015,1100000015,1100000012,1100000012,1100000013,1100000012,302046008,25711000087100,165233007,1100000029,718705001,718360006,282871009,895486006,895486006,895486006,895486006,699650006,699650006,8510008,273302005,306171006,257301003,404930003,404930003,404930003,224221006,261001000,184156005,184156005,183376001,154091000119106,301627005,301627005,725594005,445414007,165803005,323701000000101,72042002,24029004,282971008,10610811000001107,161903000,979501000000100,301477003,282966001,1149222004,371153006,311925007,225602000,763264000,249902000,249902000,223600005,386323002,37013008,205511000000108,325831000000100,1073861000000108,273469003,129032002,286489001,761481000000107,129072006,1073311000000100,286489001,286490005,129026007,160689007,286493007,129031009,1069991000000102,1071641000000109,960681000000109
Loading

0 comments on commit 69a225d

Please sign in to comment.