Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1748: 'Generate labels with AI' button in Project, in SC #6789

Open
wants to merge 32 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
40d44bb
feat: DIA-1748: 'Generate labels with AI' button in Project, in SC
Dec 13, 2024
470f2c4
Add signals.py
Dec 13, 2024
77387b1
Add infobanner
Dec 13, 2024
dd8d1d7
Revert data import
Dec 13, 2024
45c9fd0
Revert toolbar
Dec 13, 2024
de6cb0a
Fix contextlog
Dec 16, 2024
1995e79
Fix order of project.save()
Dec 17, 2024
6dc7363
Merge remote-tracking branch 'origin/develop' into fb-dia-1748
Dec 17, 2024
e683be9
Add UserTour model
Dec 17, 2024
ed93c9f
Update states
Dec 17, 2024
ef3c5ee
Merge remote-tracking branch 'origin/develop' into fb-dia-1748
Dec 17, 2024
4b54b09
update data model, add url
Dec 17, 2024
db0aeb9
Merge remote-tracking branch 'origin/develop' into fb-dia-1748
Dec 18, 2024
842c951
Optimize /api/templates
Dec 18, 2024
8e2e90b
Fix tests
Dec 18, 2024
6e21afb
Fix url
Dec 18, 2024
1160bc0
Fix tests
Dec 18, 2024
0a38941
Remove infobanner
Dec 18, 2024
9d7a403
Update content
Dec 19, 2024
e5a5a87
Fix serialiser 'name' error
Dec 19, 2024
89eb6b7
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 20, 2024
a9c9914
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 20, 2024
c4edaac
[submodules] Bump HumanSignal/label-studio-sdk version
matt-bernstein Dec 20, 2024
9706ac1
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 20, 2024
b93490b
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 20, 2024
8bc6cf4
Address comments
Dec 22, 2024
409bfaf
Update joyride content
Dec 22, 2024
fc3eab3
Ignore migrations in linter
Dec 23, 2024
c294c62
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 25, 2024
154e820
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 25, 2024
58b4cb4
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 25, 2024
3697f03
Fix target selector
Dec 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions label_studio/core/all_urls.json
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,12 @@
"name": "current-user-whoami",
"decorators": ""
},
{
"url": "/api/current-user/product-tour",
"module": "users.product_tours.api.ProductTourAPI",
"name": "product-tour",
"decorators": ""
},
{
"url": "/data/avatars/<path>",
"module": "django.views.static.serve",
Expand Down
2 changes: 1 addition & 1 deletion label_studio/core/utils/contextlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def create_payload(self, request, response, body):
else:
namespace = request.resolver_match.namespace if request.resolver_match else None
status_code = response.status_code
content_type = response.content_type
content_type = response.content_type if hasattr(response, 'content_type') else None
response_content = self._get_response_content(response)

payload = {
Expand Down
7 changes: 7 additions & 0 deletions label_studio/ml_models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ class ThirdPartyModelVersion(ModelVersion):
'organizations.Organization', on_delete=models.CASCADE, related_name='third_party_model_versions', null=True
)

@property
def project(self):
# TODO: can it be just a property of the model version?
if self.parent_model and self.parent_model.associated_projects.exists():
return self.parent_model.associated_projects.first()
return None

def has_permission(self, user):
return user.active_organization == self.organization

Expand Down
38 changes: 22 additions & 16 deletions label_studio/projects/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,28 +788,34 @@ def perform_create(self, serializer):
return instance


def read_templates_and_groups():
annotation_templates_dir = find_dir('annotation_templates')
configs = []
for config_file in pathlib.Path(annotation_templates_dir).glob('**/*.yml'):
config = read_yaml(config_file)
if settings.VERSION_EDITION == 'Community':
if settings.VERSION_EDITION.lower() != config.get('type', 'community'):
continue
if config.get('image', '').startswith('/static') and settings.HOSTNAME:
# if hostname set manually, create full image urls
config['image'] = settings.HOSTNAME + config['image']
configs.append(config)
template_groups_file = find_file(os.path.join('annotation_templates', 'groups.txt'))
with open(template_groups_file, encoding='utf-8') as f:
groups = f.read().splitlines()
logger.debug(f'{len(configs)} templates found.')
return {'templates': configs, 'groups': groups}


class TemplateListAPI(generics.ListAPIView):
parser_classes = (JSONParser, FormParser, MultiPartParser)
permission_required = all_permissions.projects_view
swagger_schema = None
# load this once in memory for performance
templates_and_groups = read_templates_and_groups()

def list(self, request, *args, **kwargs):
annotation_templates_dir = find_dir('annotation_templates')
configs = []
for config_file in pathlib.Path(annotation_templates_dir).glob('**/*.yml'):
config = read_yaml(config_file)
if settings.VERSION_EDITION == 'Community':
if settings.VERSION_EDITION.lower() != config.get('type', 'community'):
continue
if config.get('image', '').startswith('/static') and settings.HOSTNAME:
# if hostname set manually, create full image urls
config['image'] = settings.HOSTNAME + config['image']
configs.append(config)
template_groups_file = find_file(os.path.join('annotation_templates', 'groups.txt'))
with open(template_groups_file, encoding='utf-8') as f:
groups = f.read().splitlines()
logger.debug(f'{len(configs)} templates found.')
return Response({'templates': configs, 'groups': groups})
return Response(self.templates_and_groups)


class ProjectSampleTask(generics.RetrieveAPIView):
Expand Down
22 changes: 18 additions & 4 deletions label_studio/projects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
annotate_useful_annotation_number,
)
from projects.functions.utils import make_queryset_from_iterable
from projects.signals import ProjectSignals
from tasks.models import (
Annotation,
AnnotationDraft,
Expand Down Expand Up @@ -659,6 +660,10 @@ def display_count(count: int, type: str) -> Optional[str]:
def _label_config_has_changed(self):
return self.label_config != self.__original_label_config

@property
def label_config_is_not_default(self):
return self.label_config != Project._meta.get_field('label_config').default

def should_none_model_version(self, model_version):
"""
Returns True if the model version provided matches the object's model version,
Expand Down Expand Up @@ -728,7 +733,12 @@ def save(self, *args, update_fields=None, recalc=True, **kwargs):
exists = True if self.pk else False
project_with_config_just_created = not exists and self.label_config

if self._label_config_has_changed() or project_with_config_just_created:
label_config_has_changed = self._label_config_has_changed()
logger.debug(
f'Label config has changed: {label_config_has_changed}, original: {self.__original_label_config}, new: {self.label_config}'
)

if label_config_has_changed or project_with_config_just_created:
self.data_types = extract_data_types(self.label_config)
self.parsed_label_config = parse_config(self.label_config)
self.label_config_hash = hash(str(self.parsed_label_config))
Expand All @@ -740,11 +750,15 @@ def save(self, *args, update_fields=None, recalc=True, **kwargs):
if update_fields is not None:
update_fields = {'control_weights'}.union(update_fields)

if self._label_config_has_changed():
self.__original_label_config = self.label_config

super(Project, self).save(*args, update_fields=update_fields, **kwargs)

if label_config_has_changed:
# save the new label config for future comparison
self.__original_label_config = self.label_config
# if tasks are already imported, emit signal that project is configured and ready for labeling
if self.num_tasks > 0:
ProjectSignals.post_label_config_and_import_tasks.send(sender=Project, project=self)
triklozoid marked this conversation as resolved.
Show resolved Hide resolved

if not exists:
steps = ProjectOnboardingSteps.objects.all()
objs = [ProjectOnboarding(project=self, step=step) for step in steps]
Expand Down
18 changes: 18 additions & 0 deletions label_studio/projects/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from django.dispatch import Signal


class ProjectSignals:
"""
Signals for project: implements observer pattern for custom signals.
Example:

# publisher
ProjectSignals.my_signal.send(sender=self, project=project)

# observer
@receiver(ProjectSignals.my_signal)
def my_observer(sender, **kwargs):
...
"""

post_label_config_and_import_tasks = Signal()
5 changes: 5 additions & 0 deletions label_studio/tasks/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def create(self, validated_data):

self.post_process_annotations(user, db_annotations, 'imported')
self.post_process_tasks(self.project.id, [t.id for t in self.db_tasks])
self.post_process_custom_callback(self.project.id, user)

if flag_set('fflag_feat_back_lsdv_5307_import_reviews_drafts_29062023_short', user=ff_user):
with transaction.atomic():
Expand Down Expand Up @@ -586,6 +587,10 @@ def post_process_tasks(user, db_tasks):
def add_annotation_fields(body, user, action):
return body

@staticmethod
def post_process_custom_callback(project_id, user):
pass

class Meta:
model = Task
fields = '__all__'
Expand Down
29 changes: 29 additions & 0 deletions label_studio/users/migrations/0010_userproducttour.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Generated by Django 4.2.15 on 2024-12-22 09:54

from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django_migration_linter as linter


class Migration(migrations.Migration):

dependencies = [
('users', '0009_auto_20231201_0001'),
]

operations = [
linter.IgnoreMigration(),
migrations.CreateModel(
name='UserProductTour',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='Unique identifier for the product tour. Name must match the config name.', max_length=256, verbose_name='Name')),
('state', models.CharField(choices=[('ready', 'Ready'), ('completed', 'Completed'), ('skipped', 'Skipped')], default='ready', help_text='Current state of the tour for this user. Available options: ready (Ready), completed (Completed), skipped (Skipped)', max_length=32, verbose_name='State')),
('interaction_data', models.JSONField(blank=True, default=dict, help_text='Additional data about user interaction with the tour', verbose_name='Interaction Data')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='When this tour record was created')),
('updated_at', models.DateTimeField(auto_now=True, help_text='When this tour record was last updated')),
('user', models.ForeignKey(help_text='User who interacted with the tour', on_delete=django.db.models.deletion.CASCADE, related_name='tours', to=settings.AUTH_USER_MODEL)),
],
),
]
44 changes: 44 additions & 0 deletions label_studio/users/product_tours/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging

from rest_framework import generics
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAuthenticated
from users.product_tours.models import UserProductTour

from .serializers import UserProductTourSerializer

logger = logging.getLogger(__name__)


class ProductTourAPI(generics.RetrieveUpdateAPIView):
permission_classes = (IsAuthenticated,)
serializer_class = UserProductTourSerializer
swagger_schema = None

def get_tour_name(self):
name = self.request.query_params.get('name')
if not name:
raise ValidationError('Name is required')
# normalize name for subsequent checks
return name.replace('-', '_').lower()

def get_serializer_context(self):
context = super().get_serializer_context()
context['name'] = self.get_tour_name()
return context

def get_object(self):
name = self.get_tour_name()

# TODO: add additional checks, e.g. user agent, role, etc.

tour = UserProductTour.objects.filter(user=self.request.user, name=name).first()
if not tour:
logger.debug(f'Product tour {name} not found for user {self.request.user.id}. Creating new tour.')
tour_serializer = self.get_serializer(data={'user': self.request.user.id, 'name': name})
tour_serializer.is_valid(raise_exception=True)
tour = tour_serializer.save()
else:
logger.debug(f'Product tour {name} requested for user {self.request.user.id}.')

return tour
23 changes: 23 additions & 0 deletions label_studio/users/product_tours/configs/create_prompt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
steps:
- target: body
placement: center
title: '<div><img src="/static/images/prompts-tour-banner.svg" style="width: 100%; margin-bottom: 16px;"/><div style="font-size: 24px; color: var(--grape_500)">Auto-Label with AI</div></div>'
content: '<div style="width: 100%">We’ve set up an initial prompt and supplied some OpenAI credits to help you get started with auto-labeling your project.<br/><br/>Let us show you around!</div>'

- target: .cm-editor
placement: right
title: '<div style="color: var(--sand_600)">Step 1 of 3</div>'
content: '<div style="width: 100%;">We’ve gone ahead and generated a prompt for you based on your project’s labeling configuration.<br/><br/> Feel free to adjust it!</div>'

- target: '[data-testid="evaluate-model-button"]'
placement: top
title: '<div style="color: var(--sand_600)">Step 2 of 3</div>'
content: '<div style="width: 100%;">That''s it! After saving, just click on <b>Evaluate</b> to start getting predictions for your tasks.</div>'

- target: body
placement: center
title: '<div>🎉 <br/><b> That’s it! </b></div>'
content: 'Now you may convert (or correct) the predictions to annotations by opening the tasks on your project.'
locale:
last: 'Go to Project'

9 changes: 9 additions & 0 deletions label_studio/users/product_tours/configs/prompts_page.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
steps:
- target: body
placement: center
title: '<div style="font-size: 24px; color: var(--grape_500)">Welcome to Prompts!</div>'
content: '<div style="width: 100%">Set up Prompts to help you rapidly pre-label projects.</div>'

- target: ".lsf-data-table__body-row[data-index='0'] .lsf-models-list__model-name"
content: '<div style="width: 100%">Click on this sample Prompt to get you started.</div>'
isFixed: true
14 changes: 14 additions & 0 deletions label_studio/users/product_tours/configs/show_autolabel_button.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
steps:
- target: '[data-testid="auto-labeling-button"]'
placement: bottom
title: 'Great news!'
content: >
You can now rapidly label this project using Prompts.
<br/><br/>
Click "Auto-Label Tasks" to set up LLM powered labeling in under a minute.
<br/><br/>
We've provided some OpenAI credits to get you started.
disableBeacon: true
locale:
last: "OK"
skip: "Don't show this message again."
57 changes: 57 additions & 0 deletions label_studio/users/product_tours/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Dict, Optional

from django.db import models
from django.utils.translation import gettext_lazy as _
from pydantic import BaseModel, Field


class ProductTourState(models.TextChoices):
READY = 'ready', _('Ready')
COMPLETED = 'completed', _('Completed')
SKIPPED = 'skipped', _('Skipped')


class ProductTourInteractionData(BaseModel):
"""Pydantic model for validating tour interaction data"""

index: Optional[int] = Field(None, description='Step number where tour was completed')
action: Optional[str] = Field(None, description='Action taken during the tour')
type: Optional[str] = Field(None, description='Type of interaction')
status: Optional[str] = Field(None, description='Status of the interaction')
additional_data: Optional[Dict[str, Any]] = Field(
default_factory=dict, description='Extensible field for additional interaction data'
)


class UserProductTour(models.Model):
"""Stores product tour state and interaction data for users"""

user = models.ForeignKey(
'User', on_delete=models.CASCADE, related_name='tours', help_text='User who interacted with the tour'
)

name = models.CharField(
_('Name'), max_length=256, help_text='Unique identifier for the product tour. Name must match the config name.'
)

state = models.CharField(
_('State'),
max_length=32,
choices=ProductTourState.choices,
default=ProductTourState.READY,
help_text=f'Current state of the tour for this user. Available options: {", ".join(f"{k} ({v})" for k,v in ProductTourState.choices)}',
)

interaction_data = models.JSONField(
_('Interaction Data'),
default=dict,
blank=True,
help_text='Additional data about user interaction with the tour',
)

created_at = models.DateTimeField(auto_now_add=True, help_text='When this tour record was created')

updated_at = models.DateTimeField(auto_now=True, help_text='When this tour record was last updated')

def __str__(self):
return f'{self.user.email} - {self.name} ({self.state})'
Loading
Loading