Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] argilla: link user responses and suggestions to record #5518

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions argilla/src/argilla/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,17 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
fields=model.fields,
metadata={meta.name: meta.value for meta in model.metadata},
vectors={vector.name: vector.vector_values for vector in model.vectors},
# Responses and their models are not aligned 1-1.
responses=[
response
for response_model in model.responses
for response in UserResponse.from_model(response_model, dataset=dataset)
],
suggestions=[Suggestion.from_model(model=suggestion, dataset=dataset) for suggestion in model.suggestions],
_dataset=dataset,
responses=[],
suggestions=[],
)

# set private attributes
instance._dataset = dataset
instance._model.id = model.id
instance._model.status = model.status
# Responses and suggestions are computed separately based on the record model
instance.responses.from_models(model.responses)
instance.suggestions.from_models(model.suggestions)

return instance

Expand Down Expand Up @@ -349,11 +347,10 @@ class RecordResponses(Iterable[Response]):
def __init__(self, responses: List[Response], record: Record) -> None:
self.record = record
self.__responses_by_question_name = defaultdict(list)
self.__responses = []

self.__responses = responses or []
for response in self.__responses:
response.record = self.record
self.__responses_by_question_name[response.question_name].append(response)
for response in responses or []:
self.add(response)

def __iter__(self):
return iter(self.__responses)
Expand Down Expand Up @@ -409,6 +406,11 @@ def _check_response_already_exists(self, response: Response) -> None:
f"already found. The responses for the same question name do not support more than one user"
)

def from_models(self, responses: List[UserResponseModel]) -> None:
for response_model in responses:
for response in UserResponse.from_model(response_model, record=self.record):
self.add(response)


class RecordSuggestions(Iterable[Suggestion]):
"""This is a container class for the suggestions of a Record.
Expand Down Expand Up @@ -461,3 +463,7 @@ def add(self, suggestion: Suggestion) -> None:
"""
suggestion.record = self.record
self._suggestion_by_question_name[suggestion.question_name] = suggestion

def from_models(self, suggestions: List[SuggestionModel]) -> None:
for suggestion_model in suggestions:
self.add(Suggestion.from_model(suggestion_model, record=self.record))
40 changes: 32 additions & 8 deletions argilla/src/argilla/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from argilla.settings import RankingQuestion

if TYPE_CHECKING:
from argilla import Argilla, Dataset, Record
from argilla import Argilla, Record

__all__ = ["Response", "UserResponse", "ResponseStatus"]

Expand Down Expand Up @@ -71,12 +71,22 @@ def __init__(
if isinstance(status, str):
status = ResponseStatus(status)

self.record = _record
self._record = _record
self.question_name = question_name
self.value = value
self.user_id = user_id
self.status = status

@property
def record(self) -> "Record":
"""Returns the record associated with the response"""
return self._record

@record.setter
def record(self, record: "Record") -> None:
"""Sets the record associated with the response"""
self._record = record

def serialize(self) -> dict[str, Any]:
"""Serializes the Response to a dictionary. This is principally used for sending the response to the API, \
but can be used for data wrangling or manual export.
Expand Down Expand Up @@ -138,6 +148,9 @@ def __init__(
user_id=self._compute_user_id_from_responses(responses),
)

for response in responses:
response.record = _record

def __iter__(self) -> Iterable[Response]:
return iter(self.responses)

Expand All @@ -164,19 +177,29 @@ def user_id(self, user_id: UUID) -> None:
@property
def responses(self) -> List[Response]:
"""Returns the list of responses"""
return self.__model_as_responses_list(self._model)
return self.__model_as_responses_list(self._model, record=self._record)

@property
def record(self) -> "Record":
"""Returns the record associated with the response"""
return self._record

@record.setter
def record(self, record: "Record") -> None:
"""Sets the record associated with the response"""
self._record = record

@classmethod
def from_model(cls, model: UserResponseModel, dataset: "Dataset") -> "UserResponse":
def from_model(cls, model: UserResponseModel, record: "Record") -> "UserResponse":
"""Creates a UserResponse from a ResponseModel"""
responses = cls.__model_as_responses_list(model)
responses = cls.__model_as_responses_list(model, record=record)
for response in responses:
question = dataset.settings.questions[response.question_name]
question = record.dataset.settings.questions[response.question_name]
# We need to adapt the ranking question value to the expected format
if isinstance(question, RankingQuestion):
response.value = cls.__ranking_from_model_value(response.value) # type: ignore

return cls(responses=responses)
return cls(responses=responses, _record=record)

def api_model(self):
"""Returns the model that is used to interact with the API"""
Expand Down Expand Up @@ -223,7 +246,7 @@ def __responses_as_model_values(responses: List[Response]) -> Dict[str, Dict[str
return {answer.question_name: {"value": answer.value} for answer in responses}

@classmethod
def __model_as_responses_list(cls, model: UserResponseModel) -> List[Response]:
def __model_as_responses_list(cls, model: UserResponseModel, record: "Record") -> List[Response]:
"""Creates a list of Responses from a UserResponseModel without changing the format of the values"""

return [
Expand All @@ -232,6 +255,7 @@ def __model_as_responses_list(cls, model: UserResponseModel) -> List[Response]:
value=value["value"],
user_id=model.user_id,
status=model.status,
_record=record,
)
for question_name, value in model.values.items()
]
Expand Down
19 changes: 14 additions & 5 deletions argilla/src/argilla/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from argilla.settings import RankingQuestion

if TYPE_CHECKING:
from argilla import Dataset, QuestionType, Record
from argilla import QuestionType, Record

__all__ = ["Suggestion"]

Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
if value is None:
raise ValueError("value is required")

self.record = _record
self._record = _record
self._model = SuggestionModel(
question_name=question_name,
value=value,
Expand Down Expand Up @@ -104,13 +104,22 @@ def agent(self) -> Optional[str]:
def agent(self, value: str) -> None:
self._model.agent = value

@property
def record(self) -> Optional["Record"]:
"""The record that the suggestion is for."""
return self._record

@record.setter
def record(self, value: "Record") -> None:
self._record = value

@classmethod
def from_model(cls, model: SuggestionModel, dataset: "Dataset") -> "Suggestion":
question = dataset.settings.questions[model.question_id]
def from_model(cls, model: SuggestionModel, record: "Record") -> "Suggestion":
question = record.dataset.settings.questions[model.question_id]
model.question_name = question.name
model.value = cls.__from_model_value(model.value, question)

instance = cls(question.name, model.value)
instance = cls(question.name, model.value, _record=record)
instance._model = model

return instance
Expand Down
11 changes: 8 additions & 3 deletions argilla/tests/unit/test_resources/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest

from argilla import UserResponse, Response, Dataset, Workspace
from argilla import UserResponse, Response, Dataset, Workspace, Record
from argilla._models import UserResponseModel, ResponseStatus


Expand Down Expand Up @@ -89,9 +89,14 @@ def test_create_user_response_with_multiple_user_id(self):

def test_create_user_response_from_draft_response_model_without_values(self):
model = UserResponseModel(values={}, status=ResponseStatus.draft, user=uuid.uuid4())
response = UserResponse.from_model(
model=model, dataset=Dataset(name="burr", workspace=Workspace(name="test", id=uuid.uuid4()))

record = Record(
fields={"question": "answer"},
_dataset=Dataset(name="burr", workspace=Workspace(name="test", id=uuid.uuid4())),
)

response = UserResponse.from_model(model=model, record=record)

assert len(response.responses) == 0
assert response.user_id is None
assert response.status == ResponseStatus.draft