Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrain before retagging #16

Open
wants to merge 106 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
4e42226
Prototype for retagging using spacy
Sep 8, 2023
b056178
Prototype for retagging using spacy
Sep 8, 2023
b6c1444
Prototype for retagging using spacy
Sep 8, 2023
d847882
Prototype for retagging using spacy
Sep 8, 2023
348ad3b
Prototype for retagging using spacy
Sep 8, 2023
b9b1e9f
Prototype for retagging using spacy
Sep 8, 2023
06840fd
Prototype for retagging using spacy
Sep 8, 2023
b442043
Prototype for retagging using spacy
Sep 8, 2023
66156a8
Prototype for retagging using spacy
Sep 8, 2023
f3f8c23
Changes to cNN
Sep 8, 2023
f6c770e
Changes to cNN
Sep 8, 2023
ec2aeb6
Adds cpu and gpu
Sep 8, 2023
4289f4f
Adds large model
Sep 8, 2023
96356ae
Adds large model
Sep 8, 2023
e57e891
Adds sparknlp
Sep 8, 2023
e38c668
Adds sparknlp
Sep 8, 2023
da10221
Different config of sparknlp
Sep 8, 2023
848fd7a
Fixes pipeline
Sep 8, 2023
bf06cc8
Prototypes retagging after training
Sep 8, 2023
8db1322
Prototypes retagging after training
Sep 8, 2023
5440a99
Prototypes retagging after training
Sep 8, 2023
98cc4e3
Saves corrections
Sep 8, 2023
35652e0
Saves corrections
Sep 8, 2023
7a2b0ed
Saves corrections
Sep 8, 2023
2a64c86
Saves corrections
Sep 8, 2023
47690b0
Saves corrections
Sep 8, 2023
d7dd1cb
Saves corrections
Sep 8, 2023
2951428
500 rows for better accuracy
Sep 8, 2023
42b3b82
Adds batching and refactors prediction
Sep 9, 2023
c94a5a2
Adds batching and refactors prediction
Sep 9, 2023
430cc8a
Adds batching and refactors prediction
Sep 9, 2023
c93b945
Adds batching and refactors prediction
Sep 9, 2023
d8e318b
Adds batching and refactors prediction
Sep 9, 2023
0af5a85
Adds batching and refactors prediction
Sep 9, 2023
e936285
Adds batching and refactors prediction
Sep 9, 2023
44a790b
Adds some logging, refactors
Sep 9, 2023
d948f3a
Adds pure spark
Sep 10, 2023
5850b22
Adds pure spark
Sep 11, 2023
f884a25
Adds spark.repartition
Sep 11, 2023
7ebabba
Adds spark.repartition
Sep 11, 2023
e86a3fe
Adds parquet optimization
Sep 11, 2023
6083b22
Adds parquet optimization
Sep 11, 2023
4c3933d
Adds parquet optimization
Sep 11, 2023
06b5446
Adds spark context config
josejuanmartinez Sep 11, 2023
629e6b4
Spark config
Sep 11, 2023
127c8b4
Spark config
Sep 11, 2023
2dd3243
Spark config
Sep 11, 2023
f486608
Spark config
Sep 11, 2023
6be482b
Spark config
Sep 11, 2023
dca895b
Spark config
Sep 11, 2023
c8584bd
Spark config
Sep 11, 2023
13dbf67
Spark config
Sep 11, 2023
c5667be
Spark config
Sep 11, 2023
d71acf2
Remove duplicates and add reference number'
nsorros Sep 12, 2023
f7d70b0
Switch to save in excel format
nsorros Sep 12, 2023
5cb92d4
Run pre-commit
Sep 12, 2023
902eda0
Adds documentation, configurable tags, years, memory
Sep 12, 2023
4501342
Add openpyxl to support excel
Sep 12, 2023
7b78843
Run dvc
Sep 12, 2023
22fad69
Total refactor: XLinear
Sep 12, 2023
2bc3e21
Total refactor finished for retagging: XLinear
Sep 13, 2023
f4b1f32
Black
Sep 13, 2023
4ecec1c
Adds retagging tests and updates documentation
Sep 13, 2023
8bc7b4d
Adds retagging tests and updates documentation
Sep 13, 2023
58f39fa
Black
Sep 13, 2023
ad84583
rufus
Sep 13, 2023
5d41bfe
rufus and black
Sep 13, 2023
96c88ce
black
Sep 13, 2023
c70bdad
Removes spacy
Sep 13, 2023
75a952c
Merge pull request #16 from MantisAI/fix-wellcome-active-sample
nsorros Sep 14, 2023
487f8cf
Adds tests for augment
Sep 14, 2023
efacba7
Black
Sep 14, 2023
6e01ee0
Better error management. Black
josejuanmartinez Sep 14, 2023
3e91ec6
"Better error management. Black"
Sep 14, 2023
10c4468
Ruff
Sep 14, 2023
5882aa0
Merge branch 'main' into retagging
Sep 14, 2023
7206885
Feedback from PR comments. Merging latest main into this branch.
Sep 14, 2023
867fddb
Black
Sep 14, 2023
b1f1b40
Adds retagging dvc data
Sep 14, 2023
88fdd2a
Adds retagging dvc data
Sep 14, 2023
f98e598
Adds retagging dvc data
Sep 14, 2023
71eebbf
Changes reference fail from original to retagged
Sep 14, 2023
2660b7a
Changes reference fail from original to retagged
Sep 14, 2023
a8f86a6
Adds mkdir to yaml
josejuanmartinez Sep 14, 2023
0e8aa1c
Adds two steps to the dvc pipeline
Sep 14, 2023
d79d80b
Adds two steps to the dvc pipeline
Sep 14, 2023
0adb58e
Adds two steps to the dvc pipeline
Sep 14, 2023
43046ed
Adds two steps to the dvc pipeline
Sep 14, 2023
d8a30b9
Adds two steps to the dvc pipeline
Sep 14, 2023
441ee98
Adds two steps to the dvc pipeline
Sep 14, 2023
b1736da
Adds two steps to the dvc pipeline
Sep 14, 2023
5859bbe
Adds two steps to the dvc pipeline
Sep 14, 2023
25aa0e6
Adds two steps to the dvc pipeline
Sep 14, 2023
cf64bb1
Adds two steps to the dvc pipeline
Sep 14, 2023
f5c8da3
Default wandb
Sep 14, 2023
a45d982
Documentation about wandb and dvc repro
Sep 14, 2023
823bc64
12 epochs
Sep 14, 2023
ac928f1
Fixex bug with the batch calculation
Sep 15, 2023
e41dcbd
black
josejuanmartinez Sep 15, 2023
b1dc41f
Merge pull request #17 from MantisAI/retagging
josejuanmartinez Sep 15, 2023
c021da7
Adds last trained model
josejuanmartinez Sep 15, 2023
ae92b78
Adds last metrics bertmesh
josejuanmartinez Sep 16, 2023
f823fd4
tag all active grants
nsorros Sep 19, 2023
97da452
Merge pull request #18 from MantisAI/active-grants
nsorros Oct 17, 2023
f66cbbb
:wrench: change to steps saving / eval strategy
agombert Oct 24, 2023
58ef63e
retrained evaluation
agombert Nov 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ cython_debug/
# Folder where training outputs are stored
bertmesh_outs/
wandb/
/bertmesh_before_retagging
/preprocessed_results
223 changes: 201 additions & 22 deletions README.md

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions bertmesh_before_retagging.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
outs:
- md5: 4964c2e8f83f071bcb7c467a859726a6.dir
size: 2593104471
nfiles: 5
path: bertmesh_before_retagging
1 change: 1 addition & 0 deletions data/grants_comparison/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/meshterms_list.txt
/comparison.csv
/comparison.xlsx
1 change: 1 addition & 0 deletions data/raw/.gitignore
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
/desc2021.xml
/disease_tags_validation_grants.xlsx
/active_grants_last_5_years.csv
/retagging
5 changes: 5 additions & 0 deletions data/raw/retagging.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
outs:
- md5: 1a64ed7c09ef3bc49b1bfcc17f5d7e1f.dir
size: 5546175163
nfiles: 6
path: retagging
8 changes: 5 additions & 3 deletions examples/augment.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \
--min-examples 25 \
--concurrent-calls 25
# Augments data using a file with 1 label per line and years
grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FILE] \
--tags "Mathematics" \
--examples 25 \
--concurrent-calls 1
5 changes: 0 additions & 5 deletions examples/augment_specific_tags.sh

