diff --git a/medcat-v2/paper/data/supervised/MDACE/raw/README.md b/medcat-v2/paper/data/supervised/MDACE/raw/README.md new file mode 100644 index 000000000..1753785fe --- /dev/null +++ b/medcat-v2/paper/data/supervised/MDACE/raw/README.md @@ -0,0 +1,18 @@ +First we download the MDACE dataset and prepare it with MIMIC-IV as per instructions: +https://github.com/3mcloud/MDACE + +Then, we need to convert the data to a format MedCAT can understand using: +```python +python convert_to_mct_export.py # no need for arguments if in this folder +``` + +However, that still only has ICD-10 codes. +Yet the models we're comparing to use SNOMED. + +So we then need to convert to SNOMED by doing: +```python +python map_from_icd_to_snomed.py ../icd10_convert.json ../mct_export_with_candidates.json +``` + +This will create a trainer export that has multiple CUIs as options for each annotation. +That is because ICD-10 codes can map to multiple different Snomed concepts and there is no automated way to create a 1 to 1 mapping. diff --git a/medcat-v2/paper/data/supervised/MDACE/raw/convert_to_mct_export.py b/medcat-v2/paper/data/supervised/MDACE/raw/convert_to_mct_export.py new file mode 100644 index 000000000..50ff8d32d --- /dev/null +++ b/medcat-v2/paper/data/supervised/MDACE/raw/convert_to_mct_export.py @@ -0,0 +1,87 @@ +import json +import os +import sys +from datetime import datetime +from typing import Iterator + +from medcat.data.mctexport import ( + MedCATTrainerExport, MedCATTrainerExportProject, + MedCATTrainerExportDocument, MedCATTrainerExportAnnotation) +from medcat.data.mctexport import count_all_annotations, count_all_docs + +DEFAULT_INPUT_DIR = "with_text/gold" +DEFAULT_OUTPUT_PATH = "../icd10_convert.json" + + +def get_all_jsons(input_dir: str) -> Iterator[str]: + for fn in os.listdir(input_dir): + path = os.path.join(input_dir, fn) + if os.path.isdir(path): + yield from get_all_jsons(path) + elif path.endswith(".json"): + yield path + + +def do_conversion( + input_dir: str = DEFAULT_INPUT_DIR, + output_file: str = DEFAULT_OUTPUT_PATH): + mod_time = datetime.now().isoformat() + all_out: MedCATTrainerExport = { + "projects": [] + } + + for path in get_all_jsons(input_dir): + if not path.endswith(".json"): + continue + with open(path) as f: + in_data = json.load(f) + documents: list[MedCATTrainerExportDocument] = [] + proj_id = in_data["hadm_id"] + proj_name = f'MDACE_{proj_id}' + project: MedCATTrainerExportProject = { + "documents": documents, + "name": proj_name, + "id": proj_id, + "cuis": "", + "tuis": "", + } + all_out["projects"].append(project) + + in_notes = in_data["notes"] # guess name + for in_doc in in_notes: + doc_id = in_doc["note_id"] + doc_name = f'{in_doc["description"]}_{doc_id}' + anns: list[MedCATTrainerExportAnnotation] = [] + documents.append( + { + "name": doc_name, + "id": doc_id, + "last_modified": mod_time, + "text": in_doc["text"], + "annotations": anns, + } + ) + + for ann_num, ann in enumerate(in_doc["annotations"]): + anns.append( + { + "start": ann["begin"], + "end": ann["end"], + # NOTE: this is currently in ICD + "cui": ann["code"], + "value": ann["covered_text"], + "id": f"{proj_name}_{doc_name}_{ann_num}", + "meta_anns": [], + "validated": True, + } + ) + print("GOT", len(all_out["projects"]), "projects", + "with", count_all_annotations(all_out), "annotations", + "across", count_all_docs(all_out), "documents") + + with open(output_file, "w") as of: + json.dump(all_out, of, indent=2) + + +if __name__ == "__main__": + do_conversion(*sys.argv[1:]) diff --git a/medcat-v2/paper/data/supervised/MDACE/raw/map_from_icd_to_snomed.py b/medcat-v2/paper/data/supervised/MDACE/raw/map_from_icd_to_snomed.py new file mode 100644 index 000000000..0702238ba --- /dev/null +++ b/medcat-v2/paper/data/supervised/MDACE/raw/map_from_icd_to_snomed.py @@ -0,0 +1,104 @@ +import sys +import json +from collections import defaultdict + +from medcat.cat import CAT +from medcat.data.mctexport import ( + MedCATTrainerExport, MedCATTrainerExportAnnotation, + count_all_annotations, count_all_docs) + + +def load_export(path: str) -> MedCATTrainerExport: + with open(path) as f: + return json.load(f) + + +def icd2snomed(cat: CAT) -> dict[str, list[str]]: + code2snomed: dict[str, list[str]] = defaultdict(list) + cui2icd10 = cat.cdb.addl_info["cui2icd10"] + for cui_info in cat.cdb.cui2info.values(): + cui = cui_info["cui"] + for icd10 in cui2icd10.get(cui, []): + code2snomed[icd10].append(cui) + print("GOT", len(code2snomed), "ICD codes") + print("Mapped to", sum(len(v) for v in code2snomed.values()), + "total Snomed CUIs") + return code2snomed + + +def pick_concept(cat: CAT, + mapper: dict[str, list[str]], + ann: MedCATTrainerExportAnnotation) -> str | None: + # NOTE: I could try and select 1 - the best + # but there isn't really a good way to do that. + # Instead, we'll use all as candidates + return mapper.get(ann["cui"]) + + +def convert_export( + cat: CAT, export: MedCATTrainerExport + ) -> MedCATTrainerExport: + mapper = icd2snomed(cat) + return { + "projects": [ + { + "id": proj["id"], + "name": proj["name"], + "cuis": proj["cuis"], + "tuis": proj["tuis"], + "documents": docs + } + for proj in export["projects"] + if (docs := [ + { + "id": doc["id"], + "name": doc["name"], + "last_modified": doc["last_modified"], + "text": doc["text"], + "annotations": anns + } for doc in proj["documents"] + if (anns := [ + { + "id": ann["id"], + "start": ann["start"], + "end": ann["end"], + "value": ann["value"], + "cui": mapped_cui, + "meta_anns": ann["meta_anns"], + "validated": ann["validated"] + } for ann in doc["annotations"] + if (mapped_cui := pick_concept(cat, mapper, ann)) + ]) + ]) + ] + } + + +def main(model_pack_path: str, + icd10_export_path: str, + final_export_path: str): + print("Loading model pack", model_pack_path) + cat = CAT.load_model_pack(model_pack_path) + print("Loading export") + export = load_export(icd10_export_path) + print("Initial import has", count_all_docs(export), "docs", + "and", count_all_annotations(export), "anns within", + len(export["projects"]), "projects") + print("Converting...") + converted = convert_export(cat, export) + print("CONVERTED export HAS", count_all_docs(converted), "docs", + "and", count_all_annotations(converted), "anns within", + len(converted["projects"]), "projects") + from medcat.data.mctexport import iter_anns + lens = [] + for _, _, ann in iter_anns(converted): + lens.append(len(ann["cui"]) if isinstance(ann["cui"], list) else 1) + print("Total", len(lens), "annotations with", sum(lens) / len(lens), + "values on average") + print("Saving to", final_export_path) + with open(final_export_path, 'w') as f: + json.dump(converted, f) + + +if __name__ == "__main__": + main(*sys.argv[1:]) diff --git a/medcat-v2/paper/data/supervised/cometa/raw/README.md b/medcat-v2/paper/data/supervised/cometa/raw/README.md new file mode 100644 index 000000000..f2b9200c6 --- /dev/null +++ b/medcat-v2/paper/data/supervised/cometa/raw/README.md @@ -0,0 +1,7 @@ +First, we need to download the dataset: +https://metatext.io/datasets/cometa + +Then we need to convert to a format MedCAT understands: +```python +python conversion/converter.py chv.csv ../mct_export.json +``` \ No newline at end of file diff --git a/medcat-v2/paper/data/supervised/cometa/raw/conversion/converter.py b/medcat-v2/paper/data/supervised/cometa/raw/conversion/converter.py new file mode 100644 index 000000000..52f3d72a7 --- /dev/null +++ b/medcat-v2/paper/data/supervised/cometa/raw/conversion/converter.py @@ -0,0 +1,115 @@ +from sys import argv +import json +import os.path +from datetime import datetime + +from tqdm import tqdm +import pandas as pd + +from medcat.data.mctexport import ( + MedCATTrainerExport, MedCATTrainerExportProject, + MedCATTrainerExportAnnotation) +from medcat.data.mctexport import count_all_docs, count_all_annotations + + +COLS = ['Term', 'General SNOMED Label', 'General SNOMED ID', + 'Specific SNOMED Label', 'Specific SNOMED ID', 'Example', + 'Example Link', 'Origin_Sheet'] +COL4VALUE = "Term" +COL4CUI = "Specific SNOMED ID" +COL4TEXT = "Example" +COL4LINK = "Example Link" + +# November 2020 +LAST_MODIFIED = datetime(year=2020, month=11, day=1).isoformat() + + +def find_annotations(value: str, text: str, cui: str + ) -> list[MedCATTrainerExportAnnotation]: + value = value.lower() + orig_text = text + text = text.lower() + if value not in text: + raise ValueError(f"{repr(value)} not in text ({repr(text)})") + cur_start = 0 + anns: list[MedCATTrainerExportAnnotation] = [] + while (cur_index := text.find(value, cur_start)) >= 0: + start = cur_index + end = cur_index + len(value) + anns.append( + { + "cui": str(cui), + "value": orig_text[start: end], + "start": start, + "end": end, + } + ) + cur_start = end + if len(anns) > 100: + raise KeyError( + f"Too many annotations!, {start}, {end}, for {value}. " + f"cur start at {cur_start}") + return anns + + +def do_conversion(df: pd.DataFrame, proj_base_id: str, proj_base_name: str + ) -> MedCATTrainerExport: + projects: list[MedCATTrainerExportProject] = [] + for line_num, (index, line) in enumerate(tqdm(df.iterrows(), + total=len(df.index))): + text = line[COL4TEXT] + cui = line[COL4CUI] + try: + anns = find_annotations( + line[COL4VALUE], text, cui) + except ValueError as e: + print("LINE", line_num, "at index", index, + "Failed to load(VE):", str(e)) + continue + except AttributeError as e: + print("LINE", line_num, "at index", index, + "Failed to load(AE):", str(e)) + continue + proj_id = proj_base_id + str(index) + proj_name = proj_base_name + "@" + str(index) + # NOTE: each document is a project so that I can use per-project + # filters and thus only focus on the CUI in question and not + # the other terms in the text + projects.append({ + "documents": [ + { + "text": text, + "annotations": anns, + "id": str(index), + "name": f"LINK: {line[COL4LINK]}; ID: {index}", + "last_modified": LAST_MODIFIED + } + ], + "id": proj_id, + "name": proj_name, + "cuis": f'{cui}', + "tuis": '', + }) + return {"projects": projects} + + +def main(file_path: str, + export_path: str, + # TODO: options + ): + df = pd.read_csv(file_path, sep='\t', index_col=0, header=0).sort_index() + proj_name = export_path.split(os.path.sep + "cometa" + os.path.sep, 1)[-1] + proj_id = ".".join(proj_name.split(os.path.sep)[-2:]).replace(".csv", "") + print("Giving 'project' a name of", repr(proj_name)) + print("And setting ID to", proj_id) + mct_export = do_conversion(df, proj_id, proj_name) + print("Got", len(mct_export["projects"]), "projects with a total of", + count_all_docs(mct_export), "documents and a total of", + count_all_annotations(mct_export), "annotations") + print("Saving to", repr(export_path)) + with open(export_path, 'w') as f: + json.dump(mct_export, f) + + +if __name__ == "__main__": + main(*argv[1:]) diff --git a/medcat-v2/paper/data/supervised/distemist/raw/README.md b/medcat-v2/paper/data/supervised/distemist/raw/README.md new file mode 100644 index 000000000..0001b4559 --- /dev/null +++ b/medcat-v2/paper/data/supervised/distemist/raw/README.md @@ -0,0 +1,11 @@ +First we need to download and extract the distemist dataset: +https://temu.bsc.es/distemist/distemist-linking/ + +Subsequently, we convert to MedCAT supported format: +```python +python convert_to_mct_export.py distemist_zenodo/multilingual_resources/training_text_files/en distemist_zenodo/multilingual_resources/en ../mct_export.json +``` + +NOTE: +The underlying dataset (at least in some cases) links to multiple concepts per annotation. +And because of that the output also allows a subset of concepts. diff --git a/medcat-v2/paper/data/supervised/distemist/raw/convert_to_mct_export.py b/medcat-v2/paper/data/supervised/distemist/raw/convert_to_mct_export.py new file mode 100644 index 000000000..04ebfef32 --- /dev/null +++ b/medcat-v2/paper/data/supervised/distemist/raw/convert_to_mct_export.py @@ -0,0 +1,130 @@ +import sys +import os +from typing import Iterator +from datetime import datetime +from functools import lru_cache +import json + +import pandas as pd + +from medcat.data.mctexport import ( + MedCATTrainerExport, MedCATTrainerExportDocument) +from medcat.data.mctexport import count_all_annotations, count_all_docs + + +DEFAULT_TEXT_FOLDER = ( + "distemist_zenodo/multilingual_resources/training_text_files/en") +DEFAULT_ANN_FOLDER = ( + "distemist_zenodo/multilingual_resources/en") +DEFAULT_MOD_DATE = datetime.now().isoformat() +DEFAULT_DTYPE = { + "filename": str, + "mart": str, + "label": str, + "offset1": int, + "offset2": int, + "span": str, + "code": str, +} + + +def find_text_file(folder: str, base_name: str) -> str: + path = os.path.join(folder, base_name + ".txt") + if not os.path.exists(path): + raise ValueError(f"No such file/folder: {path}") + return path + + +def find_text(folder: str, base_name: str) -> str: + file_path = find_text_file(folder, base_name) + with open(file_path) as f: + return f.read() + + +@lru_cache +def get_doc(folder: str, base_name: str) -> MedCATTrainerExportDocument: + text = find_text(folder, base_name) + return { + "id": hash(base_name), + "name": base_name, + "last_modified": DEFAULT_MOD_DATE, + "text": text, + "annotations": [] + } + + +def get_docs( + annotation_folder: str, + text_folder: str, + ) -> Iterator[MedCATTrainerExportDocument]: + for file_name in os.listdir(annotation_folder): + print("Looking at annotation file", file_name) + if not file_name.endswith(".tsv"): + # print(" - IGNORE") + continue + file_path = os.path.join(annotation_folder, file_name) + df = pd.read_csv(file_path, sep="\t", dtype=DEFAULT_DTYPE, + na_values={"code": ""}) + print(" - Read Data", df.index.shape, '\n - And', df.columns) + for row_nr, row in df.iterrows(): + # print(" - Row nr", row_nr) + file_base_name = row.filename + # print("ROW", row) + # print("CODE", type(row.code), ":", row.code) + if row.code != row.code: + print("ROW", row) + print("CODE", type(row.code), ":", row.code) + print("Unsuitable! ignoring") + continue + cuis = row.code.split("+") + start, end = row.offset1, row.offset2 + value = row.span + doc = get_doc(text_folder, file_base_name) + doc["annotations"].append({ + "id": row_nr, + "cui": cuis, + "start": start, + "end": end, + "value": value, + "meta_anns": [], + "validated": True, + }) + yield doc + + +def build_export(text_folder: str, annotation_folder: str + ) -> MedCATTrainerExport: + docs: list[MedCATTrainerExportDocument] = [] + out = { + "projects": [ + { + "id": hash("distemist"), + "name": "distemist", + "cuis": "", + "tuis": "", + "documents": docs + } + ] + } + for cur_doc in get_docs(annotation_folder, text_folder): + if cur_doc not in docs: + # if multuple annotaitons in the same doc/text, + # we don't want multiple instances + docs.append(cur_doc) + return out + + +def main(text_folder: str, annotation_folder: str, + target_file: str): + export = build_export(text_folder, annotation_folder) + print("Built export w", len(export["projects"]), "projects", + count_all_docs(export), "docs and", count_all_annotations(export), + "annotations") + print("Saving to", target_file) + with open(target_file, 'w') as f: + json.dump(export, f) + print("Done!") + + +if __name__ == "__main__": + main(*sys.argv[1:]) diff --git a/medcat-v2/paper/data/supervised/linking_challenge/raw/README.md b/medcat-v2/paper/data/supervised/linking_challenge/raw/README.md new file mode 100644 index 000000000..cf470d6ae --- /dev/null +++ b/medcat-v2/paper/data/supervised/linking_challenge/raw/README.md @@ -0,0 +1,7 @@ +First, we need to download the 2023 Snomed linking challenge dataset: +https://www.drivendata.org/competitions/258/competition-snomed-ct/ + +Then, ocnvert to MedCAT supported format: +```python +python convert_to_mct_export.py mimic-iv_notes_training_set.csv train_annotations.csv ../mct_export.json +``` \ No newline at end of file diff --git a/medcat-v2/paper/data/supervised/linking_challenge/raw/convert_to_mct_export.py b/medcat-v2/paper/data/supervised/linking_challenge/raw/convert_to_mct_export.py new file mode 120000 index 000000000..fd957f236 --- /dev/null +++ b/medcat-v2/paper/data/supervised/linking_challenge/raw/convert_to_mct_export.py @@ -0,0 +1 @@ +/Users/martratas/Documents/CogStack/.MedCAT.nosync/MedCATv2_new/.temp/validation_datasets/linking_challenge_2023/convert_to_mct_export.py \ No newline at end of file diff --git a/medcat-v2/paper/data/supervised/medmentions/raw/README.md b/medcat-v2/paper/data/supervised/medmentions/raw/README.md new file mode 100644 index 000000000..8ed76468c --- /dev/null +++ b/medcat-v2/paper/data/supervised/medmentions/raw/README.md @@ -0,0 +1,15 @@ +First we need to download the MedMentions dataset: +https://github.com/chanzuckerberg/MedMentions + +Then, we need to convert it to a format MedCAT understands: +```python +python src/medmentions_converter.py corpus_pubtator.txt medmentions_umls.json +``` + +However, this still has UMLS codes instead of Snomed ones. +For that we also need UMLS (`MRCONSO.RRF`) to do the mappingp. + +To do the conversion into Snomed we do: +```python +python src/medmen_umls2snomed_converter.py medmentions_umls.json ../medmentions_snomed_stricter.json +``` \ No newline at end of file diff --git a/medcat-v2/paper/data/supervised/medmentions/raw/src/conversion_mapper.py b/medcat-v2/paper/data/supervised/medmentions/raw/src/conversion_mapper.py new file mode 100644 index 000000000..4ec76ed09 --- /dev/null +++ b/medcat-v2/paper/data/supervised/medmentions/raw/src/conversion_mapper.py @@ -0,0 +1,73 @@ +import os +import json +import pandas as pd +from medcat.model_creation.preprocess_umls import _DEFAULT_COLUMNS + + +def get_umls_df(umls_path: str) -> pd.DataFrame: + mrconso = os.path.join(umls_path, "MRCONSO.RRF") + df = pd.read_csv(mrconso, names=_DEFAULT_COLUMNS, sep="|", index_col=False) + print("INIT", len(df.index)) + df = df[df["LAT"] == "ENG"] + print("After LANG", len(df.index)) + df = df[df["SAB"].str.contains("SNOMEDCT")] + print("After SNOMED", len(df.index)) + df = df[df["SCUI"].notna()] + print("After non-none Snomed CUIs", len(df.index)) + return df + + +def load_cuis(needed_path: str) -> list[str]: + with open(needed_path) as f: + return [cui for line in f.readlines() if line for cui in line.split(",")] + + +def get_mappings(df: pd.DataFrame, umls_cuis: list[str], + status_order: list[str] = ['P', 'p', 'S', 's']) -> dict[str, str]: + print("GM") + custom_order = pd.CategoricalDtype(status_order, ordered=True) + out_dict = {} + for nr, cui in enumerate(umls_cuis): + print(nr, cui) + per_cui = df[df['CUI'] == cui] + per_cui['TS'] = per_cui['TS'].astype(custom_order) + per_cui = per_cui.sort_values('TS') + # print("PCUI", per_cui) + cui_and_status = per_cui[['CUI', 'TS']] + print("CUI and status", cui_and_status) + ordered_cuis = [row['CUI'] for _, row in cui_and_status.iterrows()] + # ordered_cuis = sorted([ + # (row['CUI'], row['TS']) for _, row in + # cui_and_status.iterrows() + # ], key=lambda cs: status_order.index(cs[1])) + # # remove duplicates + # ordered_cuis = [cui for nr, cui in enumerate(ordered_cuis) if cui not in ordered_cuis[:nr]] + print(cui, "Ordered CUIs", len(ordered_cuis)) + # scuis = per_cui['SCUI'].unique().tolist() + # if nr >= 25: + # raise + if len(ordered_cuis) == 0: + continue + if len(ordered_cuis) > 1: + print(f"{cui}:", len(ordered_cuis) if len(ordered_cuis) > 10 else ordered_cuis) + print("CONTEXT:") + for nr, row in per_cui.iterrows(): + print(row) + out_dict[cui] = ordered_cuis[0] + return out_dict + + +def main(*args: str): + umls_path, needed_path, json_path = args + umls_df = get_umls_df(umls_path) + needed_umls_cuis = load_cuis(needed_path) + print("Getting mappings") + map_dict = get_mappings(umls_df, needed_umls_cuis) + print("SAVING to", json_path) + with open(json_path, 'w') as f: + json.dump(map_dict, f) + + +if __name__ == "__main__": + import sys + main(*sys.argv[1:]) diff --git a/medcat-v2/paper/data/supervised/medmentions/raw/src/medmen_umls2snomed_converter.py b/medcat-v2/paper/data/supervised/medmentions/raw/src/medmen_umls2snomed_converter.py new file mode 100644 index 000000000..c28c517d4 --- /dev/null +++ b/medcat-v2/paper/data/supervised/medmentions/raw/src/medmen_umls2snomed_converter.py @@ -0,0 +1,258 @@ + +import json +import os +from copy import deepcopy +from functools import lru_cache +from typing import Callable + +import pandas as pd + +from medcat.data.mctexport import MedCATTrainerExport, iter_anns, iter_docs +from medcat.data.mctexport import MedCATTrainerExportAnnotation +from medcat.model_creation.preprocess_umls import _DEFAULT_COLUMNS +from medcat.cdb import CDB + + +UMLS_TYPE_TO_SNOMED_TYPE = { + "T033": ("67667581", "finding"), + "T059": ("28321150", "procedure"), + "T060": ("28321150", "procedure"), + # "T061": ("28321150", "procedure"), + "T090": ("16939031", "occupation"), + "T091": ("16939031", "occupation"), + "T071": ("2680757", "observable entity"),# and T077/conceptual entity? + # "" : ("40357424", "foundation metadata concept"), + # "" : ("29422548", "core metadata concept"), + "T072": ("32816260", "physical object"), + # "" : ("7882689", "qualifier value"), + "T167": ("91187746", "substance"), + # "" : ("72706784", "nan"), + "T001": ("81102976", "organism"), + "T017": ("37552161", "body structure"), # I think? + "T047": ("9090192", "disorder"), + # "" : ("33782986", "morphologic abnormality"), + # "" : ("66527446", "cell structure"), + "T061": ("47503797", "regime/therapy"), + # "" : ("91776366", "product"), + # "" : ("37785117", "medicinal product"), + "T025": ("99220404", "cell"), + # "" : ("31601201", "person"), + # "" : ("20410104", "ethnic group"), + # "" : ("75168589", "environment"), + "T051": ("33797723", "event"), + # "" : ("46922199", "religion/philosophy"), + "T201": ("43039974", "attribute"), + # "" : ("3061879", "situation"), + # "" : ("9593000", "medicinal product form"), + # "" : ("82417248", "navigational concept"), + # "" : ("43857361", "physical force"), + "T200": ("27603525", "clinical drug"), + # "" : ("13371933", "social concept"), + # "" : ("30703196", "tumor staging"), + # "" : ("337250", "specimen"), + # "" : ("8067332", "basic dose form"), + # "" : ("21114934", "dose form"), + # "" : ("55540447", "linkage concept"), + # "" : ("31685163", "staging scale"), + # "" : ("90170645", "record artifact"), + # "" : ("17030977", "assessment scale"), + # "" : ("25624495", "SNOMED RT+CTV3"), + # "" : ("18854038", "geographic location"), + # "" : ("78096516", "environment / location"), + # "" : ("92873870", "special concept"), + # "" : ("70426313", "namespace concept"), + # "" : ("14654508", "racial group"), + # "" : ("28695783", "link assertion"), + # "" : ("46506674", "disposition"), + # "" : ("39041339", "unit of presentation"), + # "" : ("51885115", "OWL metadata concept"), + # "" : ("49144999", "state of matter"), + # "" : ("66203715", "transformation"), + # "" : ("51120815", "intended site"), + # "" : ("64755083", "release characteristic"), + # "" : ("45958968", "administration method"), + # "" : ("87776218", "role"), + # "" : ("43744943", "supplier"), + # "" : ("95475658", "product name"), + # "" : ("40584095", "metadata"), + # "" : ("3242456", "life style"), +} + +SNOMED_TYPE_ID2NAME = { + '67667581': 'finding', '28321150': 'procedure', '16939031': 'occupation', + '2680757': 'observable entity', '40357424': 'foundation metadata concept', + '29422548': 'core metadata concept', '32816260': 'physical object', + '7882689': 'qualifier value', '91187746': 'substance', '72706784': 'nan', + '81102976': 'organism', '37552161': 'body structure', '9090192': 'disorder', + '33782986': 'morphologic abnormality', '66527446': 'cell structure', + '47503797': 'regime/therapy', '91776366': 'product', '37785117': 'medicinal product', + '99220404': 'cell', '31601201': 'person', '20410104': 'ethnic group', + '75168589': 'environment', '33797723': 'event', '46922199': 'religion/philosophy', + '43039974': 'attribute', '3061879': 'situation', '9593000': 'medicinal product form', + '82417248': 'navigational concept', '43857361': 'physical force', + '27603525': 'clinical drug', '13371933': 'social concept', '30703196': 'tumor staging', + '337250': 'specimen', '8067332': 'basic dose form', '21114934': 'dose form', + '55540447': 'linkage concept', '31685163': 'staging scale', '90170645': 'record artifact', + '17030977': 'assessment scale', '25624495': 'SNOMED RT+CTV3', + '18854038': 'geographic location', '78096516': 'environment / location', + '92873870': 'special concept', '70426313': 'namespace concept', '14654508': 'racial group', + '28695783': 'link assertion', '46506674': 'disposition', '39041339': 'unit of presentation', + '51885115': 'OWL metadata concept', '49144999': 'state of matter', '66203715': 'transformation', + '51120815': 'intended site', '64755083': 'release characteristic', + '45958968': 'administration method', '87776218': 'role', '43744943': 'supplier', + '95475658': 'product name', '40584095': 'metadata', '3242456': 'life style' +} + + + +def load_export(path: str) -> MedCATTrainerExport: + with open(path) as f: + return json.load(f) + + +def load_umls(umls_path: str) -> pd.DataFrame: + mrconso = os.path.join(umls_path, "MRCONSO.RRF") + df = pd.read_csv(mrconso, names=_DEFAULT_COLUMNS, sep="|", index_col=False) + print("INIT", len(df.index)) + df = df[df["LAT"] == "ENG"] + print("After LANG", len(df.index)) + df = df[df["SAB"].str.contains("SNOMEDCT")] + print("After SNOMED", len(df.index)) + df = df[df["SCUI"].notna()] + print("After removing None-CUIs", len(df.index)) + # remove column I don't care about + df = df.drop(["LAT", # language - already selected + "LUI", # unique identifier for term + "SUI", # unique identifier for string + "AUI", # Unique identifier for atom - variable length field, 8 or 9 characters + # source stuff - will get CUI from CODE + "SAUI", # Source asserted atom identifier [optional] + "SCUI", # Source asserted concept identifier [optional] + "SDUI", # Source asserted descriptor identifier [optional] + ], axis='columns') + return df + + +class TPG: + + def __init__(self, pt2ch: dict) -> None: + self.pt2ch = pt2ch + + @lru_cache + def get_root_parent(self, cui: str) -> str | None: + for pt, children in self.pt2ch.items(): + if cui in children: + rp = self.get_root_parent(pt) + if rp is None: + return cui + # if not a child of anything, must be root + return None + + +def pick_snomed_cui(ann: MedCATTrainerExportAnnotation, + umls_df: pd.DataFrame, + get_cui_name: Callable[[str], str], + tpg: TPG) -> str | None: + umls_cui = ann['cui'] + snomed_candidates = umls_df[umls_df['CUI'] == umls_cui] + num_of_candidates = len(snomed_candidates.index) + if num_of_candidates == 0: + return None + elif num_of_candidates == 1: + return snomed_candidates['CODE'].to_list()[0] + preferred = snomed_candidates[snomed_candidates["TS"] == "P"] + num_of_candidates = len(preferred) + if num_of_candidates == 1 or len(preferred['CODE'].unique()) == 1: + return preferred['CODE'].to_list()[0] + return None + if num_of_candidates == 0: + # check all if no preferred + print("No preferred candidates...") + preferred = snomed_candidates + cuis_and_names = [(row['CODE'], get_cui_name(row['CODE'])) + for _, row in preferred.iterrows() + if get_cui_name(row['CODE']) != row['CODE'] is not None] + cuis_with_name = set(cui for cui, _ in cuis_and_names) + if len(cuis_with_name) == 1: + return list(cuis_with_name)[0] + # find one exact match (if present) + names_of_cuis = set((name, cui) for cui, name in cuis_and_names + if name.lower() == ann['value'].lower()) + if len(names_of_cuis) == 1: + return list(names_of_cuis)[0][1] + cuitid2 = [ + (cui, name, tid, + UMLS_TYPE_TO_SNOMED_TYPE[tid], + # SNOMED_TYPE_ID2NAME[UMLS_TYPE_TO_SNOMED_TYPE[tid]] + ) + for cui, name in cuis_and_names + for tid in ann['type_ids'].split(",") + if tid in UMLS_TYPE_TO_SNOMED_TYPE] + print("CUI typeIDs") + print(cuitid2) + num_of_candidates = len(preferred) + print("Picking from", num_of_candidates, 'for', umls_cui, 'from') + print(preferred) + print("Context:", ann) + cand_cuis = [row['CODE'] for _, row in preferred.iterrows()] + for cui in cand_cuis: + root = tpg.get_root_parent(cui) # the type id + root_name = get_cui_name(root) if root else None + print("CUI2name", cui, f"({get_cui_name(cui)})", + "->", root, "->", root_name) + # import time + # time.sleep(0.2) + + +def convert(export: MedCATTrainerExport, + umls_df: pd.DataFrame, + cui2name: Callable[[str], str], + pt2ch: dict) -> MedCATTrainerExport: + export = deepcopy(export) + total_initial_anns = len(list(iter_anns(export))) + tpg = TPG(pt2ch) + for _, doc, ann in iter_anns(export): + snomed_cui = pick_snomed_cui(ann, umls_df, cui2name, tpg) + if snomed_cui: + ann['cui'] = snomed_cui + else: + ann['cui'] = None + total_kept = 0 + total_removed = 0 + for _, doc in iter_docs(export): + to_remove = [] + for nr, ann in enumerate(doc['annotations']): + if ann['cui'] is None: + to_remove.append(nr) + total_removed += len(to_remove) + total_kept += len(doc['annotations']) - len(to_remove) + # print("Removing", to_remove, "annotations") + for nr in to_remove[::-1]: + # start from end to avoid changing order while iterating + doc["annotations"].pop(nr) + print("Total removed", total_removed) + print("Total kept", total_kept) + print("TOTAL TOTAL", total_removed + total_kept, 'vs', total_initial_anns) + return export + + +def main(export_path: str, + cdb_path: str, + umls_path: str, + target_path: str) -> None: + print("Loading original") + export = load_export(export_path) + print("Getting CDB") + cdb = CDB.load(cdb_path) + pt2ch = cdb.addl_info['pt2ch'] + print("Loading UMLS") + umls_df = load_umls(umls_path) + print("Converting...") + converted = convert(export, umls_df, cdb.get_name, pt2ch) + with open(target_path, 'w') as f: + json.dump(converted, f) + + +if __name__ == "__main__": + import sys + main(*sys.argv[1:]) diff --git a/medcat-v2/paper/data/supervised/medmentions/raw/src/medmentions_converter.py b/medcat-v2/paper/data/supervised/medmentions/raw/src/medmentions_converter.py new file mode 100644 index 000000000..69e14429a --- /dev/null +++ b/medcat-v2/paper/data/supervised/medmentions/raw/src/medmentions_converter.py @@ -0,0 +1,86 @@ +import os +import json +from typing import Iterator + +from medcat.data.mctexport import MedCATTrainerExportDocument, MedCATTrainerExport +from medcat.data.mctexport import MedCATTrainerExportAnnotation +from datetime import datetime + + +def _unwrap_ann_line(line: str) -> MedCATTrainerExportAnnotation: + _, start, end, value, type_ids, cui = line.split("\t") + return { + "cui": cui, + "start": start, + "end": end, + "value": value, + "type_ids": type_ids,# EXTRA + } + + +def unwrap_anns(ann_lines: list[str]) -> list[MedCATTrainerExportAnnotation]: + return [ + _unwrap_ann_line(line) for line in ann_lines + ] + + +def load_medmentions(file_name: str) -> Iterator[tuple[str, str, str, dict]]: + with open(file_name) as f: + all_text = f.read() + for nr, part in enumerate(all_text.split("\n\n")): + if not part: + continue + # print("PART", nr, ":", type(part), len(part)) + title_line, a_line, *ann_lines = part.split("\n") + doc_id, title = title_line.split("|t|", 1) + _doc_id, abstract = a_line.split("|a|", 1) + assert doc_id == _doc_id + yield doc_id, title, abstract, unwrap_anns(ann_lines) + + +def get_export(file_name: str) -> MedCATTrainerExport: + mct_export: MedCATTrainerExport = { + "projects": [ + { + "cuis": "", + "documents": [], + "id": file_name, + "name": file_name, + "tuis": "", + } + ] + } + cur_docs: list[MedCATTrainerExportDocument] = [] + for doc_id, doc_title, ann_text, annotations in load_medmentions(file_name): + doc: MedCATTrainerExportDocument = { + "text": doc_title + " " + ann_text, + "annotations": annotations, + "id": doc_id, + "last_modified": datetime.now().isoformat() + } + cur_docs.append(doc) + mct_export['projects'][0]['documents'].extend(cur_docs) + return mct_export + + +def save_export(mct_export: dict, file_name: str) -> None: + if os.path.exists(file_name): + raise ValueError(f"File exists: {file_name}") + with open(file_name, 'w') as f: + json.dump(mct_export, f) + + +def load_json(fn: str) -> dict: + with open(fn) as f: + return json.load(f) + + +def main(*args: str): + in_file, out_file = args + mct_export = get_export(in_file) + save_export(mct_export, out_file) + + +if __name__ == "__main__": + import sys + main(*sys.argv[1:]) diff --git a/medcat-v2/paper/data/unsupervised/.keep b/medcat-v2/paper/data/unsupervised/.keep new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-v2/paper/out/.keep b/medcat-v2/paper/out/.keep new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-v2/paper/out/inference_speed/.keep b/medcat-v2/paper/out/inference_speed/.keep new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-v2/paper/out/load_speed/.keep b/medcat-v2/paper/out/load_speed/.keep new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-v2/paper/out/performance/.keep b/medcat-v2/paper/out/performance/.keep new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-v2/paper/out/unsup_train_speed/.keep b/medcat-v2/paper/out/unsup_train_speed/.keep new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-v2/paper/scripts/performance/common_pref.py b/medcat-v2/paper/scripts/performance/common_pref.py new file mode 100644 index 000000000..8deed344d --- /dev/null +++ b/medcat-v2/paper/scripts/performance/common_pref.py @@ -0,0 +1,3 @@ +import medcat + +IS_V2 = medcat.__version__.startswith("2.") diff --git a/medcat-v2/paper/scripts/performance/get_performance_all.sh b/medcat-v2/paper/scripts/performance/get_performance_all.sh new file mode 100644 index 000000000..7ac6d4408 --- /dev/null +++ b/medcat-v2/paper/scripts/performance/get_performance_all.sh @@ -0,0 +1,23 @@ +a + +script_path="scripts/performance/get_performance_for_model_and_ds.py" +v1_model_pack="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/20230227__kch_gstt_trained_model_no_mc_d84c313f24311484.zip" +v2_model_pack="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/CONVERT_2023_model_no_mc_234dda1597f635e3.zip" +# v1_model_pack="/Users/martratas/Documents/CogStack/MedCAT/medcat-snomed-model-creation/.creation_cache/out_snomed_2025/final_model_Snomed2025-07-11_de7cbec4a786e418.zip" +# v2_model_pack="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/2025_11_19_issue228_add_meta_cat_to_other/models/v2_Snomed2025_MIMIC_IV_bbe806e192df009f.zip" + +echo "*****************" +echo "running v1 stuff" +echo "*****************" + +source .venv_v1/bin/activate + +python $script_path $v1_model_pack data/supervised/*/*.json + +echo "*****************" +echo "running v2 stuff" +echo "*****************" + +source ../.venv312/bin/activate + +python $script_path $v2_model_pack data/supervised/*/*.json diff --git a/medcat-v2/paper/scripts/performance/get_performance_for_model_and_ds.py b/medcat-v2/paper/scripts/performance/get_performance_for_model_and_ds.py new file mode 100644 index 000000000..a83d6e0e9 --- /dev/null +++ b/medcat-v2/paper/scripts/performance/get_performance_for_model_and_ds.py @@ -0,0 +1,96 @@ +import json +import sys +import os +import re + +import pandas as pd +from tqdm import tqdm + +from common_pref import IS_V2 + +from medcat.cat import CAT +if IS_V2: + from medcat.data.mctexport import MedCATTrainerExport, iter_anns +else: + from medcat.stats.mctexport import MedCATTrainerExport, iter_anns + from v1_helper import MutableEntity, from_cdb + +from my_stats import StatsCalculator + + +def get_overall_prec_rec_f1(cat: CAT, export: MedCATTrainerExport, + filter_before_disamb: bool = False + ) -> tuple[float, float, float]: + if IS_V2: + calculator = StatsCalculator( + cat.config.components.linking.filters, + cat.cdb.cui2info) + if filter_before_disamb: + cat.config.components.linking.filter_before_disamb = True + else: + calculator = StatsCalculator( + cat.config.linking.filters, + from_cdb(cat.cdb)) + if filter_before_disamb: + cat.config.linking.filter_before_disamb = True + for proj in tqdm(export["projects"], desc="Projects"): + if IS_V2: + calculator.process_project( + proj, lambda text: cat(text).linked_ents, + show_progress=False) + else: + calculator.process_project( + proj, lambda text: MutableEntity.from_spacy_list( + cat(text).ents), + show_progress=False) + overall = calculator.compute_metrics()["overall"] + return overall["precision"], overall["recall"], overall["f1"] + + +PREC_REC_F1_PATTERN = re.compile( + r"Epoch: \d, Prec: (\d\.\d+), Rec: (\d\.\d+), F1: (\d\.\d+)") + + +def load_data(path: str, setup_filters: bool = True) -> MedCATTrainerExport: + with open(path) as f: + data = json.load(f) + # fix str -> int in some weird exports + for _, _, ann in iter_anns(data): + ann["start"] = int(ann["start"]) + ann["end"] = int(ann["end"]) + for proj in data["projects"]: + all_cuis: set[str] = set() + for doc in proj["documents"]: + for ann in doc["annotations"]: + cuis = ann["cui"] + if not isinstance(cuis, list): + cuis = [cuis, ] + all_cuis.update(cuis) + prev_cuis = proj["cuis"] + if prev_cuis: + all_cuis.update(proj["cuis"].split(",")) + all_cuis_str = ",".join(all_cuis) + proj["cuis"] = all_cuis_str + return data + + +def main(model_pack_path: str, + *export_paths: str): + cat = CAT.load_model_pack(model_pack_path) + out_data: list[tuple[str, float, float, float, float]] = [] + for export_path in export_paths: + print("Exploring", export_path) + data = load_data(export_path) + new_metrics = get_overall_prec_rec_f1(cat, data) + out_data.append([os.path.basename( + os.path.dirname(export_path))] + list(new_metrics)) + print(new_metrics) + df = pd.DataFrame( + out_data, + columns=["filename", "prec", "rec", "F1"] + ) + print(df.to_string()) + + +if __name__ == "__main__": + main(sys.argv[1], *sys.argv[2:]) diff --git a/medcat-v2/paper/scripts/performance/get_regression_all.sh b/medcat-v2/paper/scripts/performance/get_regression_all.sh new file mode 100644 index 000000000..c8dbd3ed7 --- /dev/null +++ b/medcat-v2/paper/scripts/performance/get_regression_all.sh @@ -0,0 +1,23 @@ + + +script_path="scripts/performance/regression_perf.py" +v1_model_pack="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/20230227__kch_gstt_trained_model_no_mc_d84c313f24311484.zip" +v2_model_pack="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/CONVERT_2023_model_no_mc_234dda1597f635e3.zip" +v1_out_file="out/performance/v1_regression.csv" +v2_out_file="out/performance/v2_regression.csv" + +echo "*****************" +echo "running v1 stuff" +echo "*****************" + +source .venv_v1/bin/activate + +python $script_path $v1_model_pack | head -n 1 >> $v1_out_file + +echo "*****************" +echo "running v2 stuff" +echo "*****************" + +source ../.venv312/bin/activate + +python $script_path $v2_model_pack | head -n 1 >> $v2_out_file diff --git a/medcat-v2/paper/scripts/performance/my_stats.py b/medcat-v2/paper/scripts/performance/my_stats.py new file mode 100644 index 000000000..83f8694fe --- /dev/null +++ b/medcat-v2/paper/scripts/performance/my_stats.py @@ -0,0 +1,246 @@ +from typing import Callable + +from tqdm import tqdm + +from common_pref import IS_V2 + +if IS_V2: + from medcat.data.mctexport import MedCATTrainerExportDocument + from medcat.data.mctexport import MedCATTrainerExportProject + from medcat.utils.filters import project_filters + from medcat.tokenizing.tokens import MutableEntity + from medcat.cdb.concepts import CUIInfo +else: + from medcat.stats.mctexport import MedCATTrainerExportDocument + from medcat.stats.mctexport import MedCATTrainerExportProject + from v1_helper import CUIInfo, project_filters, MutableEntity +from medcat.config import LinkingFilters + + +class StatsCalculator: + """Calculates precision/recall statistics for entity linking.""" + + def __init__(self, filters: LinkingFilters, cui2info: dict[str, CUIInfo]): + self.filters = filters + self.cui2info = cui2info + self._reset() + + def _reset(self): + self.tp = self.fp = self.fn = 0 + self.cui_tp: dict[str, int] = {} + self.cui_fp: dict[str, int] = {} + self.cui_fn: dict[str, int] = {} + self.examples: dict[str, dict[str, list]] = { + 'tp': {}, 'fp': {}, 'fn': {}} + + def process_document( + self, + doc: MedCATTrainerExportDocument, + predictions: list[MutableEntity] + ) -> None: + """ + Process a single document's annotations and predictions. + + Args: + doc: Gold-standard annotated document + predictions: Model's predicted entities + """ + gold_anns = self._extract_gold_annotations(doc) + pred_anns = self._extract_predictions(predictions) + + # Track which predictions have been matched + matched_preds: set[int] = set() + + # Phase 1: Match gold annotations to predictions (find TPs and FNs) + for gold in gold_anns: + match_idx = self._find_matching_prediction( + gold, pred_anns, matched_preds) + + if match_idx is not None: + # True Positive + matched_preds.add(match_idx) + pred = pred_anns[match_idx] + self._record_tp(gold, pred) + else: + # False Negative + self._record_fn(gold) + + # Phase 2: Remaining predictions are False Positives + for idx, pred in enumerate(pred_anns): + if idx not in matched_preds: + if self.filters.check_filters(pred["cui"]): + self._record_fp(pred) + + def process_project(self, project: MedCATTrainerExportProject, + entity_getter: Callable[[str], list[MutableEntity]], + use_project_filters: bool = True, + extra_cui_filter: set[str] | None = None, + show_progress: bool = True, + ) -> None: + with project_filters(self.filters, + project, + extra_cui_filter, + use_project_filters): + for doc in tqdm(project["documents"], disable=not show_progress, + desc="Documents"): + self.process_document(doc, entity_getter(doc["text"])) + + def _extract_gold_annotations( + self, + doc: MedCATTrainerExportDocument + ) -> list[dict]: + """Extract validated gold annotations, supporting multi-CUI options.""" + gold_anns = [] + + for ann in doc['annotations']: + if not ann.get('validated', True): + continue + if ann.get('killed', False) or ann.get('deleted', False): + continue + + # Support both single CUI and multiple acceptable CUIs + cuis = ann.get('acceptable_cuis', ann['cui']) + if not isinstance(cuis, list): + cuis = [cuis] + + # Filter to valid CUIs + valid_cuis = [ + cui for cui in cuis + if self.filters.check_filters(cui)] + + if valid_cuis: + gold_anns.append({ + 'start': ann['start'], + 'end': ann['end'], + 'cuis': valid_cuis, # List of acceptable CUIs + 'primary_cui': valid_cuis[0], # For counting + 'text': ann['value'], + 'raw': ann + }) + + return gold_anns + + def _extract_predictions( + self, + predictions: list[MutableEntity] + ) -> list[dict]: + """Extract relevant info from predicted entities.""" + return [{ + 'start': ent.base.start_char_index, + 'end': ent.base.end_char_index, + 'cui': ent.cui, + 'text': ent.base.text, + 'confidence': float(ent.context_similarity), + 'raw': ent + } for ent in predictions if self.filters.check_filters(ent.cui)] + + def _find_matching_prediction( + self, + gold: dict, + predictions: list[dict], + matched_preds: set[int] + ) -> int | None: + """ + Find a prediction that matches this gold annotation. + + Matching criteria: + - Same start position (can be relaxed for fuzzy matching) + - Predicted CUI is in gold's acceptable CUIs + - Not already matched + """ + for idx, pred in enumerate(predictions): + if idx in matched_preds: + continue + + # Exact span match + if pred['start'] == gold['start']: + # Check if predicted CUI is acceptable + if pred['cui'] in gold['cuis']: + return idx + + return None + + def _record_tp(self, gold: dict, pred: dict) -> None: + """Record a true positive.""" + cui = pred['cui'] + self.tp += 1 + self.cui_tp[cui] = self.cui_tp.get(cui, 0) + 1 + + if cui not in self.examples['tp']: + self.examples['tp'][cui] = [] + self.examples['tp'][cui].append({ + 'gold_text': gold['text'], + 'pred_text': pred['text'], + 'cui': cui, + 'start': pred['start'], + 'confidence': pred['confidence'] + }) + + def _record_fn(self, gold: dict) -> None: + """Record a false negative.""" + cui = gold['primary_cui'] + self.fn += 1 + self.cui_fn[cui] = self.cui_fn.get(cui, 0) + 1 + + if cui not in self.examples['fn']: + self.examples['fn'][cui] = [] + self.examples['fn'][cui].append({ + 'text': gold['text'], + 'acceptable_cuis': gold['cuis'], + 'start': gold['start'] + }) + + def _record_fp(self, pred: dict) -> None: + """Record a false positive.""" + cui = pred['cui'] + self.fp += 1 + self.cui_fp[cui] = self.cui_fp.get(cui, 0) + 1 + + if cui not in self.examples['fp']: + self.examples['fp'][cui] = [] + self.examples['fp'][cui].append({ + 'text': pred['text'], + 'cui': cui, + 'start': pred['start'], + 'confidence': pred['confidence'] + }) + + def compute_metrics(self) -> dict: + """Compute overall and per-CUI metrics.""" + metrics = { + 'overall': self._compute_prf(self.tp, self.fp, self.fn), + 'per_cui': {} + } + + all_cuis = ( + set(self.cui_tp.keys()) | set(self.cui_fp.keys()) | + set(self.cui_fn.keys())) + + for cui in all_cuis: + tp = self.cui_tp.get(cui, 0) + fp = self.cui_fp.get(cui, 0) + fn = self.cui_fn.get(cui, 0) + + metrics['per_cui'][cui] = { + 'name': self._get_cui_name(cui), + **self._compute_prf(tp, fp, fn), + 'tp': tp, 'fp': fp, 'fn': fn + } + + return metrics + + @staticmethod + def _compute_prf(tp: int, fp: int, fn: int) -> dict: + """Compute precision, recall, F1.""" + prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 + + return {'precision': prec, 'recall': rec, 'f1': f1} + + def _get_cui_name(self, cui: str) -> str: + """Get preferred name for CUI.""" + info = self.cui2info.get(cui) + if info: + return info.get('preferred_name') or list(info['names'])[0] + return cui diff --git a/medcat-v2/paper/scripts/performance/regression_perf.py b/medcat-v2/paper/scripts/performance/regression_perf.py new file mode 100644 index 000000000..d03f06849 --- /dev/null +++ b/medcat-v2/paper/scripts/performance/regression_perf.py @@ -0,0 +1,118 @@ +from sys import argv +from pathlib import Path +import os +import logging +import re +from pydantic import BaseModel + +from medcat.utils.regression.regression_checker import ( + main as regr_main, logger as regr_l) + + +DEFAULT_REGRESSION_SUITE = os.path.join( + *"../tests/resources/default_regression_tests.yml".split("/")) + + +CASES_PATTERN = re.compile( + r"The number of total (successful|failing) \(sub\) cases\s*: (\d+) " + r"\( ?(\d+\.\d+)%\)" +) + + +class RegressionOverallResults(BaseModel): + total_cases: int + successful_cases: int + failed_cases: int + + def is_valid(self, + success_percent: float, + fail_percent: float, + tolerance: float = 0.005) -> bool: + got_good = self.successful_cases / self.total_cases + got_bad = self.failed_cases / self.total_cases + return ( + abs(got_good - success_percent) < tolerance and + abs(got_bad - fail_percent) < tolerance) + + def final_comma_sep_out(self) -> str: + return ",".join([str(self.successful_cases/self.total_cases), + str(self.failed_cases/self.total_cases), + str(self.total_cases)]) + + @classmethod + def from_records(cls, records: list[tuple[str, int, float]] + ) -> 'RegressionOverallResults': + if len(records) != 2: + raise ValueError(f"Unbalanced records: {records}") + good, bad = records + if "successful" != good[0] and "successful" in bad[0]: + # NOTE: swapping order - shouldn't be needed though + good, bad = bad, good + good_cases, good_perc = good[1:] + bad_cases, bad_perc = bad[1:] + inst = cls(total_cases=good_cases + bad_cases, + successful_cases=good_cases, + failed_cases=bad_cases) + if not inst.is_valid(good_perc / 100, bad_perc / 100): + raise ValueError( + f"Unbalanced totals:\nRecords:\n{records}" + f"\nvs\nOutcome:\n{inst}\n" + f"Expected: {good_perc}% S, {bad_perc}% F\n" + f"Got: {inst.successful_cases / inst.total_cases} S, " + f"{inst.failed_cases / inst.total_cases} F") + return inst + + +class CapturingHandler(logging.Handler): + """ + A custom logging handler that captures formatted messages + in a list instead of outputting them. + """ + def __init__(self, *args, + pattern: re.Pattern = CASES_PATTERN, + **kwargs): + super().__init__(*args, **kwargs) + self.pattern = pattern + self.records: list[tuple[str, int, float]] = [] + + def emit(self, record: logging.LogRecord): + """ + Format the record and append the resulting string to the records list. + """ + # Ensure the record is formatted before storing it + msg = self.format(record) + for line in msg.split("\n"): + match = self.pattern.match(line) + if match: + self.records.append( + (match.group(1), int(match.group(2)), + float(match.group(3)))) + + def get_captured_records(self) -> list[str]: + """ + Returns the list of captured formatted log messages. + """ + return self.records + + def get_results(self) -> RegressionOverallResults: + return RegressionOverallResults.from_records(self.records) + + def clear(self): + """ + Clears the list of captured records. + """ + self.records.clear() + + +def main(model_pack_path: str, + regression_suite_path: str = DEFAULT_REGRESSION_SUITE): + regr_l.setLevel(logging.INFO) + handler = CapturingHandler() + regr_l.addHandler(handler) + regr_main(Path(model_pack_path), Path(regression_suite_path)) + results = handler.get_results() + print(results.final_comma_sep_out()) + + +if __name__ == "__main__": + main(*argv[1:]) diff --git a/medcat-v2/paper/scripts/performance/v1_helper.py b/medcat-v2/paper/scripts/performance/v1_helper.py new file mode 100644 index 000000000..f6c830f9b --- /dev/null +++ b/medcat-v2/paper/scripts/performance/v1_helper.py @@ -0,0 +1,114 @@ +from typing import TypedDict, Any +from contextlib import contextmanager, nullcontext + +from pydantic import BaseModel +from spacy.tokens import Span + +from medcat.cdb import CDB +from medcat.config import LinkingFilters + +from medcat.stats.mctexport import MedCATTrainerExportProject + + +class CUIInfo(TypedDict): + preferred_name: str | None + + +class _FakeDict: + + def __init__(self, cdb: CDB): + self.cdb = cdb + + def get(self, cui: str, def_val: Any | None = None) -> CUIInfo | None: + if cui not in self.cdb.cui2preferred_name: + return def_val + return {"preferred_name": self.cdb.cui2preferred_name[cui]} + + def __getitem__(self, cui: str) -> CUIInfo: + if cui not in self.cdb.cui2preferred_name: + raise KeyError(cui) + return {"preferred_name": self.cdb.cui2preferred_name[cui]} + + def __contains__(self, cui: str) -> bool: + return cui in self.cdb.cui2preferred_name + + +def from_cdb(cdb: CDB) -> dict[str, 'CUIInfo']: + return _FakeDict(cdb) + + +class BaseMutableEntity(BaseModel): + start_char_index: int + end_char_index: int + text: str + + +class MutableEntity(BaseModel): + base: BaseMutableEntity + cui: str + context_similarity: float + + @classmethod + def from_spacy(cls, span: Span) -> 'MutableEntity': + base = BaseMutableEntity(start_char_index=span.start_char, + end_char_index=span.end_char, + text=span.text) + return cls(base=base, + cui=span._.cui, + context_similarity=span._.context_similarity) + + @classmethod + def from_spacy_list(cls, spans: list[Span]) -> list['MutableEntity']: + return [cls.from_spacy(span) for span in spans] + + +@contextmanager +def temp_changed_config(config: BaseModel, target: str, value: Any): + """Context manager to change the config temporarily (within). + + Args: + config (BaseModel): The config in question. + target (str): The attribute name to change. + value (Any): The temporary value to use. + + Raises: + IllegalConfigPathException: If no previous value is available. + """ + try: + prev_value = getattr(config, target) + except AttributeError as e: + raise IllegalConfigPathException(target) from e + setattr(config, target, value) + try: + yield + finally: + setattr(config, target, prev_value) + + +class IllegalConfigPathException(ValueError): + + def __init__(self, target_path: str): + super().__init__( + f"Config has no target path: {target_path}") + + +def project_filters(filters: LinkingFilters, + project: MedCATTrainerExportProject, + extra_cui_filter: set[str] | None, + use_project_filters: bool): + """Context manager with per project filters based on a trainer export. + + Args: + filters (LinkingFilters): The current config. + project (MedCATTrainerExportProject): The trainer export. + extra_cui_filter (Optional[set[str]]): Extra cui filters. + use_project_filters (bool): Whether to use project filters. + """ + if extra_cui_filter is not None and not use_project_filters: + return temp_changed_config(filters, 'cuis', extra_cui_filter) + if use_project_filters: + cuis = project.get('cuis', None) + if cuis is None or not cuis: + return nullcontext() + return temp_changed_config(filters, 'cuis', set(cuis.split(","))) + return temp_changed_config(filters, 'cuis', set()) diff --git a/medcat-v2/paper/scripts/run_all_at_once.sh b/medcat-v2/paper/scripts/run_all_at_once.sh new file mode 100644 index 000000000..63f6c3269 --- /dev/null +++ b/medcat-v2/paper/scripts/run_all_at_once.sh @@ -0,0 +1,23 @@ +echo "*****************" +echo "Performance" +echo "*****************" + +bash scripts/performance/get_performance_all.sh + +echo "*****************" +echo "Regression" +echo "*****************" + +bash scripts/performance/get_regression_all.sh + +echo "*****************" +echo "Speed" +echo "*****************" + +bash scripts/speed/run_all_speed_scripts.sh + +echo "*****************" +echo "Variance" +echo "*****************" + +bash scripts/variance/get_variance_with_linker_and_tokenizer_all.sh diff --git a/medcat-v2/paper/scripts/speed/common.py b/medcat-v2/paper/scripts/speed/common.py new file mode 100644 index 000000000..01eff6794 --- /dev/null +++ b/medcat-v2/paper/scripts/speed/common.py @@ -0,0 +1,90 @@ +import logging +import timeit +from contextlib import contextmanager +import cProfile +import pstats +import io +import time +import importlib.metadata + + +mct_ver = importlib.metadata.distribution("medcat").version + + +logger = logging.getLogger(__name__) + + +def _get_stats_str(profile: cProfile.Profile, lines_in_profile: int, + stat_type: str) -> str: + string_io = io.StringIO() + stats = pstats.Stats(profile, stream=string_io) + stats.sort_stats(stat_type).print_stats(lines_in_profile) + return string_io.getvalue() + + +@contextmanager +def show_profile(do_profiling: bool, lines_in_profile: int): + if do_profiling: + profile = cProfile.Profile() + + profile.enable() + + yield [] + + if do_profiling: + profile.disable() + + # NOTE: for logging + tot_stats = _get_stats_str(profile, lines_in_profile, "tottime") + logger.info("TOTtime for top %d", lines_in_profile) + logger.info(tot_stats) + cum_stats = _get_stats_str(profile, lines_in_profile, "cumtime") + logger.info("CUMtime for top %d", lines_in_profile) + logger.info(cum_stats) + + +def perform_work(setup: list[str], + worker: list[str], + warmup: int, + startup: bool, + verbose: bool, + profiling: bool, + lines_in_profile: int, + ) -> float: + sh = logging.StreamHandler() + logger.addHandler(sh) + if verbose: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(logging.CRITICAL) + # NOTE: to make sure all the imports are done and so on + if warmup > 0 and startup: + raise ValueError("Timing warmed up from startup doesn't make sense") + # do warmup if needed + for cur_warmup in range(warmup): + logger.info("Doing warmup step %d", cur_warmup) + exec("\n".join(setup + worker)) + if startup: + logger.warning("For startup, will include warmup in timed work") + worker = setup + worker + setup = [] + if profiling: + # NOTE: do it manually so I can profile only the worker part + exec("\n".join(setup)) + start_time = time.perf_counter() + with show_profile( + do_profiling=True, + lines_in_profile=lines_in_profile): + exec("\n".join(worker)) + times = [time.perf_counter() - start_time] + else: + times = timeit.repeat( + "\n".join(worker), + setup="\n".join(setup), + repeat=1, number=1 + ) + took_time = times[0] + logger.info("Took a total of %ss", took_time) + # NOTE: print for any time output + # NOTE: no units for easy automation + return took_time diff --git a/medcat-v2/paper/scripts/speed/common4subproc.py b/medcat-v2/paper/scripts/speed/common4subproc.py new file mode 100644 index 000000000..9621d22ca --- /dev/null +++ b/medcat-v2/paper/scripts/speed/common4subproc.py @@ -0,0 +1,76 @@ +import sys +from enum import Enum, auto +from pydantic import BaseModel, ConfigDict +import subprocess + + +class RunConfig(BaseModel): + repeats: int = 20 + # how many times to perform for warmup + warmup_count: int = 1 + + +class RunResults(BaseModel): + all_times: list[float] + mean: float + min: float + max: float + + model_config = ConfigDict(frozen=True) + + @classmethod + def from_times(cls, times: list[float]) -> "RunResults": + return cls( + all_times=times, + mean=sum(times) / len(times), + min=min(times), + max=max(times), + ) + + +class RunType(Enum): + STARTUP = auto() + COLD = auto() + WARM = auto() + + +def _single_experiment(target_script: str, + target_args: list[str], + cnf: RunConfig, + run_type: RunType, + run_type_map: dict[RunType, list[str]], + ) -> RunResults: + sys_argv = [sys.executable, target_script,] + target_args + if run_type in run_type_map: + sys_argv += run_type_map[run_type] + all_took: list[float] = [] + for _ in range(cnf.repeats): + run_out = subprocess.run(sys_argv, capture_output=True) + raw_time_str = run_out.stdout.strip().split(b"\n")[-1] + try: + took_time = float(raw_time_str) + except ValueError as err: + raise ValueError( + f"Unable to get run time for {run_type} from:\n" + f"'{raw_time_str}'\n" + f"Total output:\n{run_out.stdout.decode()}\n" + f"\nError output was:\n" + f"{run_out.stderr.decode()}\n" + f"\nWas running the command:\n {' '.join(sys_argv)}" + ) from err + all_took.append(took_time) + return RunResults.from_times(all_took) + + +def do_experiment( + target_script: str, + target_args: list[str], + run_type_map: dict[RunType, list[str]], + cnf: RunConfig = RunConfig(), + ) -> dict[RunType, RunResults]: + return { + run_type: _single_experiment( + target_script, target_args, cnf, run_type, + run_type_map=run_type_map) + for run_type in run_type_map + } diff --git a/medcat-v2/paper/scripts/speed/get_inference_speed.py b/medcat-v2/paper/scripts/speed/get_inference_speed.py new file mode 100644 index 000000000..f1fcf1a66 --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_inference_speed.py @@ -0,0 +1,60 @@ +from common import perform_work, mct_ver +from pydantic import BaseModel +import argparse + + +class InferenceSpeedConfig(BaseModel): + model_pack_path: str + inference_file_path: str + + +def main(): + parser = argparse.ArgumentParser( + "get_inference_speed.py" + ) + parser.add_argument("model_pack_path", + help="The path to the model pack", + type=str) + parser.add_argument("csv_path", + help="Path to the csv with (at least a) 'text' field", + type=str) + parser.add_argument("--verbose", "-v", + help="Whether to run in verbose mode", + action="store_true") + parser.add_argument("--do-profiling", "-p", + help="Whether to run profiling on top of just timing", + action="store_true") + parser.add_argument("--num-in-profile", "--np", + help="The number of lines in the profile.", + type=int, default=20) + parser.add_argument("--startup", "-s", + help="Whether to use the startup as the start time. " + "This is useful when trying to include import times " + "as well - i.e real user experience", + action="store_true") + parser.add_argument("--warmup", "-w", + help="The number of warmup rounds", + type=int, default=1) + args = parser.parse_args() + took_time = perform_work( + setup=["from medcat.cat import CAT", + "import pandas as pd", + f"cat = CAT.load_model_pack('{args.model_pack_path}')", + # NOTE: this reset subnames - it is only required for models saved + # in v2 pre-beta releases + "cat.cdb.has_subname('abc')" if mct_ver.startswith("2") else "", + f"df = pd.read_csv('{args.csv_path}')"], + worker=["for text in df.text:", + " cat.get_entities(text)"], + warmup=args.warmup, + startup=args.startup, + verbose=args.verbose, + profiling=args.do_profiling, + lines_in_profile=args.num_in_profile + ) + print(took_time) + return took_time + + +if __name__ == "__main__": + main() diff --git a/medcat-v2/paper/scripts/speed/get_inference_speed_all.py b/medcat-v2/paper/scripts/speed/get_inference_speed_all.py new file mode 100644 index 000000000..bc78a99de --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_inference_speed_all.py @@ -0,0 +1,49 @@ +import argparse +from pprint import pprint +import json +import os + +import get_inference_speed +from common4subproc import do_experiment, RunType, RunConfig + + +def main(): + parser = argparse.ArgumentParser( + "get_inference_speed_all" + ) + parser.add_argument("model_pack_path", + help="Model pack path", + type=str) + parser.add_argument("csv_path", + help="Path to the csv with (at least a) 'text' field", + type=str) + parser.add_argument("--repeats", + help="Number of repeats to use", + type=int, default=20) + parser.add_argument("--save-json", "-j", + help="The json path to save the results to", + type=str, default=None) + args = parser.parse_args() + target_script = os.path.join( + os.path.dirname(__file__), get_inference_speed.__name__ + ".py") + results = do_experiment( + target_script, + [args.model_pack_path, args.csv_path], + run_type_map={ + RunType.COLD: ["-w", "0"], + RunType.WARM: ["-w", "1"], + }, + cnf=RunConfig(repeats=args.repeats,)) + dumped = {run_type.name: model.model_dump() + for run_type, model in results.items()} + if args.save_json: + print("Saving to", args.save_json) + with open(args.save_json, 'w') as f: + json.dump(dumped, f) + else: + print("Overall:") + pprint(dumped) + + +if __name__ == "__main__": + main() diff --git a/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple.sh b/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple.sh new file mode 100644 index 000000000..3c6cbaafa --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +SAVE_PREFIX=$1 +shift 1 + +# --- Input Validation --- +if (( $# == 0 )); then + echo "Usage: $0 ..." + exit 0 +fi + +if (( $# % 3 != 0 )); then + echo "Error: Arguments must be provided in triples (name, model path, and CSV path)." >&2 + exit 1 +fi + +echo "Starting triplet argument processing..." +echo "-----------------------------------------" + +# The 'while' loop continues as long as there are arguments left ($# is non-zero) +while (( "$#" )); do + MODEL_NAME="$1" + MODEL_PATH="$2" + CSV_PATH="$3" + + echo "Model: '$MODEL_NAME' with CSV '$CSV_PATH'" + + SAVE_PATH=$SAVE_PREFIX"_"$MODEL_NAME".json" + echo "Will save to" $SAVE_PATH + + FULL_TARGET="scripts/speed/get_inference_speed_all.py $MODEL_PATH $CSV_PATH --save-json $SAVE_PATH" + echo "Running: python $FULL_TARGET" + python $FULL_TARGET + + echo "---" + + # Shift discards the first N arguments. + # We discard the thre arguments we just processed ($1, $2, and $3) + shift 3 +done + +echo "-----------------------------------------" +echo "Processing complete." \ No newline at end of file diff --git a/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple_v1.sh b/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple_v1.sh new file mode 100644 index 000000000..a2b1553b7 --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple_v1.sh @@ -0,0 +1,14 @@ +ner1="2023_NER_no_MetaCAT" +ner_model_path_no_mc="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/20230227__kch_gstt_trained_model_no_mc_d84c313f24311484.zip" +ner2="2023_NER_w_MetaCAT" +ner_model_path_w_mc="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/20230227__kch_gstt_trained_model_494c3717f637bb89.zip" +csv_path="data/unsupervised/mimic_iv_discharge_head20.csv" + +out_prefix="out/inference_speed/v1" +if [[ ! -z "$1" ]] + then + out_prefix=$1 + echo "Overwriting out prefix with: '"$1"'" +fi + +bash scripts/speed/get_inference_speed_for_multiple.sh "$out_prefix" "$ner1" "$ner_model_path_no_mc" "$csv_path" "$ner2" "$ner_model_path_w_mc" "$csv_path" diff --git a/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple_v2.sh b/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple_v2.sh new file mode 100644 index 000000000..62b21bee9 --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_inference_speed_for_multiple_v2.sh @@ -0,0 +1,14 @@ +ner1="2023_NER_no_MetaCAT" +ner_model_path_no_mc="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/CONVERT_2023_model_no_mc_234dda1597f635e3.zip" +ner2="2023_NER_w_MetaCAT" +ner_model_path_w_mc="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/20230227__kch_gstt_trained_model_f76d2121b77c3e9a.zip" +csv_path="data/unsupervised/mimic_iv_discharge_head20.csv" + +out_prefix="out/inference_speed/v2" +if [[ ! -z "$1" ]] + then + out_prefix=$1 + echo "Overwriting out prefix with: "$1 +fi + +bash scripts/speed/get_inference_speed_for_multiple.sh "$out_prefix" "$ner1" "$ner_model_path_no_mc" "$csv_path" "$ner2" "$ner_model_path_w_mc" "$csv_path" diff --git a/medcat-v2/paper/scripts/speed/get_load_speed.py b/medcat-v2/paper/scripts/speed/get_load_speed.py new file mode 100644 index 000000000..6f382d2df --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_load_speed.py @@ -0,0 +1,49 @@ +import argparse +import logging + +from common import perform_work + + +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser( + "get_load_speed.py" + ) + parser.add_argument("model_pack_path", + help="model_pack_path", + type=str) + parser.add_argument("--verbose", "-v", + help="Whether to run in verbose mode", + action="store_true") + parser.add_argument("--do-profiling", "-p", + help="Whether to run profiling on top of just timing", + action="store_true") + parser.add_argument("--num-in-profile", "--np", + help="The number of lines in the profile.", + type=int, default=20) + parser.add_argument("--startup", "-s", + help="Whether to use the startup as the start time. " + "This is useful when trying to include import times " + "as well - i.e real user experience", + action="store_true") + parser.add_argument("--warmup", "-w", + help="The number of warmup rounds", + type=int, default=1) + args = parser.parse_args() + took_time = perform_work( + setup=["from medcat.cat import CAT",], + worker=[f"""CAT.load_model_pack("{args.model_pack_path}")"""], + warmup=args.warmup, + startup=args.startup, + verbose=args.verbose, + profiling=args.do_profiling, + lines_in_profile=args.num_in_profile + ) + print(took_time) + return took_time + + +if __name__ == "__main__": + took_time = main() diff --git a/medcat-v2/paper/scripts/speed/get_load_speed_all.py b/medcat-v2/paper/scripts/speed/get_load_speed_all.py new file mode 100644 index 000000000..f32656a0f --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_load_speed_all.py @@ -0,0 +1,47 @@ +import argparse +from pprint import pprint +import json +import os + +import get_load_speed +from common4subproc import do_experiment, RunType, RunConfig + + +def main(): + parser = argparse.ArgumentParser( + "get_load_speed_all" + ) + parser.add_argument("model_pack_path", + help="Model pack path", + type=str) + parser.add_argument("--repeats", + help="Number of repeats to use", + type=int, default=20) + parser.add_argument("--save-json", "-j", + help="The json path to save the results to", + type=str, default=None) + args = parser.parse_args() + target_script = os.path.join( + os.path.dirname(__file__), get_load_speed.__name__ + ".py") + results = do_experiment( + target_script, + [args.model_pack_path,], + run_type_map={ + RunType.STARTUP: ["-w", "0", "-s"], + RunType.COLD: ["-w", "0"], + RunType.WARM: [], + }, + cnf=RunConfig(repeats=args.repeats,)) + dumped = {run_type.name: model.model_dump() + for run_type, model in results.items()} + if args.save_json: + print("Saving to", args.save_json) + with open(args.save_json, 'w') as f: + json.dump(dumped, f) + else: + print("Overall:") + pprint(dumped) + + +if __name__ == "__main__": + main() diff --git a/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple.sh b/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple.sh new file mode 100644 index 000000000..9748e12ee --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +SAVE_PREFIX=$1 +shift 1 + +# --- Input Validation --- +if (( $# == 0 )); then + echo "Usage: $0 ..." + exit 0 +fi + +if (( $# % 2 != 0 )); then + echo "Error: Arguments must be provided in pairs (name and path)." >&2 + exit 1 +fi + +echo "Starting pairwise argument processing..." +echo "-----------------------------------------" + +# The 'while' loop continues as long as there are arguments left ($# is non-zero) +while (( "$#" )); do + MODEL_NAME="$1" + MODEL_PATH="$2" + + echo "Model: '$MODEL_NAME'" + + SAVE_PATH=$SAVE_PREFIX"_"$MODEL_NAME".json" + echo "Will save to" $SAVE_PATH + + python scripts/speed/get_load_speed_all.py $MODEL_PATH --save-json $SAVE_PATH + + echo "---" + + # Shift discards the first N arguments. + # We discard the two arguments we just processed ($1 and $2) + shift 2 +done + +echo "-----------------------------------------" +echo "Processing complete." \ No newline at end of file diff --git a/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple_v1.sh b/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple_v1.sh new file mode 100644 index 000000000..21ca257bc --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple_v1.sh @@ -0,0 +1,15 @@ +ner1="2023_NER_no_MetaCAT" +ner_model_path_no_mc="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/20230227__kch_gstt_trained_model_no_mc_d84c313f24311484.zip" +ner2="2023_NER_w_MetaCAT" +ner_model_path_w_mc="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/20230227__kch_gstt_trained_model_494c3717f637bb89.zip" +deid="n2c2_DeID" +deid_model_path="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/deid_medcat_n2c2_modelpack.zip" + +out_prefix="out/load_speed/v1" +if [[ ! -z "$1" ]] + then + out_prefix=$1 + echo "Overwriting out prefix with: "$1 +fi + +bash scripts/speed/get_load_speed_for_multiple.sh $out_prefix "$ner1" "$ner_model_path_no_mc" "$ner2" "$ner_model_path_w_mc" "$deid" "$deid_model_path" diff --git a/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple_v2.sh b/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple_v2.sh new file mode 100644 index 000000000..e1e22a857 --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_load_speed_for_multiple_v2.sh @@ -0,0 +1,16 @@ +ner1="2023_NER_no_MetaCAT" +ner_model_path_no_mc="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/CONVERT_2023_model_no_mc_234dda1597f635e3.zip" +ner2="2023_NER_w_MetaCAT" +ner_model_path_w_mc="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/20230227__kch_gstt_trained_model_f76d2121b77c3e9a.zip" +deid="n2c2_DeID" +deid_model_path="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/CONVERT_deid_model_af31d2a9c5ccbe4d.zip.zip" + +out_prefix="out/load_speed/v2" +if [[ ! -z "$1" ]] + then + out_prefix=$1 + echo "Overwriting out prefix with: "$1 +fi + + +bash scripts/speed/get_load_speed_for_multiple.sh $out_prefix "$ner1" "$ner_model_path_no_mc" "$ner2" "$ner_model_path_w_mc" "$deid" "$deid_model_path" diff --git a/medcat-v2/paper/scripts/speed/get_unsup_train_speed.py b/medcat-v2/paper/scripts/speed/get_unsup_train_speed.py new file mode 100644 index 000000000..2daa2f24e --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_unsup_train_speed.py @@ -0,0 +1,63 @@ +from common import perform_work, mct_ver +from pydantic import BaseModel +import argparse + + +class InferenceSpeedConfig(BaseModel): + model_pack_path: str + inference_file_path: str + + +def main(): + parser = argparse.ArgumentParser( + "get_inference_speed.py" + ) + parser.add_argument("model_pack_path", + help="The path to the model pack", + type=str) + parser.add_argument("csv_path", + help="Path to the csv with (at least a) 'text' field", + type=str) + parser.add_argument("--verbose", "-v", + help="Whether to run in verbose mode", + action="store_true") + parser.add_argument("--do-profiling", "-p", + help="Whether to run profiling on top of just timing", + action="store_true") + parser.add_argument("--num-in-profile", "--np", + help="The number of lines in the profile.", + type=int, default=20) + parser.add_argument("--startup", "-s", + help="Whether to use the startup as the start time. " + "This is useful when trying to include import times " + "as well - i.e real user experience", + action="store_true") + parser.add_argument("--warmup", "-w", + help="The number of warmup rounds", + type=int, default=1) + args = parser.parse_args() + if mct_ver.startswith("1."): + work_string = "cat.train(df.text)" + elif mct_ver.startswith("2."): + work_string = "cat.trainer.train_unsupervised(df.text)" + took_time = perform_work( + setup=["from medcat.cat import CAT", + "import pandas as pd", + f"cat = CAT.load_model_pack('{args.model_pack_path}')", + # NOTE: this reset subnames - it is only required for models saved + # in v2 pre-beta releases + "cat.cdb.has_subname('abc')" if mct_ver.startswith("2") else "", + f"df = pd.read_csv('{args.csv_path}')"], + worker=[work_string], + warmup=args.warmup, + startup=args.startup, + verbose=args.verbose, + profiling=args.do_profiling, + lines_in_profile=args.num_in_profile + ) + print(took_time) + return took_time + + +if __name__ == "__main__": + main() diff --git a/medcat-v2/paper/scripts/speed/get_unsup_train_speed_all.py b/medcat-v2/paper/scripts/speed/get_unsup_train_speed_all.py new file mode 100644 index 000000000..124b22f6b --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_unsup_train_speed_all.py @@ -0,0 +1,49 @@ +import argparse +from pprint import pprint +import json +import os + +import get_unsup_train_speed +from common4subproc import do_experiment, RunType, RunConfig + + +def main(): + parser = argparse.ArgumentParser( + "get_unsup_train_speed_all" + ) + parser.add_argument("model_pack_path", + help="Model pack path", + type=str) + parser.add_argument("csv_path", + help="Path to the csv with (at least a) 'text' field", + type=str) + parser.add_argument("--repeats", + help="Number of repeats to use", + type=int, default=20) + parser.add_argument("--save-json", "-j", + help="The json path to save the results to", + type=str, default=None) + args = parser.parse_args() + target_script = os.path.join( + os.path.dirname(__file__), get_unsup_train_speed.__name__ + ".py") + results = do_experiment( + target_script, + [args.model_pack_path, args.csv_path], + run_type_map={ + RunType.COLD: ["-w", "0"], + RunType.WARM: ["-w", "1"], + }, + cnf=RunConfig(repeats=args.repeats,)) + dumped = {run_type.name: model.model_dump() + for run_type, model in results.items()} + if args.save_json: + print("Saving to", args.save_json) + with open(args.save_json, 'w') as f: + json.dump(dumped, f) + else: + print("Overall:") + pprint(dumped) + + +if __name__ == "__main__": + main() diff --git a/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple.sh b/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple.sh new file mode 100644 index 000000000..10f895b7e --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +SAVE_PREFIX=$1 +shift 1 + +# --- Input Validation --- +if (( $# == 0 )); then + echo "Usage: $0 ..." + exit 0 +fi + +if (( $# % 3 != 0 )); then + echo "Error: Arguments must be provided in triples (name, model path, and CSV path)." >&2 + exit 1 +fi + +echo "Starting triplet argument processing..." +echo "-----------------------------------------" + +# The 'while' loop continues as long as there are arguments left ($# is non-zero) +while (( "$#" )); do + MODEL_NAME="$1" + MODEL_PATH="$2" + CSV_PATH="$3" + + echo "Model: '$MODEL_NAME' with CSV '$CSV_PATH'" + + SAVE_PATH=$SAVE_PREFIX"_"$MODEL_NAME".json" + echo "Will save to" $SAVE_PATH + + FULL_TARGET="scripts/speed/get_unsup_train_speed_all.py $MODEL_PATH $CSV_PATH --save-json $SAVE_PATH" + echo "Running: python $FULL_TARGET" + python $FULL_TARGET + + echo "---" + + # Shift discards the first N arguments. + # We discard the thre arguments we just processed ($1, $2, and $3) + shift 3 +done + +echo "-----------------------------------------" +echo "Processing complete." \ No newline at end of file diff --git a/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple_v1.sh b/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple_v1.sh new file mode 100644 index 000000000..8dd0a480f --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple_v1.sh @@ -0,0 +1,14 @@ +ner1="2023_NER_no_MetaCAT" +ner_model_path_no_mc="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/20230227__kch_gstt_trained_model_no_mc_d84c313f24311484.zip" +ner2="2023_NER_w_MetaCAT" +ner_model_path_w_mc="/Users/martratas/Documents/CogStack/MedCAT/MedCAT/models/20230227__kch_gstt_trained_model_494c3717f637bb89.zip" +csv_path="data/unsupervised/mimic_iv_discharge_head20.csv" + +out_prefix="out/unsup_train_speed/v1" +if [[ ! -z "$1" ]] + then + out_prefix=$1 + echo "Overwriting out prefix with: '"$1"'" +fi + +bash scripts/speed/get_unsup_train_speed_for_multiple.sh "$out_prefix" "$ner1" "$ner_model_path_no_mc" "$csv_path" "$ner2" "$ner_model_path_w_mc" "$csv_path" diff --git a/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple_v2.sh b/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple_v2.sh new file mode 100644 index 000000000..0eb6e083b --- /dev/null +++ b/medcat-v2/paper/scripts/speed/get_unsup_train_speed_for_multiple_v2.sh @@ -0,0 +1,14 @@ +ner1="2023_NER_no_MetaCAT" +ner_model_path_no_mc="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/CONVERT_2023_model_no_mc_234dda1597f635e3.zip" +ner2="2023_NER_w_MetaCAT" +ner_model_path_w_mc="/Users/martratas/Documents/CogStack/MedCAT/monorepo-nlp/medcat-v2/.temp/20230227__kch_gstt_trained_model_f76d2121b77c3e9a.zip" +csv_path="data/unsupervised/mimic_iv_discharge_head20.csv" + +out_prefix="out/unsup_train_speed/v2" +if [[ ! -z "$1" ]] + then + out_prefix=$1 + echo "Overwriting out prefix with: '"$1"'" +fi + +bash scripts/speed/get_unsup_train_speed_for_multiple.sh "$out_prefix" "$ner1" "$ner_model_path_no_mc" "$csv_path" "$ner2" "$ner_model_path_w_mc" "$csv_path" diff --git a/medcat-v2/paper/scripts/speed/run_all_speed_scripts.sh b/medcat-v2/paper/scripts/speed/run_all_speed_scripts.sh new file mode 100644 index 000000000..6211360bc --- /dev/null +++ b/medcat-v2/paper/scripts/speed/run_all_speed_scripts.sh @@ -0,0 +1,15 @@ +echo "*****************" +echo "running v1 stuff" +echo "*****************" + +source .venv_v1/bin/activate + +bash scripts/speed/run_all_speed_scripts_for_version.sh v1 + +echo "*****************" +echo "running v2 stuff" +echo "*****************" + +source ../.venv312/bin/activate + +bash scripts/speed/run_all_speed_scripts_for_version.sh v2 diff --git a/medcat-v2/paper/scripts/speed/run_all_speed_scripts_for_version.sh b/medcat-v2/paper/scripts/speed/run_all_speed_scripts_for_version.sh new file mode 100644 index 000000000..6630db8a4 --- /dev/null +++ b/medcat-v2/paper/scripts/speed/run_all_speed_scripts_for_version.sh @@ -0,0 +1,13 @@ +ver=$1 +echo "Running for version: $ver" +python --version +python -m pip show medcat | grep "Version" + + +for fn in `ls scripts/speed/*_$ver.sh`; +do + echo "__________________________" + echo "Running script:" + echo $fn + bash $fn +done diff --git a/medcat-v2/paper/scripts/speed/summarise_speeds.py b/medcat-v2/paper/scripts/speed/summarise_speeds.py new file mode 100644 index 000000000..43c5606d8 --- /dev/null +++ b/medcat-v2/paper/scripts/speed/summarise_speeds.py @@ -0,0 +1,52 @@ +import json +import sys +import os +import pandas as pd +import re + + +VERSION_MODEL_PATTERN = re.compile(r"(v\d)_(.*).json") +FOLDER_NAME_PATTERN = re.compile(r"(.*)_speed") + + +def extract_test_version_and_model(path: str) -> tuple[str, str, str]: + dirname = os.path.basename(os.path.dirname(path)) + fnmatch = FOLDER_NAME_PATTERN.match(dirname) + if not fnmatch: + raise ValueError(f"Folder name unrecognsied: {dirname}") + basename = os.path.basename(path) + match = VERSION_MODEL_PATTERN.match(basename) + if not match: + raise ValueError(f"Basename did not match: {basename}") + return fnmatch.group(1), match.group(1), match.group(2) + + +def gather_data(json_paths: list[str], + header=[ + "Check Type", "Version", "Model", "Warm status", + "Mean time", "# of repeats"] + ) -> pd.DataFrame: + dfs: list[pd.DataFrame] = [] + for path in json_paths: + speed_type, version, model = extract_test_version_and_model(path) + with open(path) as f: + cur_data = json.load(f) + print("KEYS", cur_data.keys()) + col1 = list(cur_data.keys()) + mean = [cur_data[cc]['mean'] for cc in col1] + experiments = [len(cur_data[cc]['all_times']) for cc in col1] + vals = [speed_type, version, model, col1, mean, experiments] + dfs.append(pd.DataFrame({col: val for col, val in zip(header, vals)})) + df = pd.concat(dfs) + df.sort_values(by=["Check Type", "Model", "Warm status"], inplace=True) + df.reset_index(inplace=True) + return df + + +def main(*file_paths: str): + df = gather_data(file_paths) + print(df.to_string()) + + +if __name__ == "__main__": + main(*sys.argv[1:]) diff --git a/medcat-v2/paper/scripts/variance/convert_to_embed_linker.py b/medcat-v2/paper/scripts/variance/convert_to_embed_linker.py new file mode 100644 index 000000000..b852914d7 --- /dev/null +++ b/medcat-v2/paper/scripts/variance/convert_to_embed_linker.py @@ -0,0 +1,35 @@ +import os + +from medcat.cat import CAT +from medcat.components.types import CoreComponentType +from medcat.config.config import EmbeddingLinking +from medcat.components.linking.embedding_linker import ( + Linker as ELinker) + + +def convert(cat: CAT): + cmp_cnf = cat.config.components + cmp_cnf.linking = EmbeddingLinking() + # NOTE: should fix on the lib side + cmp_cnf.linking.comp_name = "medcat2_embedding_linker" + # need to recreate and create embeddings + cat._recreate_pipe() + linker: ELinker = cat.pipe.get_component(CoreComponentType.linking) + print("Creating embeddings...") + linker.create_embeddings() + # NOTE: returning without another pipe recreation + + +def main(model_pack_path: str, save_path: str): + print("Loading", model_pack_path) + cat = CAT.load_model_pack(model_pack_path) + convert(cat) + print("Saving to", save_path) + saved = cat.save_model_pack(os.path.dirname(save_path), + pack_name=os.path.basename(save_path)) + print(f"Saved to\n{saved}") + + +if __name__ == "__main__": + from sys import argv + main(*argv[1:]) diff --git a/medcat-v2/paper/scripts/variance/get_variance_with_linker_and_tokenizer.py b/medcat-v2/paper/scripts/variance/get_variance_with_linker_and_tokenizer.py new file mode 100644 index 000000000..131e7b57a --- /dev/null +++ b/medcat-v2/paper/scripts/variance/get_variance_with_linker_and_tokenizer.py @@ -0,0 +1,206 @@ +import json +import time +from enum import Enum +import io +from contextlib import redirect_stdout, redirect_stderr, contextmanager +import re +import os + +from cProfile import Profile +from pstats import Stats + +from medcat.cat import CAT +from medcat.components.types import CoreComponentType +from medcat.stats import get_stats + +EXAMPLE_DATASET = "paper/data/supervised/cometa/mct_export.json" +EXAMPLE_MODEL_PATH = ".temp/CONVERT_2023_model_no_mc_234dda1597f635e3.zip" +USE_LINKER = "DEFAULT" +USE_REGEX_TOKENIZER = True +DO_PROFILING = False + + +class LinkerType(Enum): + DEFAULT = 0 + VECTOR_CONTEXT = 0 + FASTER = 1 + EMBEDDING = 2 + + @classmethod + def get_type(cls, linker: str) -> 'LinkerType': + if linker.upper() in cls: + return cls[linker.upper()] + elif linker.lower() in ("new", "faster"): + return cls.FASTER + elif "embed" in linker.lower(): + return cls.EMBEDDING + elif linker.lower() in ("normal", "def", "regular", "reg", "old"): + return cls.DEFAULT + raise ValueError(f"Unknown linker type: '{linker}'") + + def set_linker(self, cat: CAT): + cmp_cnf = cat.config.components + if self is LinkerType.DEFAULT: + # make change just in case this is a re-run / subsequent change + cmp_cnf.linking.comp_name = "default" + if self is LinkerType.FASTER: + cmp_cnf.linking.comp_name = "primary_name_only_linker" + elif self is LinkerType.EMBEDDING: + from medcat.config.config import EmbeddingLinking + from medcat.components.linking.embedding_linker import ( + Linker as ELinker) + cmp_cnf.linking = EmbeddingLinking() + # NOTE: should fix on the lib side + cmp_cnf.linking.comp_name = "medcat2_embedding_linker" + # need to recreate and create embeddings + cat._recreate_pipe() + linker: ELinker = cat.pipe.get_component(CoreComponentType.linking) + print("Creating embeddings...") + linker.create_embeddings() + # NOTE: returning without another pipe recreation + return + else: + raise ValueError("Not defined for linker:") + cat._recreate_pipe() + + +def setup_cui_filter(data: dict) -> None: + per_proj_cuis: list[int] = [] + for proj in data["projects"]: + all_cuis = { + ann["cui"] + for doc in proj["documents"] + for ann in doc["annotations"] + } + cur_cuis = proj["cuis"] + all_cuis.update(cur_cuis.split(",")) + proj["cuis"] = ",".join(all_cuis) + per_proj_cuis.append(len(all_cuis)) + print("Total projects", len(per_proj_cuis), + "\n Min CUIs", min(per_proj_cuis), + "\n Mean CUIs", sum(per_proj_cuis) / len(per_proj_cuis), + "\n Max CUIs", max(per_proj_cuis)) + + +@contextmanager +def capture_output(): + f = io.StringIO() + out_list = [] + with redirect_stdout(f): + with redirect_stderr(f): + yield out_list + lines = f.getvalue().split("\n") + linker: str | None = None + tokenizer: str | None = None + prec: str | None = None + rec: str | None = None + f1: str | None = None + time_taken: str | None = None + ent_throughput: str | None = None + for line in lines: + if m := re.match(r"\s+Linker:\s*(.*)", line): + linker = m.group(1) + elif m := re.match(r"\s+Tokenizer:\s*(.*)", line): + tokenizer = m.group(1) + elif m := re.search( + r"Epoch:\s*0,.*Prec:\s*([\d.]+),\s*" + r"Rec:\s*([\d.]+),\s*"r"F1:\s*([\d.]+)", line): + prec, rec, f1 = m.groups() + elif m := re.search( + r"Took ([\d.]+)", line): + time_taken = m.group(1) + elif m:= re.search( + r"Throughput rate (\d+\.\d+)", line): + ent_throughput = m.group(1) + if None not in (linker, tokenizer, prec, rec, f1, time_taken, ent_throughput): + # break early if all found + break + if None in (linker, tokenizer, prec, rec, f1, time_taken, ent_throughput): + raise ValueError( + "Unable to find linker, tokenizer, precision, recall, f1, ent_throughput" + "or time taken. Got " + f"{linker}, {tokenizer}, {prec}, {rec}, {f1}, {time_taken}, {ent_throughput}") + out_list.extend([linker, tokenizer, prec, rec, f1, time_taken, ent_throughput]) + + +def main( + linker_type_str: str = USE_LINKER, + regex_tokenizer_raw: bool | str = USE_REGEX_TOKENIZER, + model_path: str = EXAMPLE_MODEL_PATH, + data_path: str = EXAMPLE_DATASET, + one_line_only: bool = False): + if one_line_only: + with capture_output() as captured: + _main(linker_type_str, regex_tokenizer_raw, + model_path, data_path) + # start with data path + data_folder_name = os.path.basename( + os.path.dirname(data_path)) + print(",".join([data_folder_name] + captured)) + else: + _main(linker_type_str, regex_tokenizer_raw, + model_path, data_path) + + +def _main( + linker_type_str: str = USE_LINKER, + regex_tokenizer_raw: bool | str = USE_REGEX_TOKENIZER, + model_path: str = EXAMPLE_MODEL_PATH, + data_path: str = EXAMPLE_DATASET): + linker_type = LinkerType.get_type(linker_type_str) + if isinstance(regex_tokenizer_raw, str): + regex_tokenizer = regex_tokenizer_raw.lower() in ( + "regex", "yes", "true") + else: + regex_tokenizer = regex_tokenizer_raw + print(f"Setup:\n Linker:{linker_type.name}" + f"\n Tokenizer:{'regex' if regex_tokenizer else 'spacy'}") + print("Loading model", model_path, "...") + cat = CAT.load_model_pack(model_path) + # NOTE: prep subnames + cat.cdb.has_subname("") + if linker_type != LinkerType.DEFAULT: + print("USING non-default LINKER", linker_type) + linker_type.set_linker(cat) + else: + print("Using DEFAULT linker...") + if regex_tokenizer: + print("USING REGEX BASED TOKENIZER") + cat.config.general.nlp.provider = "regex" + cat._recreate_pipe() + else: + print("Using regular (spacy) tokenizer") + print("Loading data", data_path) + with open(data_path) as f: + data = json.load(f) + print("setting up CUI filter") + setup_cui_filter(data) + print("Running metrics...") + start = time.perf_counter() + if DO_PROFILING: + print("PROFILING") + profile = Profile() + profile.enable() + fps, _, tps, *_ = get_stats(cat, data, use_project_filters=True) + if DO_PROFILING: + profile.disable() + end = time.perf_counter() + time_taken = end - start + print("Took", time_taken) + ents_found = sum(fps.values()) + sum(tps.values()) + print("Throughput rate", ents_found / time_taken) + if DO_PROFILING: + print("Profile stats (CUMtime)") + stats = Stats(profile) + print(stats.sort_stats("cumtime").print_stats(50)) + print("Profile stats (TOTtime)") + stats = Stats(profile) + print(stats.sort_stats("tottime").print_stats(50)) + + +if __name__ == "__main__": + from sys import argv + one_line_only = "--one-line" in argv + if one_line_only: + argv.remove("--one-line") + main(*argv[1:], one_line_only=one_line_only) diff --git a/medcat-v2/paper/scripts/variance/get_variance_with_linker_and_tokenizer_all.sh b/medcat-v2/paper/scripts/variance/get_variance_with_linker_and_tokenizer_all.sh new file mode 100644 index 000000000..d5a106bb6 --- /dev/null +++ b/medcat-v2/paper/scripts/variance/get_variance_with_linker_and_tokenizer_all.sh @@ -0,0 +1,47 @@ +SCRIPT="scripts/variance/get_variance_with_linker_and_tokenizer.py" +MODEL_PATH="../.temp/CONVERT_2023_model_no_mc_234dda1597f635e3.zip" + +# HEADER +echo "Dataset,linker,tokenizer,prec,rec,f1,runtime,throughput" + +#"==COMETA==" +DATASET="data/supervised/cometa/mct_export.json" +EXTRA="--one-line" + +#"NORMAL" +python $SCRIPT old spacy $MODEL_PATH $DATASET $EXTRA +#"With faster linker" +python $SCRIPT new spacy $MODEL_PATH $DATASET $EXTRA +#"With regex tokenizer" +python $SCRIPT old regex $MODEL_PATH $DATASET $EXTRA +# "With regex tokenizer AND faster linker" +python $SCRIPT new regex $MODEL_PATH $DATASET $EXTRA + +# with embedding linker +# convert embedding model once +EMBED_MODEL_PATH=`python scripts/variance/convert_to_embed_linker.py $MODEL_PATH data/embed_model_converted | tail -n 1` + +# "With spacy tokenizer + embed lnker" +python $SCRIPT embed spacy $EMBED_MODEL_PATH $DATASET $EXTRA +# "With regex tokenizer + embed linker" +python $SCRIPT embed regex $EMBED_MODEL_PATH $DATASET $EXTRA + +# other dataset +# "==Linking Challenge==" +DATASET="data/supervised/linking_challenge/mct_export.json" + +# "NORMAL" +python $SCRIPT old spacy $MODEL_PATH $DATASET $EXTRA +# "With faster linker" +python $SCRIPT new spacy $MODEL_PATH $DATASET $EXTRA +# "With regex tokenizer" +python $SCRIPT old regex $MODEL_PATH $DATASET $EXTRA +# "With regex tokenizer AND faster linker" +python $SCRIPT new regex $MODEL_PATH $DATASET $EXTRA + +# with embedding linker + +# "With spacy tokenizer + embed lnker" +python $SCRIPT embed spacy $EMBED_MODEL_PATH $DATASET $EXTRA +# "With regex tokenizer + embed linker" +python $SCRIPT embed regex $EMBED_MODEL_PATH $DATASET $EXTRA diff --git a/medcat-v2/paper/scripts/variance/plot_variance.py b/medcat-v2/paper/scripts/variance/plot_variance.py new file mode 100644 index 000000000..7d8e2cbc4 --- /dev/null +++ b/medcat-v2/paper/scripts/variance/plot_variance.py @@ -0,0 +1,146 @@ +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +# import numpy as np +import pandas as pd +import seaborn as sns + + +data_lc = { + 'Config': [ + 'Spacy / Vector Context Linker', 'Spacy / Faster Linker', 'Spacy / Embedding Linker', + 'Regex / Vector Context Linker', 'Regex / Faster Linker', 'Regex / Embedding Linker',], + 'Runtime': [68.16, 51.64, 321.37, 30.54, 6.21, 348.79], + 'F1': [0.6072, 0.5804, 0.5932, 0.5693, 0.5681, 0.5852] +} +data_cometa = { + 'Config': [ + 'Spacy / Vector Context Linker', 'Spacy / Faster Linker', 'Spacy / Embedding Linker', + 'Regex / Vector Context Linker', 'Regex / Faster Linker', 'Regex / Embedding Linker',], + 'Runtime': [75.40, 48.05, 511.19, 117.55, 82.61, 248.02], + 'F1': [0.4112, 0.3871, 0.4215, 0.3722, 0.3664, 0.3847] +} + + +def draw(is_lc: bool): +# 1. Set style for high-visibility poster presentation + sns.set_theme(style="whitegrid") + plt.rcParams.update({ + 'font.size': 14, + 'axes.labelsize': 16, + 'axes.titlesize': 18, + 'xtick.labelsize': 14, + 'ytick.labelsize': 14, + 'legend.fontsize': 10, + 'figure.titlesize': 20 + }) + if is_lc: + data = data_lc + else: + data = data_cometa + df = pd.DataFrame(data) + + # Sort by runtime to properly compute the trade-off frontier + df = df.sort_values(by='Runtime').reset_index(drop=True) + + # Identify different types + to_marker_map = { + "Spacy": "*", + "Regex": 'o', + } + df['marker'] = '' + for part, marker in to_marker_map.items(): + df.loc[df['Config'].str.contains(part), 'marker'] = marker + to_colour_map = { + "Vector Context Linker": "blue", + "Faster Linker": "red", + "Embedding Linker": "green", + } + df['colour'] = '' + for part, colour in to_colour_map.items(): + df.loc[df['Config'].str.contains(part), 'colour'] = colour + + fig, ax = plt.subplots(figsize=(7, 4)) + + for marker_type in set(to_marker_map.values()): + for cur_colour in set(to_colour_map.values()): + cur_df = df[(df['colour'] == cur_colour) & (df['marker'] == marker_type)] + ax.scatter(cur_df['Runtime'], cur_df['F1'], + marker=marker_type, + color=cur_colour, s=120, alpha=0.6, + edgecolor='k', zorder=3) + + for i, row in df.iterrows(): + # Subtle offsets to prevent text overlaying directly on top of the dots + xytext = (12, -5) + if is_lc: + if "Regex" in row['Config'] and "Faster Linker" in row['Config']: + xytext = (12, -15) + if "Embedding" in row['Config']: + xytext = (-150, -5) + if "Regex" in row['Config']: + xytext = (-140, -15) + + ax.annotate( + row['Config'], + xy=(row['Runtime'], row['F1']), + xytext=xytext, + textcoords='offset points', + fontsize=12, + color='black', + ) + + # 6. Formatting Axis and Labels + ax.set_xlabel('Runtime (seconds)', labelpad=10, weight='bold') + ax.set_ylabel('$F_1$ Score', labelpad=10, weight='bold') + ax.set_title( + f'Speed vs performance for {"Linking Challenge" if is_lc else "COMETA"}', + pad=15, weight='bold') + + # Adjust limits appropriately to give padding for text labels + ax.set_ylim(df['F1'].min() - 0.005, df['F1'].max() + 0.005) + ax.set_xlim(0, df['Runtime'].max() + 50) + + # Build custom handles for the Marker legend + marker_handles = [ + Line2D([0], [0], marker=m, color='gray', linestyle='None', markersize=10, label=label) + for label, m in to_marker_map.items() + ] + + # Build custom handles for the Color legend + # (We use a generic square 's' or circle 'o' marker just to showcase the color) + color_handles = [ + Line2D([0], [0], marker='s', color='None', markerfacecolor=c, markeredgecolor=c, markersize=10, label=label) + for label, c in to_colour_map.items() + ] + + # Create and add the legends to the axis + # First legend (Markers) - placed normally + if is_lc: + leg1 = ax.legend(handles=marker_handles, title="Shape Meaning", loc='upper right', frameon=True) + else: + leg1 = ax.legend(handles=marker_handles, title="Shape Meaning", loc='center right', frameon=True) + + # Second legend (Colors) - added manually so it doesn't overwrite the first one + leg2 = ax.legend(handles=color_handles, title="Color Meaning", loc='lower right', frameon=True) + ax.add_artist(leg1) # CRITICAL: This prevents leg2 from deleting leg1 + + + plt.tight_layout() + if is_lc: + plt.savefig('tradeoff_lc.png', dpi=300, transparent=True) + else: + plt.savefig('tradeoff_cometa.png', dpi=300, transparent=True) + + +if __name__ == "__main__": + from sys import argv + if len(argv) < 2: + print("Assuming linking challenge") + is_lc = True + else: + is_lc = ( + "lc" in argv[1].lower() or + "linking" in argv[1].lower() or + "challenge" in argv[1].lower()) + print("Doing dataset", "Linking Challenge" if is_lc else "COMETA") + draw(is_lc)