Skip to content

Commit

Permalink
Merge branch 'main' into 1983-new_reco_pages
Browse files Browse the repository at this point in the history
  • Loading branch information
GresilleSiffle committed Oct 31, 2024
2 parents 9ae779a + 01b9f2c commit 0c37963
Show file tree
Hide file tree
Showing 59 changed files with 1,363 additions and 881 deletions.
30 changes: 29 additions & 1 deletion backend/ml/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import pandas as pd
from django.db.models import Case, F, QuerySet, When
from django.db.models import Case, F, Q, QuerySet, When
from django.db.models.expressions import RawSQL
from solidago.pipeline import TournesolInput

Expand All @@ -14,6 +14,7 @@
ContributorScaling,
Entity,
)
from vouch.models import Voucher


class MlInputFromDb(TournesolInput):
Expand Down Expand Up @@ -189,3 +190,30 @@ def get_individual_scores(

dtf = pd.DataFrame(values)
return dtf[["user_id", "entity", "criteria", "raw_score"]]

def get_vouches(self):
values = Voucher.objects.filter(
by__is_active=True,
to__is_active=True,
).values(
voucher=F("by__id"),
vouchee=F("to__id"),
vouch=F("value"),
)
return pd.DataFrame(values, columns=["voucher", "vouchee", "vouch"])

def get_users(self):
values = (
User.objects
.filter(is_active=True)
.annotate(is_pretrusted=Q(pk__in=User.with_trusted_email()))
.values(
"is_pretrusted",
"trust_score",
user_id=F("id"),
)
)
return pd.DataFrame(
data=values,
columns=["user_id", "is_pretrusted", "trust_score"],
).set_index("user_id")
172 changes: 139 additions & 33 deletions backend/ml/management/commands/ml_train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,66 @@
import os
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed

from django import db
from django.conf import settings
from django.core.management.base import BaseCommand
from solidago.aggregation import EntitywiseQrQuantile
from solidago.pipeline import Pipeline
from solidago.post_process.squash import Squash
from solidago.preference_learning import UniformGBT
from solidago.scaling import Mehestan, QuantileShift, ScalingCompose, Standardize
from solidago.trust_propagation import LipschiTrust, NoopTrust
from solidago.voting_rights import AffineOvertrust

from ml.inputs import MlInputFromDb
from ml.mehestan.run import MehestanParameters, run_mehestan
from ml.outputs import TournesolPollOutput, save_tournesol_scores
from tournesol.models import EntityPollRating, Poll
from tournesol.models.poll import ALGORITHM_LICCHAVI, ALGORITHM_MEHESTAN
from vouch.trust_algo import trust_algo
from tournesol.models.poll import ALGORITHM_MEHESTAN, DEFAULT_POLL_NAME


def get_solidago_pipeline(run_trust_propagation: bool = True):
if run_trust_propagation:
trust_algo = LipschiTrust()
else:
trust_algo = NoopTrust()

aggregation_lipshitz = 0.1

return Pipeline(
trust_propagation=trust_algo,
voting_rights=AffineOvertrust(),
# TODO: use LBFGS (faster) implementation.
# Currently requires to install Solidago with "torch" extra.
preference_learning=UniformGBT(
prior_std_dev=7.0,
convergence_error=1e-5,
cumulant_generating_function_error=1e-5,
high_likelihood_range_threshold=0.25,
# max_iter=300,
),
scaling=ScalingCompose(
Mehestan(),
Standardize(
dev_quantile=0.9,
lipschitz=0.1,
),
QuantileShift(
quantile=0.1,
# target_score is defined to be the recommendability
# threshold, i.e the therorical max score that can be
# reached by an entity with 2 contributors.
target_score=2*aggregation_lipshitz,
lipschitz=0.1,
error=1e-5,
),
),
aggregation=EntitywiseQrQuantile(
quantile=0.5,
lipschitz=aggregation_lipshitz,
error=1e-5,
),
post_process=Squash(score_max=100.)
)


class Command(BaseCommand):
Expand All @@ -17,37 +73,87 @@ def add_arguments(self, parser):
help="Disable trust scores computation and preserve existing trust_score values",
)
parser.add_argument("--main-criterion-only", action="store_true")
parser.add_argument("--alpha", type=float, default=None)
parser.add_argument("-W", type=float, default=None)
parser.add_argument("--score-shift-quantile", type=float, default=None)
parser.add_argument("--score-deviation-quantile", type=float, default=None)

def handle(self, *args, **options):
if not options["no_trust_algo"]:
# Update "trust_score" for all users
trust_algo()

# Update scores for all polls
for poll in Poll.objects.filter(active=True):
ml_input = MlInputFromDb(poll_name=poll.name)