This file was deleted.

3 changes: 2 additions & 1 deletion examples/preprocess_and_train_by_epochs.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Run on g5.12xlarge instance

# Without saving (on-the-fly)
SOURCE="data/raw/allMeSH_2021.jsonl"
#SOURCE="data/raw/allMeSH_2021.jsonl"
SOURCE="data/raw/retagging/allMeSH_2021.2016-2021.jsonl"

grants-tagger train bertmesh \
"" \
Expand Down
3 changes: 2 additions & 1 deletion examples/preprocess_and_train_by_steps.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Run on g5.12xlarge instance

# Without saving (on-the-fly)
SOURCE="data/raw/allMeSH_2021.jsonl"
# SOURCE="data/raw/allMeSH_2021.jsonl"
SOURCE="data/raw/retagging/allMeSH_2021.2016-2021.jsonl"

grants-tagger train bertmesh \
"" \
Expand Down
4 changes: 2 additions & 2 deletions examples/preprocess_splitting_by_fract.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 0.05
grants-tagger preprocess mesh data/raw/retagging/allMeSH_2021.2016-2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 0.05
4 changes: 2 additions & 2 deletions examples/preprocess_splitting_by_rows.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 25000
grants-tagger preprocess mesh data/raw/retagging/allMeSH_2021.2016-2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 25000
4 changes: 2 additions & 2 deletions examples/preprocess_splitting_by_years.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
grants-tagger preprocess mesh data/raw/retagging/allMeSH_2021.2016-2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 25000 \
--train-years 2016,2017,2018,2019 \
--test-years 2020,2021
--test-years 2020,2021
2 changes: 1 addition & 1 deletion examples/resume_train_by_epoch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ grants-tagger train bertmesh \
--save_strategy epoch \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
--wandb_api_key ${WANDB_API_KEY}
2 changes: 1 addition & 1 deletion examples/resume_train_by_steps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ grants-tagger train bertmesh \
--save_steps 10000 \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
--wandb_api_key ${WANDB_API_KEY}
6 changes: 6 additions & 0 deletions examples/retag.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
grants-tagger retag mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FILE_HERE] \
--tags "Artificial Intelligence,HIV,Data Collection,Mathematics,Geography" \
--years 2016,2017,2018,2019,2020,2021 \
--train-examples 100 \
--batch-size 10000 \
--supervised
50 changes: 40 additions & 10 deletions grants_tagger_light/augmentation/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from grants_tagger_light.augmentation.parallel_augment_openai import (
ParallelAugmentOpenAI,
)
from grants_tagger_light.utils.years_tags_parser import parse_tags

