Skip to content

Commit

Permalink
feat: chat system role support (#790)
Browse files Browse the repository at this point in the history
  • Loading branch information
Elliott authored Nov 7, 2023
1 parent 00a59d0 commit c2bc78d
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 15 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.2.0"
__version__ = "1.2.1"

import sys
from typing import Any, List, Optional
Expand Down
10 changes: 9 additions & 1 deletion dataquality/integrations/seq2seq/formatters/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from dataquality.integrations.seq2seq.formatters.base import BaseFormatter
from dataquality.schemas.seq2seq import Seq2SeqInputCols as S2SIC

# HF tokenizers don't support newlines, so we use a token to represent them
# Example of tokenizer without the NEWLINE token:
Expand All @@ -31,6 +32,7 @@ class ChatFormatter(BaseFormatter):
# Chat roles
user: str = "User"
assistant: str = "Chatbot"
system: str = "System"

def format_sample(
self, sample: Dict[str, Any], idx: Optional[int] = None
Expand Down Expand Up @@ -74,9 +76,14 @@ def format_sample(
turn_data: Dict[str, Any] = {"chat_id": None, "turn_id": None}
turn_id = 1
turn_default_cols = [self.role_col, self.content_col]
system_prompts: Dict[str, str] = {}
for turn in turns:
role = turn[self.role_col]
content = turn[self.content_col]
if role == self.system:
system_prompts[self.system] = content
continue

# Add metadata to each turn
turn_meta = {
f"{role}_{col}": turn[col]
Expand All @@ -94,6 +101,7 @@ def format_sample(
turn_data[self.target_col] = content
turn_data["turn_id"] = turn_id
turn_data["chat_id"] = idx
turn_data[S2SIC.system_prompts] = system_prompts
# Add sample level metadata
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta'
turn_data.update(metadata)
Expand Down Expand Up @@ -195,6 +203,6 @@ def format_sample(
]
# If both are -1, we just take the last max_input_tokens tokens
start_index = min(non_negative) if non_negative else -self.max_input_tokens
user_inputs[i] = parsed_history[start_index:]
user_inputs[i] = f"{parsed_history[start_index:]}"

return formatted_sample
1 change: 1 addition & 0 deletions dataquality/loggers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class BaseLoggerAttributes(str, Enum):
generated_output = "generated_output"
input_cutoff = "input_cutoff"
target_cutoff = "target_cutoff"
system_prompts = "system_prompts"

@staticmethod
def get_valid() -> List[str]:
Expand Down
26 changes: 15 additions & 11 deletions dataquality/loggers/data_logger/seq2seq/seq2seq_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
seq2seq_logger_config,
)
from dataquality.schemas.dataframe import BaseLoggerDataFrames
from dataquality.schemas.seq2seq import Seq2SeqInputCols as C
from dataquality.schemas.seq2seq import Seq2SeqInputCols as S2SIC
from dataquality.schemas.split import Split
from dataquality.utils.seq2seq.generation import (
add_generated_output_to_df,
Expand Down Expand Up @@ -124,17 +124,21 @@ def validate_and_format(self) -> None:
self.logger_config.id_to_tokens[self.token_map_key].update(id_to_tokens)

def _get_input_df(self) -> DataFrame:
return vaex.from_dict(
data = vaex.from_dict(
{
C.id.value: self.ids,
C.text.value: self.texts,
C.label.value: self.labels,
C.split_.value: [self.split] * len(self.ids),
C.token_label_positions.value: pa.array(self.token_label_positions),
C.token_label_offsets.value: pa.array(self.token_label_offsets),
S2SIC.id.value: self.ids,
S2SIC.text.value: self.texts,
S2SIC.label.value: self.labels,
S2SIC.split_.value: [self.split] * len(self.ids),
S2SIC.token_label_positions.value: pa.array(self.token_label_positions),
S2SIC.token_label_offsets.value: pa.array(self.token_label_offsets),
**self.meta,
}
)
if S2SIC.system_prompts in self.meta:
# We must store nested dicts as pyarrow arrays to support vaex export
data[S2SIC.system_prompts.value] = pa.array(self.meta[S2SIC.system_prompts])
return data

def _log_df(
self,
Expand Down Expand Up @@ -199,7 +203,7 @@ def get_valid_attributes() -> List[str]:
Returns a list of valid attributes that for this Logger class
:return: List[str]
"""
return list(map(lambda x: x.value, C))
return list(map(lambda x: x.value, S2SIC))

@classmethod
def _get_prob_cols(cls) -> List[str]:
Expand Down Expand Up @@ -292,7 +296,7 @@ def separate_dataframe(
other_cols += ["id"]

emb = df_copy[emb_cols]
data_df = C.set_cols(df_copy[other_cols])
data_df = S2SIC.set_cols(df_copy[other_cols])
return BaseLoggerDataFrames(prob=prob, emb=emb, data=data_df)

@classmethod
Expand All @@ -315,7 +319,7 @@ def calculate_cutoffs(cls, df: DataFrame) -> DataFrame:
max_input_length = cls.logger_config.max_input_tokens
df = add_input_cutoff_to_df(df, tokenizer, max_tokens=max_input_length)

target_offsets_colname = C.token_label_offsets
target_offsets_colname = S2SIC.token_label_offsets
if target_offsets_colname in df.get_column_names():
df = add_target_cutoff_to_df(df, target_offsets_colname)

Expand Down
6 changes: 4 additions & 2 deletions dataquality/schemas/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ class Seq2SeqInputCols(str, Enum):
generated_output = "generated_output"
split_ = "split"
tokenized_label = "tokenized_label"
token_label_positions = "token_label_positions"
token_label_offsets = "token_label_offsets"
input_cutoff = "input_cutoff"
target_cutoff = "target_cutoff"
# Columns saved as pyarrow arrays
token_label_positions = "token_label_positions"
token_label_offsets = "token_label_offsets"
system_prompts = "system_prompts"

@classmethod
def set_cols(cls, df: DataFrame) -> DataFrame:
Expand Down
80 changes: 80 additions & 0 deletions docs/notebooks/Seq2Seq-Auto-Chat.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ['GALILEO_CONSOLE_URL']=\"\"\n",
"os.environ[\"GALILEO_USERNAME\"]=\"\"\n",
"os.environ[\"GALILEO_PASSWORD\"]=\"\"\n",
"\n",
"import dataquality as dq\n",
"dq.configure()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dataquality.integrations.seq2seq.auto import auto\n",
"from dataquality.integrations.seq2seq.formatters.chat import ChatHistoryFormatter\n",
"from dataquality.integrations.seq2seq.schema import Seq2SeqDatasetConfig, Seq2SeqGenerationConfig, Seq2SeqTrainingConfig\n",
"\n",
"chf = ChatHistoryFormatter(\n",
" assistant=\"Chatbot\"\n",
")\n",
"\n",
"dataset_config = Seq2SeqDatasetConfig(\n",
" train_path=\"./chats_with_system_role.jsonl\",\n",
" input_col=\"input\",\n",
" target_col=\"target\",\n",
" formatter=chf,\n",
")\n",
"gen_config = Seq2SeqGenerationConfig(\n",
" generation_splits=[]\n",
")\n",
"tr_config = Seq2SeqTrainingConfig(\n",
" max_target_tokens=256,\n",
" epochs=0\n",
")\n",
"\n",
"auto(project_name=\"auto_s2s\", run_name=\"my_run_name\", dataset_config=dataset_config, generation_config=gen_config, training_config=tr_config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit c2bc78d

Please sign in to comment.