Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8693wtx4b: Add code and notebook to compare two models #13

Merged
merged 91 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
e047255
CU-8693wtx4b: Add code and notebook to compare two models
mart-r Feb 22, 2024
11a7a05
CU-8693wtx4: Add working output (synthetic data) to notebook
mart-r Feb 22, 2024
59d34ae
CU-8693wtx4: Show cat data in raw format
mart-r Feb 22, 2024
9ed15bc
CU-8693wtx4: Add documentation to notebook
mart-r Feb 22, 2024
c2e12c6
CU-8693wtx4: Add per-cui output to code and notebook
mart-r Feb 22, 2024
6ff302a
CU-8693wtx4: Add test to GHA workflow
mart-r Feb 22, 2024
909f862
CU-8693wtx4: Remove old comments/commented code
mart-r Feb 22, 2024
8efdf62
CU-8693wtx4: Add option to filter tally results by CUI; along with re…
mart-r Feb 22, 2024
0018202
CU-8693wtx4: Add option to filter by CUI when comparing
mart-r Feb 22, 2024
3dec0e5
CU-8693wtx4: Add option to filter by CUI when comparing
mart-r Feb 22, 2024
69155af
CU-8693wtx4: Fix typing issue
mart-r Feb 22, 2024
8caa8ac
CU-8693wtx4: Fix typing issue [#2]
mart-r Feb 22, 2024
5c0696e
CU-8693wtx4: Add typing temp-issue to non-related files
mart-r Feb 22, 2024
093b473
CU-8693wtx4: Add typing temp-issue to non-related files [#2]
mart-r Feb 22, 2024
2bf145c
CU-8693wtx4b: Remove unnecessary whitespace
mart-r Feb 29, 2024
0c0c4ab
CU-8693wtx4b: Add option to iterate over annotation pairs
mart-r Feb 29, 2024
2913da8
CU-8693wtx4b: Add tests for iteration over annotation pairs
mart-r Feb 29, 2024
7b4f2a7
CU-8693wtx4b: Add option to iteration over subset of annotation pairs…
mart-r Feb 29, 2024
fe21967
CU-8693wtx4b: Add some documentation for annotation pair iteration
mart-r Feb 29, 2024
7d455b0
CU-8693wtx4b: Add annotation pair iteration to notebook
mart-r Feb 29, 2024
a0179e1
CU-8693wtx4b: Add option (defaulted to True) to omit identical to ann…
mart-r Feb 29, 2024
5184419
CU-8693wtx4b: Update notebook to omit identical in annotation pair it…
mart-r Feb 29, 2024
76ddd8d
CU-8693wtx4b: Update notebook with some extra documentation
mart-r Feb 29, 2024
7ebc80e
CU-8693wtx4b: Add further assertion to iteration with filter when tes…
mart-r Feb 29, 2024
72430c0
CU-8693wtx4b: Fix non-overlapping annotations
mart-r Feb 29, 2024
9080cb4
CU-8693wtx4b: Add more comprehensive test for non-overlapping annotat…
mart-r Feb 29, 2024
e423b65
CU-8693wtx4b: Allow none-valued dicst in dict comparison for output
mart-r Feb 29, 2024
c1ad97e
CU-8693wtx4b: Add tests for none-valued dicts in output
mart-r Feb 29, 2024
bd6e17e
CU-8693wtx4b: Add better handling of empty dicst in comparison in output
mart-r Feb 29, 2024
3be47b8
CU-8693wtx4b: Remove debug output
mart-r Feb 29, 2024
14fa15a
CU-8693wtx4b: Add more useful nulled dicts along with tests for them
mart-r Feb 29, 2024
1a2f976
CU-8693wtx4b: Update notebook with recent changes
mart-r Feb 29, 2024
9cd1fd5
CU-8693wtx4b: Remove debug output
mart-r Feb 29, 2024
0e5d148
CU-8693wtx4b: Fix typing issues
mart-r Feb 29, 2024
9f035d3
Merge branch 'main' of github.com:CogStack/working_with_cogstack into…
mart-r Mar 25, 2024
7edcb92
CU-8693wtx4b: Remove ignore file from mct_analysis
mart-r Mar 25, 2024
6b3f202
CU-8693wtx4b: Add option to recognise children and grandchildren as
mart-r Mar 26, 2024
9778085
CU-8693wtx4b: Fix typo
mart-r Mar 26, 2024
014b8d9
CU-8693wtx4b: Fix typo (#2)
mart-r Mar 26, 2024
51b5b6e
CU-8693wtx4b: Avoid annotating twice over the dataset
mart-r Mar 26, 2024
007ba9f
CU-8693wtx4b: Fix issue with raw annotations being empty at tally time
mart-r Apr 2, 2024
64a28f0
CU-869475m5q: Allow differentiating CUIs that one of the two models d…
mart-r Apr 2, 2024
ef09236
CU-869475m5q: Update notebook with latest output
mart-r Apr 3, 2024
976abf4
CU-869475h56: Allow CUI filter to be specified by a file with a list …
mart-r Apr 3, 2024
ac762c5
CU-869475h56: Update notebook with file based CUI filter information
mart-r Apr 3, 2024
38bb6fb
CU-869475m5q: Fix tests: add necessary per model CUI sets where appli…
mart-r Apr 3, 2024
098f473
CU-869475h38: Allow including children for per-cui view
mart-r Apr 5, 2024
83046b6
CU-869475h38: Make per-cui names easier to read
mart-r Apr 5, 2024
ef93016
CU-869475h38: Update notebook
mart-r Apr 5, 2024
ff2cf2c
CU-86948wv58: Add method to get CSV output of annotations
mart-r Apr 9, 2024
e4c37c3
CU-86948wv58: Improve documentation
mart-r Apr 9, 2024
77012ea
CU-86948wv58: Update model comparison with CSV output part
mart-r Apr 9, 2024
1729995
CU-86948wv58: Fix CSV output relative start and end for annotations
mart-r Apr 9, 2024
2f65a41
CU-86948wv58: Fix typing for annotations when creating CSV
mart-r Apr 9, 2024
65257d1
CU-86948wv58: Fix typing when creating sub-text for CSV
mart-r Apr 9, 2024
c25446f
CU-869498jf4: Add option to add children to CUI filter along with tes…
mart-r Apr 10, 2024
5942ce2
CU-869498jf4: Update notebook with notes regarding children in filter
mart-r Apr 10, 2024
a3cd63e
CU-8693wtx4b: Fix whitespace
mart-r Apr 10, 2024
eb2f4fd
CU-869498jf4: Add more output regarding filter after adding children
mart-r Apr 10, 2024
777f2e3
CU-869498jf4: Add children from both models if possible
mart-r Apr 10, 2024
5cd2284
CU-869475h0e: Add option to compare to a base model + supervised trainig
mart-r Apr 16, 2024
4a6ddd5
CU-869475h0e: Add tests for base model + supervised training; Add fak…
mart-r Apr 16, 2024
a4ee98e
CU-869475h0e: Add note for supervised training-based comparison
mart-r Apr 16, 2024
5d42045
CU-869475h0e: Add missing notes for supervised training-based comparison
mart-r Apr 16, 2024
0e7b6a3
CU-869475h78: Add markdown to dict comparison
mart-r Apr 25, 2024
cda586a
CU-869475h78: Fix notebook output for per-annotation differences
mart-r Apr 25, 2024
322aa21
CU-869475h78: Automatically use markdown when notebook output required
mart-r Apr 25, 2024
88e8966
CU-869475h78: Automatically detect when in notebook (by default)
mart-r Apr 25, 2024
d721f77
CU-869475h78: Update notebook with tabular / markdown output
mart-r Apr 25, 2024
79c7e0c
CU-8694ezhrn: Load documents as iterator instead of holding them in m…
mart-r Apr 30, 2024
262e373
CU-8694ezhrn: Add option to omit raw text
mart-r Apr 30, 2024
fd29bfe
CU-8694ezhrn: Add tests when omitting raw text
mart-r Apr 30, 2024
dd2c559
CU-8694ezhrn: Propgate keeping raw option properly
mart-r Apr 30, 2024
e6d6207
CU-8694ezhrn: Propgate keeping raw option properly (2)
mart-r Apr 30, 2024
ff1b1c4
CU-8694ezhrn: Do not store copies of existing dicts in annotation pairs
mart-r Apr 30, 2024
1c15312
CU-8693wtx4: Improve memory usage by saving the annotation difference…
mart-r Apr 30, 2024
f4366bc
CU-8693wtx4: Allow filtering by comparison type when converting to CSV
mart-r Apr 30, 2024
4547305
CU-8693wtx4: Add tests for filtering by comparison type when iteratin…
mart-r Apr 30, 2024
3035efb
CU-8693wtx4: Fix model type in case of using DB for annotation pairs
mart-r May 1, 2024
7497ac5
CU-8694ggtgm: Improve initial documentation for model input
mart-r May 7, 2024
aa7b850
CU-8694ggtgm: Add more accurate differences for two approaches in ini…
mart-r May 7, 2024
e54bdb2
CU-8694ggtgm: Add code for supervised training comparison (but commen…
mart-r May 7, 2024
841a770
CU-8694ggtgm: Add widget input
mart-r May 7, 2024
b792ed7
CU-8694ggtgm: Small documentation fix
mart-r May 7, 2024
87c189f
CU-8694ggtgm: Add new dependencies to requirementss file
mart-r May 16, 2024
0a5eaa6
CU-8694ggtgm: Add input data validation
mart-r May 16, 2024
d1344cf
CU-8694ggtgm: Move input data validation to its own module
mart-r May 16, 2024
e9135af
CU-8694ggtgm: Add further validation for medcat models
mart-r May 16, 2024
fed91b9
CU-8694ggtgm: Add further validation for documents file
mart-r May 16, 2024
9fd5e57
CU-8694ggtgm: Add wildcard MCT export validation
mart-r May 16, 2024
7e07a5a
CU-8694ggtgm: Replace use of 3.9+ exclusive string method
mart-r May 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- name: Test
run: |
python -m unittest discover
python -m unittest discover -s medcat/compare_models
# TODO - in the future, we might want to add automated tests for notebooks as well
# though it's not really possible right now since the notebooks are designed
# in a way that assumes interaction (i.e specifying model pack names)
62 changes: 62 additions & 0 deletions medcat/compare_models/cmp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Type, TypeVar, Generic, Iterable, Callable, Optional

import sqlite3
import re
from pydantic import BaseModel


T = TypeVar('T', bound=BaseModel)


def sanitize_table_name(name, max_length=64):
# Replace any characters not allowed in table names with underscores
name = re.sub(r'[^a-zA-Z0-9_$]', '_', name)
# Truncate the name if it's too long
name = name[:max_length]
return name


class SaveOptions(BaseModel):
use_db: bool = False
db_file_name: Optional[str] = None
clean_callback: Optional[Callable[[], None]] = None


class DifferenceDatabase(Generic[T]):

def __init__(self, db_file: str, part: str, model_type: Type[T],
batch_size: int = 100):
self.db_file = db_file
self.part = sanitize_table_name(part)
self.model_type = model_type
self.conn = sqlite3.connect(self.db_file)
self.cursor = self.conn.cursor()
self._create_table()
self._len = 0
self._batch_size = batch_size

def _create_table(self):
self.cursor.execute(f'''CREATE TABLE IF NOT EXISTS differences_{self.part}
(id INTEGER PRIMARY KEY, data TEXT)''')
self.conn.commit()

def append(self, difference: T):
data = difference.json()
self.cursor.execute(f"INSERT INTO differences_{self.part} (data) VALUES (?)", (data,))
self.conn.commit()
self._len += 1

def __iter__(self) -> Iterable[T]:
self.cursor.execute(f"SELECT data FROM differences_{self.part}")
while True:
rows = self.cursor.fetchmany(self._batch_size)
if not rows:
break
for row in rows:
yield self.model_type.parse_raw(row[0])

def __len__(self) -> int:
return self._len

def __del__(self):
self.conn.close()
163 changes: 163 additions & 0 deletions medcat/compare_models/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import List, Tuple, Dict, Set, Optional, Union, Iterator
from functools import partial
import glob

from medcat.cat import CAT

import pandas as pd
import tqdm
import tempfile

from compare_cdb import compare as compare_cdbs, CDBCompareResults
from compare_annotations import ResultsTally, PerAnnotationDifferences
from output import parse_and_show
from cmp_utils import SaveOptions



def load_documents(file_name: str) -> Iterator[Tuple[str, str]]:
with open(file_name) as f:
df = pd.read_csv(f, names=["id", "text"])
if df.iloc[0].id == "id" and df.iloc[0].text == "text":
# removes the header
# but also messes up the index a little
df = df.iloc[1:, :]
yield from df.itertuples(index=False)


def do_counting(cat1: CAT, cat2: CAT,
ann_diffs: PerAnnotationDifferences) -> ResultsTally:
def cui2name(cat, cui):
if cui in cat.cdb.cui2preferred_name:
return cat.cdb.cui2preferred_name[cui]
all_names = cat.cdb.cui2names[cui]
# longest anme
return sorted(all_names, key=lambda name: len(name), reverse=True)[0]
res1 = ResultsTally(pt2ch=_get_pt2ch(cat1), cat_data=cat1.cdb.make_stats(),
cui2name=partial(cui2name, cat1))
res2 = ResultsTally(pt2ch=_get_pt2ch(cat2), cat_data=cat2.cdb.make_stats(),
cui2name=partial(cui2name, cat2))
for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values()):
res1.count(per_doc.raw1)
res2.count(per_doc.raw2)
return res1, res2


def _get_pt2ch(cat: CAT) -> Optional[Dict]:
return cat.cdb.addl_info.get("pt2ch", None)


def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str, str]],
show_progress: bool = True,
keep_raw: bool = True,
) -> PerAnnotationDifferences:
pt2ch1: Optional[Dict] = _get_pt2ch(cat1)
pt2ch2: Optional[Dict] = _get_pt2ch(cat2)
temp_file = tempfile.NamedTemporaryFile()
save_opts = SaveOptions(use_db=True, db_file_name=temp_file.name,
clean_callback=temp_file.close)
pad = PerAnnotationDifferences(pt2ch1=pt2ch1, pt2ch2=pt2ch2,
model1_cuis=set(cat1.cdb.cui2names),
model2_cuis=set(cat2.cdb.cui2names),
keep_raw=keep_raw,
save_options=save_opts)
for doc_id, doc in tqdm.tqdm(documents, disable=not show_progress):
pad.look_at_doc(cat1.get_entities(doc), cat2.get_entities(doc), doc_id, doc)
pad.finalise()
return pad


