Skip to content

Commit

Permalink
optimize post_save() signal.
Browse files Browse the repository at this point in the history
  • Loading branch information
rostyslavhereha committed Oct 31, 2023
1 parent a5f23cd commit 8fb0f9e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
14 changes: 8 additions & 6 deletions label_studio/tasks/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,14 +559,16 @@ def create(self, request, *args, **kwargs):
predictions_array = request_data.get("result", [])
tasks_array = request_data.get("task", [])
predictions_to_create = []
tasks = Task.objects.select_related('project').in_bulk(tasks_array)
data = request_data.copy()
for prediction, task_id in zip(predictions_array, tasks_array):
data = request_data.copy()
data["result"] = prediction
data["task"] = Task.objects.get(id=task_id)
predictions_to_create.append(Prediction(**data))
predictions = Prediction.objects.bulk_create(predictions_to_create)
for prediction in predictions:
post_save.send(sender=Prediction, instance=prediction, creared=True)
data["task"] = tasks.get(task_id)
pred = Prediction(**data)
predictions_to_create.append(pred)
post_save.send(sender=Prediction, instance=pred, creared=True)

Prediction.objects.bulk_create(predictions_to_create)
return Response(data={'success': 'Predictions created successfully.'}, status=status.HTTP_201_CREATED)

@method_decorator(name='get', decorator=swagger_auto_schema(auto_schema=None))
Expand Down
5 changes: 3 additions & 2 deletions label_studio/tasks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,9 @@ def remove_predictions_from_project(sender, instance, **kwargs):
@receiver(post_save, sender=Prediction)
def save_predictions_to_project(sender, instance, **kwargs):
"""Add predictions counters"""
instance.task.total_predictions = instance.task.predictions.all().count()
instance.task.save(update_fields=['total_predictions'])
task_id = instance.task_id
predictions_count = Prediction.objects.filter(task_id=task_id).count()
Task.objects.filter(id=task_id).update(total_predictions=predictions_count + 1)
logger.debug(f"Updated total_predictions for {instance.task.id}.")

# =========== END OF PROJECT SUMMARY UPDATES ===========
Expand Down

0 comments on commit 8fb0f9e

Please sign in to comment.