-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CU-86956a6y5 improve comparison (#15)
* CU-86956a6y5: Add helper for comparison in notebook * CU-86956a6y5: Update model comparison notebook * CU-86956a6y5: Update per document view to allow ignoring empty documents * CU-86956a6y5: Update notebook with new demo * CU-86956a6y5: Add (truncated) demo data * CU-86956a6y5: Add optional doc limit to comparison tool
- Loading branch information
Showing
5 changed files
with
18,635 additions
and
671 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from ipyfilechooser import FileChooser | ||
from ipywidgets import widgets | ||
from IPython.display import display | ||
import os | ||
from typing import List, Optional | ||
|
||
|
||
from compare import get_diffs_for | ||
from output import parse_and_show, show_dict_deep, compare_dicts | ||
|
||
|
||
_def_path = '../../models/modelpack' | ||
_def_path = _def_path if os.path.exists(_def_path) else '.' | ||
|
||
|
||
class NBComparer: | ||
|
||
def __init__(self, model_path_1: str, model_path_2: str, | ||
documents_file: str, doc_limit: int, is_mct_export_compare: bool, | ||
cui_filter: str, filter_children: bool) -> None: | ||
self.model_path_1 = model_path_1 | ||
self.model_path_2 = model_path_2 | ||
self.documents_file = documents_file | ||
self.doc_limit = doc_limit | ||
self.is_mct_export_compare = is_mct_export_compare | ||
self.cui_filter = cui_filter | ||
self.filter_children = filter_children | ||
self._run_comparison() | ||
|
||
def _run_comparison(self): | ||
(self.cdb_comp, self.tally1, self.tally2, self.ann_diffs) = get_diffs_for( | ||
self.model_path_1, self.model_path_2, self.documents_file, | ||
cui_filter=self.cui_filter, include_children_in_filter=self.filter_children, | ||
supervised_train_comparison_model=self.is_mct_export_compare, doc_limit=self.doc_limit) | ||
|
||
def show_all(self): | ||
parse_and_show(self.cdb_comp, self.tally1, self.tally2, self.ann_diffs) | ||
|
||
def show_per_document(self, limit: int = -1, print_delimiter: bool = True, | ||
ignore_empty: bool = True): | ||
cnt = 0 | ||
for key in self.ann_diffs.per_doc_results.keys(): | ||
comp_dict = self.ann_diffs.per_doc_results[key].nr_of_comparisons | ||
if not ignore_empty or comp_dict: # ignore empty ones | ||
if print_delimiter: | ||
print('='*20,f'\n{key}', f'\n{"="*20}') | ||
show_dict_deep(self.ann_diffs.per_doc_results[key].nr_of_comparisons) | ||
cnt += 1 | ||
if limit > -1 and cnt == limit: | ||
break | ||
|
||
def diffs_to_csv(self, file_path: str) -> None: | ||
self.ann_diffs.to_csv(file_path) | ||
|
||
def compare_for_cui(self, cui: str, include_children: int = 2) -> None: | ||
per_cui1 = self.tally1.get_for_cui(cui, include_children=include_children) | ||
per_cui2 = self.tally2.get_for_cui(cui, include_children=include_children) | ||
compare_dicts(per_cui1, per_cui2) | ||
|
||
def show_docs(self, docs: List[str], show_delimiter: bool = True, | ||
omit_identical: bool = True): | ||
for doc_name, pair in self.ann_diffs.iter_ann_pairs(docs=docs, omit_identical=omit_identical): | ||
if show_delimiter: | ||
print('='*20,f'\n{doc_name} ({pair.comparison_type})', f'\n{"="*20}') | ||
# NOTE: if only one of the two has an annotation, the other one will be None | ||
# the following will deal with that automatically, though | ||
compare_dicts(pair.one, pair.two) | ||
|
||
|
||
class NBInputter: | ||
models_overall_title = "Models and data" | ||
mc1_title = "Choose model 1" | ||
mc2_title = "Choose model 2 (or an MCT export)" | ||
docs_title = "Choose the documents file (.csv with 'text' field)" | ||
docs_limit_title = "Limit the number of documents to run (-1 to disable)" | ||
mct_export_title = "Is the 2nd path an MCT export (instead of a model)?" | ||
cui_filter_title_overall = "CUI Filter" | ||
cui_filter_title_file_chooser = "Choose file with comma-separated CUIs" | ||
cui_filter_title_text = "List comma-separated CUIs" | ||
cui_children_title = "How many layers of children of concepts to include?" | ||
|
||
def __init__(self) -> None: | ||
self.model1_chooser = FileChooser(_def_path) | ||
self.model2_chooser = FileChooser(_def_path) | ||
self.documents_chooser = FileChooser(".") | ||
self.doc_limit = widgets.IntText(-1) | ||
self.ckbox = widgets.Checkbox(description="MCT export compare") | ||
|
||
self.cui_filter_chooser = FileChooser(".", description="The CUI filter file") | ||
self.cui_filter_box = widgets.Textarea(description="CUI list") | ||
self.cui_children = widgets.IntText(description="Children", value=-1) | ||
|
||
def show_all(self): | ||
model_choosers = widgets.VBox([ | ||
widgets.HTML(f"<h2>{self.models_overall_title}</h2>"), | ||
widgets.VBox([widgets.Label(self.mc1_title), self.model1_chooser]), | ||
widgets.VBox([widgets.Label(self.mc2_title), self.model2_chooser]), | ||
widgets.VBox([widgets.Label(self.docs_title), self.documents_chooser]), | ||
widgets.VBox([widgets.Label(self.docs_limit_title), self.doc_limit]), | ||
widgets.VBox([widgets.Label(self.mct_export_title), self.ckbox]) | ||
]) | ||
|
||
cui_filter = widgets.VBox([ | ||
widgets.HTML(f"<h2>{self.cui_filter_title_overall}</h2>"), | ||
widgets.VBox([widgets.Label(self.cui_filter_title_file_chooser), self.cui_filter_chooser]), | ||
widgets.VBox([widgets.Label(self.cui_filter_title_text), self.cui_filter_box]), | ||
widgets.VBox([widgets.Label(self.cui_children_title), self.cui_children]) | ||
]) | ||
|
||
# Combine all sections into a main VBox | ||
main_box = widgets.VBox([ | ||
model_choosers, | ||
cui_filter | ||
]) | ||
display(main_box) | ||
|
||
|
||
def _get_params(self): | ||
model_path_1 = self.model1_chooser.selected | ||
model_path_2 = self.model2_chooser.selected | ||
documents_file = self.documents_chooser.selected | ||
doc_limit = self.doc_limit.value | ||
is_mct_export_compare = self.ckbox.value | ||
if not is_mct_export_compare: | ||
print(f"For models, selected:\nModel1: {model_path_1}\nModel2: {model_path_2}" | ||
f"\nDocuments: {documents_file}") | ||
else: | ||
print(f"Selected:\nModel: {model_path_1}\nMCT export: {model_path_2}" | ||
f"\nDocuments: {documents_file}") | ||
# CUI filter | ||
cui_filter = None | ||
filter_children = None | ||
if self.cui_filter_chooser.selected: | ||
cui_filter = self.cui_filter_chooser.selected | ||
elif self.cui_filter_box.value: | ||
cui_filter = self.cui_filter_box.value | ||
if self.cui_children.value and self.cui_children.value > 0: | ||
filter_children = self.cui_children.value | ||
print(f"For CUI filter, selected:\nFilter: {cui_filter}\nChildren: {filter_children}") | ||
return (model_path_1, model_path_2, documents_file, doc_limit, is_mct_export_compare, cui_filter, filter_children) | ||
|
||
def get_comparison(self) -> NBComparer: | ||
return NBComparer(*self._get_params()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
289001005,289004002,226207007,165224005,129043005,129040008,129041007,704440004,704439001,704437004,129065005,129039006,129035000,165232002,1080000000000000,716422006,45850009,284908004,129045003,129062008,714887007,714916007,719024002,715127003,714915006,282882001,302040002,302043000,165243005,365112008,165255004,105504002,301563003,301497008,160680006,301589003,165248001,165249009,362000000000000,270469004,160729004,248000000000000,160734000,405807003,160685001,285035003,285038001,285034004,428483004,1912002,431188001,864000000000000,301563003,1070000000000000,301497008,78459008,284915007,307439001,229798009,229799001,229797004,31000000000000,818000000000000,763692001,404934007,863721000000101,1100000017,1100000016,1100000030,1100000030,1100000011,863721000000101,863721000000101,863721000000101,282884000,282884000,282884000,1100000027,361721000000103,1100000028,361721000000103,1100000028,361721000000103,1100000027,361721000000103,361721000000103,1100000027,1100000028,1100000027,361721000000103,361721000000103,1100000028,1100000031,31031000119102,310131003,895488007,895488007,394923006,394923006,394923006,394923006,394923006,248171000000108,248171000000108,248171000000108,248171000000108,1100000015,1100000015,1100000012,1100000012,1100000013,1100000012,302046008,25711000087100,165233007,1100000029,718705001,718360006,282871009,895486006,895486006,895486006,895486006,699650006,699650006,8510008,273302005,306171006,257301003,404930003,404930003,404930003,224221006,261001000,184156005,184156005,183376001,154091000119106,301627005,301627005,725594005,445414007,165803005,323701000000101,72042002,24029004,282971008,10610811000001107,161903000,979501000000100,301477003,282966001,1149222004,371153006,311925007,225602000,763264000,249902000,249902000,223600005,386323002,37013008,205511000000108,325831000000100,1073861000000108,273469003,129032002,286489001,761481000000107,129072006,1073311000000100,286489001,286490005,129026007,160689007,286493007,129031009,1069991000000102,1071641000000109,960681000000109 |
Oops, something went wrong.