Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1748: 'Generate labels with AI' button in Project, in SC #6789

Open
wants to merge 32 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
40d44bb
feat: DIA-1748: 'Generate labels with AI' button in Project, in SC
Dec 13, 2024
470f2c4
Add signals.py
Dec 13, 2024
77387b1
Add infobanner
Dec 13, 2024
dd8d1d7
Revert data import
Dec 13, 2024
45c9fd0
Revert toolbar
Dec 13, 2024
de6cb0a
Fix contextlog
Dec 16, 2024
1995e79
Fix order of project.save()
Dec 17, 2024
6dc7363
Merge remote-tracking branch 'origin/develop' into fb-dia-1748
Dec 17, 2024
e683be9
Add UserTour model
Dec 17, 2024
ed93c9f
Update states
Dec 17, 2024
ef3c5ee
Merge remote-tracking branch 'origin/develop' into fb-dia-1748
Dec 17, 2024
4b54b09
update data model, add url
Dec 17, 2024
db0aeb9
Merge remote-tracking branch 'origin/develop' into fb-dia-1748
Dec 18, 2024
842c951
Optimize /api/templates
Dec 18, 2024
8e2e90b
Fix tests
Dec 18, 2024
6e21afb
Fix url
Dec 18, 2024
1160bc0
Fix tests
Dec 18, 2024
0a38941
Remove infobanner
Dec 18, 2024
9d7a403
Update content
Dec 19, 2024
e5a5a87
Fix serialiser 'name' error
Dec 19, 2024
89eb6b7
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 20, 2024
a9c9914
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 20, 2024
c4edaac
[submodules] Bump HumanSignal/label-studio-sdk version
matt-bernstein Dec 20, 2024
9706ac1
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 20, 2024
b93490b
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 20, 2024
8bc6cf4
Address comments
Dec 22, 2024
409bfaf
Update joyride content
Dec 22, 2024
fc3eab3
Ignore migrations in linter
Dec 23, 2024
c294c62
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 25, 2024
154e820
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 25, 2024
58b4cb4
[submodules] Bump HumanSignal/label-studio-sdk version
fern-api[bot] Dec 25, 2024
3697f03
Fix target selector
Dec 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions label_studio/core/all_urls.json
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,12 @@
"name": "current-user-whoami",
"decorators": ""
},
{
"url": "/api/current-user/product-tour",
"module": "users.product_tours.api.ProductTourAPI",
"name": "product-tour",
"decorators": ""
},
{
"url": "/data/avatars/<path>",
"module": "django.views.static.serve",
Expand Down
2 changes: 1 addition & 1 deletion label_studio/core/utils/contextlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def create_payload(self, request, response, body):
else:
namespace = request.resolver_match.namespace if request.resolver_match else None
status_code = response.status_code
content_type = response.content_type
content_type = response.content_type if hasattr(response, 'content_type') else None
response_content = self._get_response_content(response)

payload = {
Expand Down
7 changes: 7 additions & 0 deletions label_studio/ml_models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@
'organizations.Organization', on_delete=models.CASCADE, related_name='third_party_model_versions', null=True
)

@property
def project(self):
# TODO: can it be just a property of the model version?
if self.parent_model and self.parent_model.associated_projects.exists():
return self.parent_model.associated_projects.first()
return None

Check warning on line 123 in label_studio/ml_models/models.py

View check run for this annotation

Codecov / codecov/patch

label_studio/ml_models/models.py#L121-L123

Added lines #L121 - L123 were not covered by tests

def has_permission(self, user):
return user.active_organization == self.organization

Expand Down
38 changes: 22 additions & 16 deletions label_studio/projects/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,28 +788,34 @@
return instance


def read_templates_and_groups():
annotation_templates_dir = find_dir('annotation_templates')
configs = []
for config_file in pathlib.Path(annotation_templates_dir).glob('**/*.yml'):
config = read_yaml(config_file)
if settings.VERSION_EDITION == 'Community':
if settings.VERSION_EDITION.lower() != config.get('type', 'community'):
continue

Check warning on line 798 in label_studio/projects/api.py

View check run for this annotation

Codecov / codecov/patch

label_studio/projects/api.py#L798

Added line #L798 was not covered by tests
if config.get('image', '').startswith('/static') and settings.HOSTNAME:
# if hostname set manually, create full image urls
config['image'] = settings.HOSTNAME + config['image']

Check warning on line 801 in label_studio/projects/api.py

View check run for this annotation

Codecov / codecov/patch

label_studio/projects/api.py#L801

Added line #L801 was not covered by tests
configs.append(config)
template_groups_file = find_file(os.path.join('annotation_templates', 'groups.txt'))
with open(template_groups_file, encoding='utf-8') as f:
groups = f.read().splitlines()
logger.debug(f'{len(configs)} templates found.')
return {'templates': configs, 'groups': groups}


class TemplateListAPI(generics.ListAPIView):
parser_classes = (JSONParser, FormParser, MultiPartParser)
permission_required = all_permissions.projects_view
swagger_schema = None
# load this once in memory for performance
templates_and_groups = read_templates_and_groups()

def list(self, request, *args, **kwargs):
annotation_templates_dir = find_dir('annotation_templates')
configs = []
for config_file in pathlib.Path(annotation_templates_dir).glob('**/*.yml'):
config = read_yaml(config_file)
if settings.VERSION_EDITION == 'Community':
if settings.VERSION_EDITION.lower() != config.get('type', 'community'):
continue
if config.get('image', '').startswith('/static') and settings.HOSTNAME:
# if hostname set manually, create full image urls
config['image'] = settings.HOSTNAME + config['image']
configs.append(config)
template_groups_file = find_file(os.path.join('annotation_templates', 'groups.txt'))
with open(template_groups_file, encoding='utf-8') as f:
groups = f.read().splitlines()
logger.debug(f'{len(configs)} templates found.')
return Response({'templates': configs, 'groups': groups})
return Response(self.templates_and_groups)

Check warning on line 818 in label_studio/projects/api.py

View check run for this annotation

Codecov / codecov/patch

label_studio/projects/api.py#L818

Added line #L818 was not covered by tests


class ProjectSampleTask(generics.RetrieveAPIView):
Expand Down
22 changes: 18 additions & 4 deletions label_studio/projects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
annotate_useful_annotation_number,
)
from projects.functions.utils import make_queryset_from_iterable
from projects.signals import ProjectSignals
from tasks.models import (
Annotation,
AnnotationDraft,
Expand Down Expand Up @@ -659,6 +660,10 @@
def _label_config_has_changed(self):
return self.label_config != self.__original_label_config

@property
def label_config_is_not_default(self):
return self.label_config != Project._meta.get_field('label_config').default

Check warning on line 665 in label_studio/projects/models.py

View check run for this annotation

Codecov / codecov/patch

label_studio/projects/models.py#L665

Added line #L665 was not covered by tests

def should_none_model_version(self, model_version):
"""
Returns True if the model version provided matches the object's model version,
Expand Down Expand Up @@ -728,7 +733,12 @@
exists = True if self.pk else False
project_with_config_just_created = not exists and self.label_config

if self._label_config_has_changed() or project_with_config_just_created:
label_config_has_changed = self._label_config_has_changed()
logger.debug(
f'Label config has changed: {label_config_has_changed}, original: {self.__original_label_config}, new: {self.label_config}'
)

if label_config_has_changed or project_with_config_just_created:
self.data_types = extract_data_types(self.label_config)
self.parsed_label_config = parse_config(self.label_config)
self.label_config_hash = hash(str(self.parsed_label_config))
Expand All @@ -740,11 +750,15 @@
if update_fields is not None:
update_fields = {'control_weights'}.union(update_fields)

if self._label_config_has_changed():
self.__original_label_config = self.label_config

super(Project, self).save(*args, update_fields=update_fields, **kwargs)

if label_config_has_changed:
# save the new label config for future comparison
self.__original_label_config = self.label_config
# if tasks are already imported, emit signal that project is configured and ready for labeling
if self.num_tasks > 0:
ProjectSignals.post_label_config_and_import_tasks.send(sender=Project, project=self)
triklozoid marked this conversation as resolved.
Show resolved Hide resolved

if not exists:
steps = ProjectOnboardingSteps.objects.all()
objs = [ProjectOnboarding(project=self, step=step) for step in steps]
Expand Down
18 changes: 18 additions & 0 deletions label_studio/projects/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from django.dispatch import Signal


class ProjectSignals:
"""
Signals for project: implements observer pattern for custom signals.
Example:

# publisher
ProjectSignals.my_signal.send(sender=self, project=project)

# observer
@receiver(ProjectSignals.my_signal)
def my_observer(sender, **kwargs):
...
"""

post_label_config_and_import_tasks = Signal()
5 changes: 5 additions & 0 deletions label_studio/tasks/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def create(self, validated_data):

self.post_process_annotations(user, db_annotations, 'imported')
self.post_process_tasks(self.project.id, [t.id for t in self.db_tasks])
self.post_process_custom_callback(self.project.id, user)

if flag_set('fflag_feat_back_lsdv_5307_import_reviews_drafts_29062023_short', user=ff_user):
with transaction.atomic():
Expand Down Expand Up @@ -586,6 +587,10 @@ def post_process_tasks(user, db_tasks):
def add_annotation_fields(body, user, action):
return body

@staticmethod
def post_process_custom_callback(project_id, user):
pass

class Meta:
model = Task
fields = '__all__'
Expand Down
27 changes: 27 additions & 0 deletions label_studio/users/migrations/0010_userproducttour.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Generated by Django 4.2.15 on 2024-12-17 18:56

from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('users', '0009_auto_20231201_0001'),
]

operations = [
migrations.CreateModel(
name='UserProductTour',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='Unique identifier for the product tour. Name must match the config name.', max_length=256)),
('state', models.CharField(choices=[('ready', 'ready'), ('completed', 'completed'), ('skipped', 'skipped')], default='ready', help_text='Current state of the tour for this user: "ready" when tour is initiated, "completed" when user finishes the tour, "skipped" when user cancels the tour.', max_length=32)),
('interaction_data', models.JSONField(blank=True, default=dict, help_text='Additional data about user interaction with the tour')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='When this tour record was created')),
('updated_at', models.DateTimeField(auto_now=True, help_text='When this tour record was last updated')),
('user', models.ForeignKey(help_text='User who interacted with the tour', on_delete=django.db.models.deletion.CASCADE, related_name='tours', to=settings.AUTH_USER_MODEL)),
],
),
]
44 changes: 44 additions & 0 deletions label_studio/users/product_tours/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging

from rest_framework import generics
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAuthenticated
from users.product_tours.models import UserProductTour

from .serializers import UserProductTourSerializer

logger = logging.getLogger(__name__)


class ProductTourAPI(generics.RetrieveUpdateAPIView):
permission_classes = (IsAuthenticated,)
serializer_class = UserProductTourSerializer
swagger_schema = None

def get_tour_name(self):
name = self.request.query_params.get('name')
if not name:
raise ValidationError('Name is required')
# normalize name for subsequent checks
return name.replace('-', '_').lower()

Check warning on line 23 in label_studio/users/product_tours/api.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/api.py#L23

Added line #L23 was not covered by tests

def get_serializer_context(self):
context = super().get_serializer_context()
context['name'] = self.get_tour_name()
return context

Check warning on line 28 in label_studio/users/product_tours/api.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/api.py#L26-L28

Added lines #L26 - L28 were not covered by tests

def get_object(self):
name = self.get_tour_name()

# TODO: add additional checks, e.g. user agent, role, etc.

tour = UserProductTour.objects.filter(user=self.request.user, name=name).first()
if not tour:
logger.info(f'Product tour {name} not found for user {self.request.user.id}. Creating new tour.')
tour_serializer = self.get_serializer(data={'user': self.request.user.id, 'name': name})
tour_serializer.is_valid(raise_exception=True)
tour = tour_serializer.save()

Check warning on line 40 in label_studio/users/product_tours/api.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/api.py#L35-L40

Added lines #L35 - L40 were not covered by tests
else:
logger.info(f'Product tour {name} requested for user {self.request.user.id}.')

Check warning on line 42 in label_studio/users/product_tours/api.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/api.py#L42

Added line #L42 was not covered by tests
triklozoid marked this conversation as resolved.
Show resolved Hide resolved

return tour

Check warning on line 44 in label_studio/users/product_tours/api.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/api.py#L44

Added line #L44 was not covered by tests
23 changes: 23 additions & 0 deletions label_studio/users/product_tours/configs/create_prompt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
steps:
- target: body
placement: center
title: '<div><img src="/static/images/prompts-tour-banner.svg" style="width: 100%; margin-bottom: 16px;"/><div style="font-size: 24px; color: var(--grape_500)">Auto-Label with AI</div></div>'
content: '<div style="width: 100%">We’ve set up an initial prompt and supplied some OpenAI credits to help you get started with auto-labeling your project.<br/><br/>Let us show you around!</div>'

- target: .cm-editor
placement: right
title: '<div style="color: var(--sand_600)">Step 1 of 3</div>'
content: '<div style="width: 100%;">We’ve gone ahead and generated a prompt for you based on your project’s labeling configuration.<br/><br/> Feel free to adjust it!</div>'

- target: '[data-testid="evaluate-model-button"]'
placement: top
title: '<div style="color: var(--sand_600)">Step 2 of 3</div>'
content: '<div style="width: 100%;">That''s it! After saving, just click on <b>Evaluate</b> to start getting predictions for your tasks.</div>'

- target: body
placement: center
title: '<div>🎉 <br/><b> That’s it! </b></div>'
content: 'Now you may convert (or correct) the predictions to annotations by opening the tasks on your project.'
locale:
last: 'Go to Project'

9 changes: 9 additions & 0 deletions label_studio/users/product_tours/configs/prompts_page.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
steps:
- target: body
placement: center
title: '<div style="font-size: 24px; color: var(--grape_500)">Welcome to Prompts!</div>'
content: '<div style="width: 100%">Set up Prompts to help you rapidly pre-label projects.</div>'

- target: ".lsf-data-table__body-row[data-index='0'] .lsf-models-list__model-name"
content: '<div style="width: 100%">Click on this sample Prompt to get you started.</div>'
isFixed: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
steps:
- target: .lsf-auto-labeling-button
placement: bottom
title: 'Great news!'
content: 'You can now rapidly label this project using Label Studio Prompts with the power of LLMs.<br/><br/>Click “Auto-Label Tasks” to automatically label a sample of 20 tasks.'
disableBeacon: true
53 changes: 53 additions & 0 deletions label_studio/users/product_tours/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from enum import Enum
from typing import Any, Dict, Optional

from django.db import models
from pydantic import BaseModel, Field


class ProductTourState(str, Enum):
triklozoid marked this conversation as resolved.
Show resolved Hide resolved
READY = 'ready'
COMPLETED = 'completed'
SKIPPED = 'skipped'


class ProductTourInteractionData(BaseModel):
"""Pydantic model for validating tour interaction data"""

index: Optional[int] = Field(None, description='Step number where tour was completed')
action: Optional[str] = Field(None, description='Action taken during the tour')
type: Optional[str] = Field(None, description='Type of interaction')
status: Optional[str] = Field(None, description='Status of the interaction')
additional_data: Optional[Dict[str, Any]] = Field(
default_factory=dict, description='Extensible field for additional interaction data'
)


class UserProductTour(models.Model):
"""Stores product tour state and interaction data for users"""

user = models.ForeignKey(
'User', on_delete=models.CASCADE, related_name='tours', help_text='User who interacted with the tour'
)

name = models.CharField(
max_length=256, help_text='Unique identifier for the product tour. Name must match the config name.'
)

state = models.CharField(
max_length=32,
choices=[(state.value, state.value) for state in ProductTourState],
default=ProductTourState.READY.value,
help_text='Current state of the tour for this user: "ready" when tour is initiated, "completed" when user finishes the tour, "skipped" when user cancels the tour.',
)

interaction_data = models.JSONField(
default=dict, blank=True, help_text='Additional data about user interaction with the tour'
)

created_at = models.DateTimeField(auto_now_add=True, help_text='When this tour record was created')

updated_at = models.DateTimeField(auto_now=True, help_text='When this tour record was last updated')

def __str__(self):
return f'{self.user.email} - {self.name} ({self.state})'

Check warning on line 53 in label_studio/users/product_tours/models.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/models.py#L53

Added line #L53 was not covered by tests
48 changes: 48 additions & 0 deletions label_studio/users/product_tours/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pathlib
from functools import cached_property

import yaml
from rest_framework import serializers

from .models import ProductTourInteractionData, UserProductTour

PRODUCT_TOURS_CONFIGS_DIR = pathlib.Path(__file__).parent / 'configs'


class UserProductTourSerializer(serializers.ModelSerializer):
steps = serializers.SerializerMethodField(read_only=True)

class Meta:
model = UserProductTour
fields = '__all__'

@cached_property
def available_tours(self):
return {pathlib.Path(f).stem for f in PRODUCT_TOURS_CONFIGS_DIR.iterdir()}

Check warning on line 21 in label_studio/users/product_tours/serializers.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/serializers.py#L21

Added line #L21 was not covered by tests

def validate_name(self, value):

if value not in self.available_tours:
raise serializers.ValidationError(

Check warning on line 26 in label_studio/users/product_tours/serializers.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/serializers.py#L25-L26

Added lines #L25 - L26 were not covered by tests
f'Product tour {value} not found. Available tours: {self.available_tours}'
)

return value

Check warning on line 30 in label_studio/users/product_tours/serializers.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/serializers.py#L30

Added line #L30 was not covered by tests

def load_tour_config(self):
# TODO: get product tour from yaml file. Later we move it to remote storage, e.g. S3
filepath = PRODUCT_TOURS_CONFIGS_DIR / f'{self.context["name"]}.yml'
with open(filepath, 'r') as f:
return yaml.safe_load(f)

Check warning on line 36 in label_studio/users/product_tours/serializers.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/serializers.py#L34-L36

Added lines #L34 - L36 were not covered by tests

def get_steps(self, obj):
config = self.load_tour_config()
return config.get('steps', [])

Check warning on line 40 in label_studio/users/product_tours/serializers.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/serializers.py#L39-L40

Added lines #L39 - L40 were not covered by tests

def validate_interaction_data(self, value):
try:

Check warning on line 43 in label_studio/users/product_tours/serializers.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/serializers.py#L43

Added line #L43 was not covered by tests
# Validate interaction data using pydantic model
ProductTourInteractionData(**value)
return value
except Exception as e:
raise serializers.ValidationError(f'Invalid interaction data format: {str(e)}')

Check warning on line 48 in label_studio/users/product_tours/serializers.py

View check run for this annotation

Codecov / codecov/patch

label_studio/users/product_tours/serializers.py#L45-L48

Added lines #L45 - L48 were not covered by tests
Loading
Loading