-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_domain.py
48 lines (43 loc) · 1.98 KB
/
train_domain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gpl
# model_name = 'Snowflake/snowflake-arctic-embed-xs'
# model_name = 'GPL/msmarco-distilbert-margin-mse',
# model_name = 'distilbert-base-uncased'
model_name = 'bert-base-multilingual-uncased'
batch_size = 64
gpl_steps = 140000
output_dir = './output/enbeddrus_domain'
evaluation_output = f"{output_dir}_evaluation"
gpl.train(
path_to_generated_data=f"generated/embeddrus",
base_ckpt=model_name,
# The starting checkpoint of the experiments in the paper
gpl_score_function="dot",
# Note that GPL uses MarginMSE loss, which works with dot-product
batch_size_gpl=batch_size,
gpl_steps=gpl_steps,
# Resize the corpus to `new_size` (|corpus|) if needed.
# When set to None (by default), the |corpus| will be the full size.
# When set to -1, the |corpus| will be set automatically:
# If QPP * |corpus| <= 250K, |corpus| will be the full size;
# else QPP will be set 3 and |corpus| will be set to 250K / 3
new_size=-1,
# Number of Queries Per Passage (QPP) in the query generation step.
# When set to -1 (by default), the QPP will be chosen automatically:
# If QPP * |corpus| <= 250K, then QPP will be set to 250K / |corpus|;
# else QPP will be set 3 and |corpus| will be set to 250K / 3
queries_per_passage=25,
output_dir=output_dir,
evaluation_data=f"./datasets",
evaluation_output=evaluation_output,
generator="BeIR/query-gen-msmarco-t5-base-v1",
retrievers=["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
# Note that these two retriever model work with cosine-similarity
retriever_score_functions=["cos_sim", "cos_sim"],
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
# This prefix will appear as part of the (folder/file) names for query-generation results:
# For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
qgen_prefix="qgen",
do_evaluation=True,
# One can use this flag for enabling the efficient float16 precision
# use_amp=True
)