def load_cui_filter(filter_file: str) -> Set[str]:
with open(filter_file) as f:
str_list = f.read().split(',')
return set(item.strip() for item in str_list)


def _add_all_children(cat: CAT, cui_filter: Set[str], include_children: int) -> None:
if include_children <= 0:
return
if "pt2ch" not in cat.cdb.addl_info:
return
pt2ch = cat.cdb.addl_info["pt2ch"]
children = set(ch for cui in cui_filter for ch in pt2ch.get(cui, []))
if include_children > 1:
_add_all_children(cat, children, include_children=include_children-1)
cui_filter.update(children)


def load_and_train(model_pack_path: str, mct_export_path: str) -> CAT:
cat = CAT.load_model_pack(model_pack_path)
# NOTE: Allowing mct_export_path to contain wildcat ("*").
# And in such a case, iterating over all matching files
if "*" not in mct_export_path:
cat.train_supervised_from_json(mct_export_path)
else:
for file in glob.glob(mct_export_path):
cat.train_supervised_from_json(file)
return cat


def get_diffs_for(model_pack_path_1: str,
model_pack_path_2: str,
documents_file: str,
cui_filter: Optional[Union[Set[str], str]] = None,
show_progress: bool = True,
include_children_in_filter: Optional[int] = None,
supervised_train_comparison_model: bool = False,
keep_raw: bool = True,
) -> Tuple[CDBCompareResults, ResultsTally, ResultsTally, PerAnnotationDifferences]:
documents = load_documents(documents_file)
if show_progress:
print("Loading [1]", model_pack_path_1)
cat1 = CAT.load_model_pack(model_pack_path_1)
if show_progress:
print("Loading [2]", model_pack_path_2)
if not supervised_train_comparison_model:
cat2 = CAT.load_model_pack(model_pack_path_2)
else:
if show_progress:
print("Reloading model pack 1", model_pack_path_1)
print("And subsequently training on", model_pack_path_2)
print("This may take a while, depending on the amount of "
"data is being trained on")
cat2 = load_and_train(model_pack_path_1, model_pack_path_2)
if show_progress:
print("Per annotations diff finding")
if cui_filter:
if isinstance(cui_filter, str):
cui_filter = load_cui_filter(cui_filter)
if show_progress:
print("Applying filter to CATs:", len(cui_filter), 'CUIs')
if include_children_in_filter:
if show_progress:
print("Adding all children of", include_children_in_filter,
"or lower level from first model")
_add_all_children(cat1, cui_filter, include_children_in_filter)
if show_progress:
print("After adding children from 1st model have a total of",
len(cui_filter), "CUIs")
_add_all_children(cat2, cui_filter, include_children_in_filter)
if show_progress:
print("After adding children from 2nd model have a total of",
len(cui_filter), "CUIs")
cat1.config.linking.filters.cuis = cui_filter
cat2.config.linking.filters.cuis = cui_filter
ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw)
if show_progress:
print("Counting [1&2]")
res1, res2 = do_counting(cat1, cat2, ann_diffs)
if show_progress:
print("CDB compare")
cdb_diff = compare_cdbs(cat1.cdb, cat2.cdb)
return cdb_diff, res1, res2, ann_diffs


def main(mpn1: str, mpn2: str, documents_file: str):
cdb_diff, res1, res2, ann_diffs = get_diffs_for(mpn1, mpn2, documents_file, show_progress=False)
print("Results:")
parse_and_show(cdb_diff, res1, res2, ann_diffs)


if __name__ == "__main__":
import sys
main(*sys.argv[1:])
Loading
Loading