augment_app = typer.Typer()

Expand Down Expand Up @@ -50,6 +51,7 @@ def augment(
prompt_template: str = "grants_tagger_light/augmentation/prompt.template",
concurrent_calls: int = os.cpu_count() * 2,
temperature: float = 1.5,
tags: list = None,
tags_file_path: str = None,
):
if model_key.strip().lower() not in ["gpt-3.5-turbo", "text-davinci", "gpt-4"]:
Expand All @@ -60,6 +62,7 @@ def augment(
dset = load_from_disk(os.path.join(data_path, "dataset"))
if "train" in dset:
dset = dset["train"]

logger.info("Obtaining count values from the labels...")
pool = multiprocessing.Pool(processes=num_proc)
element_counts_list = pool.map(_count_elements_in_sublist, dset["meshMajor"])
Expand All @@ -71,16 +74,22 @@ def augment(
merged_element_counts.items(), key=lambda x: x[1], reverse=True
)
sorted_merged_element_counts_dict = dict(sorted_merged_element_counts)

print(f"Tags: {tags}")
if tags is None:
tags = []
if tags_file_path is not None:
with open(tags_file_path, "r") as f:
tags = f.read().split("\n")
tags.extend([x.strip() for x in f.readlines()])
logger.info(
f"Tags file path found. Filtering {len(tags)} tags "
f"(examples found: {tags[:15]}...)"
)
sorted_merged_element_counts_dict = {
k: v for k, v in sorted_merged_element_counts_dict.items() if k in tags
}
if len(tags) > 0:
sorted_merged_element_counts_dict = {
k: v for k, v in sorted_merged_element_counts_dict.items() if k in tags
}
logger.info(f"Tags count dictionary: {sorted_merged_element_counts_dict}")

if min_examples is not None:
sorted_merged_element_counts_dict = {
Expand All @@ -89,11 +98,28 @@ def augment(
if v < min_examples
}

if len(sorted_merged_element_counts_dict.keys()) < 1:
logger.error(
"I did not find any examples for your tags "
"in your preprocessed folder. Try:\n"
"- Other train/set split in `preprocess`;\n"
"- Other years;\n"
"- Other tags;"
)
exit(-1)

with open(f"{save_to_path}.count", "w") as f:
f.write(json.dumps(sorted_merged_element_counts_dict, indent=2))

tags_to_augment = list(sorted_merged_element_counts_dict.keys())

if len(tags_to_augment) < concurrent_calls:
logger.error(
"Found less tags than concurrent calls to OpenAI."
f" Overwritting `concurrent-calls` to {len(tags_to_augment)}"
)
concurrent_calls = len(tags_to_augment)

biggest_tags_to_augment = [
f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[:5]
]
Expand Down Expand Up @@ -156,10 +182,8 @@ def augment(

@augment_app.command()
def augment_cli(
data_path: str = typer.Argument(..., help="Path to mesh.jsonl"),
save_to_path: str = typer.Argument(
..., help="Path to save the serialized PyArrow dataset after preprocessing"
),
data_path: str = typer.Argument(..., help="Path to folder after `preprocess`"),
save_to_path: str = typer.Argument(..., help="Path to save the new jsonl data"),
model_key: str = typer.Option(
"gpt-3.5-turbo",
help="LLM to use data augmentation. By now, only `openai` is supported",
Expand Down Expand Up @@ -193,6 +217,7 @@ def augment_cli(
max=2,
help="A value between 0 and 2. The bigger - the more creative.",
),
tags: str = typer.Option(None, help="Comma separated list of tags to retag"),
tags_file_path: str = typer.Option(
None,
help="Text file containing one line per tag to be considered. "
Expand All @@ -206,13 +231,17 @@ def augment_cli(
)
exit(-1)

if tags_file_path is None and min_examples is None:
if tags_file_path is None and tags is None and min_examples is None:
logger.error(
"To understand which tags need to be augmented, "
"set either --min-examples or --tags-file-path"
"set either --min-examples or --tags-file-path or --tags"
)
exit(-1)

if tags_file_path is not None and not os.path.isfile(tags_file_path):
logger.error(f"{tags_file_path} not found")
exit(-1)

if float(temperature) > 2.0 or float(temperature) < -2.0:
logger.error("Temperature should be in the range [-2, 2]")
exit(-1)
Expand All @@ -228,5 +257,6 @@ def augment_cli(
prompt_template=prompt_template,
concurrent_calls=concurrent_calls,
temperature=temperature,
tags=parse_tags(tags),
tags_file_path=tags_file_path,
)
2 changes: 1 addition & 1 deletion grants_tagger_light/augmentation/prompt.template
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ ABSTRACT:
{ABSTRACT}

TOPIC:
{TOPIC}
{TOPIC}
2 changes: 2 additions & 0 deletions grants_tagger_light/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from grants_tagger_light.augmentation import augment_app
from grants_tagger_light.download_epmc import download_epmc_cli
from grants_tagger_light.evaluation import evaluate_app
from grants_tagger_light.retagging import retag_app
from grants_tagger_light.predict import predict_cli
from grants_tagger_light.preprocessing import preprocess_app
from grants_tagger_light.tune_threshold import tune_threshold_cli
Expand All @@ -18,6 +19,7 @@
app.add_typer(preprocess_app, name="preprocess")
app.add_typer(augment_app, name="augment")
app.add_typer(evaluate_app, name="evaluate")
app.add_typer(retag_app, name="retag")


app.command("predict")(predict_cli)
Expand Down
21 changes: 15 additions & 6 deletions grants_tagger_light/evaluation/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from tqdm.auto import tqdm

import scipy.sparse as sp
import typer
Expand Down Expand Up @@ -33,13 +34,20 @@ def evaluate_model(
model = BertMesh.from_pretrained(model_path)

label_binarizer = MultiLabelBinarizer()
label_binarizer.fit([list(model.id2label.values())])
id2labels = [0 for i in range(model.config.num_labels)]
for k, v in model.id2label.items():
id2labels[k] = v
label_binarizer.fit([id2labels])

pipe = pipeline(
"grants-tagging",
model=model,
tokenizer="Wellcome/WellcomeBertMesh",
device=0,
)
def data():
for x in X_test:
yield x

if split_data:
print(
Expand All @@ -48,13 +56,14 @@ def evaluate_model(
)
_, X_test, _, Y_test = load_train_test_data(data_path, label_binarizer)
else:
X_test, Y_test, _ = load_data(data_path, label_binarizer)

Y_pred_proba = pipe(X_test, return_labels=False)

X_test, Y_test, _ = load_data(data_path, label_binarizer, model_id2labels=model.id2label)

Y_pred_proba = []
for out in tqdm(pipe(data(), return_labels=False)):
Y_pred_proba.append(out)
Y_pred_proba = torch.vstack(Y_pred_proba)

Y_pred_proba = sp.csr_matrix(Y_pred_proba)
#Y_pred_proba = sp.csr_matrix(Y_pred_proba)

if not isinstance(threshold, list):
threshold = [threshold]
Expand Down
13 changes: 12 additions & 1 deletion grants_tagger_light/models/xlinear/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ def __init__(
# Those are MeshXLinear params
self.threshold = threshold

self.model_path = None
self.xlinear_model_ = None
self.vectorizer_ = None

self.label_binarizer_path = label_binarizer_path
self.label_binarizer_ = None

if label_binarizer_path is not None:
self.load_label_binarizer(label_binarizer_path)

Expand Down Expand Up @@ -167,7 +174,6 @@ def predict_tags(
"""
X: list or numpy array of texts
model_path: path to trained model
label_binarizer_path: path to trained label_binarizer
probabilities: bool, default False. When true probabilities
are returned along with tags
threshold: float, default 0.5. Probability threshold to be used to assign tags.
Expand Down Expand Up @@ -217,6 +223,9 @@ def load(self, model_path, is_predict_only=True):
with open(params_path, "r") as f:
self.__dict__.update(json.load(f))

self.load_label_binarizer(self.label_binarizer_path)
self.model_path = model_path

if self.vectorizer_library == "sklearn":
self.vectorizer_ = load_pickle(vectorizer_path)
else:
Expand All @@ -229,6 +238,8 @@ def load(self, model_path, is_predict_only=True):
model_path, is_predict_only=is_predict_only
)

return self

def load_label_binarizer(self, label_binarizer_path):
with open(label_binarizer_path, "rb") as f:
self.label_binarizer_ = pickle.loads(f.read())
Expand Down
3 changes: 1 addition & 2 deletions grants_tagger_light/preprocessing/preprocess_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def preprocess_mesh(
num_proc=num_proc,
desc="Tokenizing",
fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"},
load_from_cache_file=False,
)
logger.info("Time taken to tokenize: {}".format(time.time() - t1))

Expand Down Expand Up @@ -261,7 +260,7 @@ def preprocess_mesh_cli(
if not data_path.endswith("jsonl"):
logger.error(
"It seems your input MeSH data is not in `jsonl` format. "
"Please, run first `scripts/mesh_json_to_jsonlpy.`"
"Please, run first `scripts/mesh_json_to_jsonl.py.`"
)
exit(-1)

Expand Down
8 changes: 8 additions & 0 deletions grants_tagger_light/retagging/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typer
from .retagging import retag_cli

retag_app = typer.Typer()
retag_app.command(
"mesh",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)(retag_cli)
Loading