if poll.algorithm == ALGORITHM_MEHESTAN:
kwargs = {
param: options[param]
for param in ["alpha", "W", "score_shift_quantile", "score_deviation_quantile"]
if options[param] is not None
}
parameters = MehestanParameters(**kwargs)
run_mehestan(
ml_input=ml_input,
poll=poll,
parameters=parameters,
main_criterion_only=options["main_criterion_only"],
if poll.algorithm != ALGORITHM_MEHESTAN:
raise ValueError(f"Unknown algorithm {poll.algorithm!r}")

is_default_poll = (poll.name == DEFAULT_POLL_NAME)
self.run_poll_pipeline(
poll=poll,
update_trust_scores=(not options["no_trust_algo"] and is_default_poll),
main_criterion_only=options["main_criterion_only"],
)

def run_poll_pipeline(
self,
poll: Poll,
update_trust_scores: bool,
main_criterion_only: bool,
):
pipeline = get_solidago_pipeline(
run_trust_propagation=update_trust_scores
)
criteria_list = poll.criterias_list
criteria_to_run = [poll.main_criteria]
if not main_criterion_only:
criteria_to_run.extend(
c for c in criteria_list if c != poll.main_criteria
)

if settings.MEHESTAN_MULTIPROCESSING:
# compute each criterion in parallel
cpu_count = os.cpu_count() or 1
cpu_count -= settings.MEHESTAN_KEEP_N_FREE_CPU
os.register_at_fork(before=db.connections.close_all)
executor = ProcessPoolExecutor(max_workers=max(1, cpu_count))
else:
# In tests, we might prefer to use a single thread to reduce overhead
# of multiple processes, db connections, and redundant numba compilation
executor = ThreadPoolExecutor(max_workers=1)

with executor:
futures = []
for crit in criteria_to_run:
pipeline_input = MlInputFromDb(poll_name=poll.name)
pipeline_output = TournesolPollOutput(
poll_name=poll.name,
criterion=crit,
save_trust_scores_enabled=(update_trust_scores and crit == poll.main_criteria)
)
elif poll.algorithm == ALGORITHM_LICCHAVI:
raise NotImplementedError("Licchavi is no longer supported")
else:
raise ValueError(f"unknown algorithm {repr(poll.algorithm)}'")
self.stdout.write(f"Starting bulk update of sum_trust_score for poll {poll.name}")
EntityPollRating.bulk_update_sum_trust_scores(poll)
self.stdout.write(f"Finished bulk update of sum_trust_score for poll {poll.name}")

futures.append(
executor.submit(
self.run_pipeline_and_close_db,
pipeline=pipeline,
pipeline_input=pipeline_input,
pipeline_output=pipeline_output,
criterion=crit,
)
)

for fut in as_completed(futures):
# reraise potential exception
fut.result()

save_tournesol_scores(poll)
EntityPollRating.bulk_update_sum_trust_scores(poll)

self.stdout.write(f"Pipeline for poll {poll.name}: Done")

@staticmethod
def run_pipeline_and_close_db(
pipeline: Pipeline,
pipeline_input: MlInputFromDb,
pipeline_output: TournesolPollOutput,
criterion: str
):
pipeline.run(
input=pipeline_input,
criterion=criterion,
output=pipeline_output,
)
# Closing the connection fixes a warning in tests
# about open connections to the database.
db.connection.close()
4 changes: 3 additions & 1 deletion backend/ml/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def save_entity_scores(
scores: pd.DataFrame,
score_mode="default",
):
scores_iterator = scores[["entity_id", "score", "uncertainty"]].itertuples(index=False)
if len(scores) == 0:
return

scores_iterator = scores[["entity_id", "score", "uncertainty"]].itertuples(index=False)
with transaction.atomic():
EntityCriteriaScore.objects.filter(
poll=self.poll,
Expand Down
7 changes: 6 additions & 1 deletion backend/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@

TWITTERBOT_CREDENTIALS = server_settings.get("TWITTERBOT_CREDENTIALS", {})

if "DJANGO_LOG_LEVEL" in os.environ:
DJANGO_LOG_LEVEL = os.environ["DJANGO_LOG_LEVEL"]
else:
DJANGO_LOG_LEVEL = server_settings.get("DJANGO_LOG_LEVEL", "INFO")

LOGGING = {
"version": 1,
"disable_existing_loggers": False,
Expand All @@ -374,7 +379,7 @@
},
"root": {
"handlers": ["console"],
"level": os.environ.get("DJANGO_LOG_LEVEL", "DEBUG"),
"level": DJANGO_LOG_LEVEL,
},
"loggers": {
"factory": {
Expand Down
2 changes: 0 additions & 2 deletions backend/tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ pylint-django==2.5.3
pylint-json2html==0.4.0

# Unit tests tools
faker==13.15.1
pytest==7.1.3
pytest-html==3.1.1
pytest-mock==3.8.2

# Pytest for django
pytest-django==4.5.2
Expand Down
10 changes: 6 additions & 4 deletions backend/tournesol/lib/public_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def write_comparisons_file(
"criteria",
"score",
"score_max",
"week_date"
"week_date",
]
writer = csv.DictWriter(write_target, fieldnames=fieldnames)
writer.writeheader()
Expand Down Expand Up @@ -413,7 +413,9 @@ def write_vouchers_file(write_target):
"to_username": voucher.to.username,
"value": voucher.value,
}
for voucher in Voucher.objects.filter(is_public=True)
.select_related("by", "to")
.order_by("by__username", "to__username")
for voucher in (
Voucher.objects.filter(is_public=True, by__is_active=True, to__is_active=True)
.select_related("by", "to")
.order_by("by__username", "to__username")
)
)
8 changes: 4 additions & 4 deletions backend/tournesol/management/commands/load_public_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def handle(self, *args, **options):
entity_1=videos[entity_a],
entity_2=videos[entity_b],
)
for _, values in rows.iterrows():
for values in rows.itertuples(index=False):
ComparisonCriteriaScore.objects.create(
comparison=comparison,
criteria=values["criteria"],
score=values["score"],
score_max=values["score_max"],
criteria=values.criteria,
score=values.score,
score_max=values.score_max,
)
nb_comparisons += 1
print(f"Created {nb_comparisons} comparisons")
Expand Down
Loading

0 comments on commit 0c37963

Please sign in to comment.