Skip to content

Commit a22f3dd

Browse files
matt-bernsteinniknikitabelonogov
authored
feat: DIA-1402: V1-Submit Prompt auto-refinement job (#214)
Co-authored-by: nik <[email protected]> Co-authored-by: Nikita Belonogov <[email protected]>
1 parent cb6e4d9 commit a22f3dd

File tree

8 files changed

+615
-39
lines changed

8 files changed

+615
-39
lines changed

adala/agents/base.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import traceback
23
from pydantic import (
34
BaseModel,
45
Field,
@@ -7,17 +8,19 @@
78
SerializeAsAny,
89
)
910
from abc import ABC
10-
from typing import Optional, Dict, Union, Tuple
11+
from typing import Optional, Dict, Union, Tuple, List
1112
from rich import print
1213
import yaml
1314

1415
from adala.environments.base import Environment, AsyncEnvironment, EnvironmentFeedback
1516
from adala.environments.static_env import StaticEnvironment
1617
from adala.runtimes.base import Runtime, AsyncRuntime
1718
from adala.runtimes._openai import OpenAIChatRuntime
18-
from adala.skills._base import Skill
19+
from adala.skills._base import Skill, TransformSkill
1920
from adala.memories.base import Memory
2021
from adala.skills.skillset import SkillSet, LinearSkillSet
22+
from adala.skills.collection.prompt_improvement import ImprovedPromptResponse
23+
2124
from adala.utils.logs import (
2225
print_dataframe,
2326
print_text,
@@ -26,7 +29,7 @@
2629
is_running_in_jupyter,
2730
)
2831
from adala.utils.internal_data import InternalDataFrame
29-
32+
from adala.utils.types import BatchData
3033
logger = logging.getLogger(__name__)
3134

3235

@@ -61,7 +64,7 @@ class Agent(BaseModel, ABC):
6164
default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")}
6265
)
6366
default_runtime: str = "default"
64-
teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field(
67+
teacher_runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
6568
default_factory=lambda: {"default": None}
6669
)
6770
default_teacher_runtime: str = "default"
@@ -118,7 +121,7 @@ def skills_validator(cls, v) -> SkillSet:
118121
f"skills must be of type SkillSet or Skill, but received type {type(v)}"
119122
)
120123

121-
@field_validator("runtimes", mode="before")
124+
@field_validator("runtimes", "teacher_runtimes", mode="before")
122125
def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
123126
"""
124127
Validates and creates runtimes
@@ -393,6 +396,48 @@ def learn(
393396

394397
print_text("Train is done!")
395398

399+
async def arefine_skill(
400+
self,
401+
skill_name: str,
402+
input_variables: List[str],
403+
batch_data: Optional[BatchData] = None,
404+
) -> ImprovedPromptResponse:
405+
"""
406+
beta v2 of Agent.learn() that is:
407+
- compatible with the newer LiteLLM runtimes
408+
- compatible with the newer response_model output formats for skills
409+
- returns chain of thought reasoning in a legible format
410+
411+
Limitations so far:
412+
- single skill at a time
413+
- only returns the improved input_template, doesn't modify the skill in place
414+
- doesn't use examples/feedback
415+
- no iterations/variable cost
416+
"""
417+
418+
skill = self.skills[skill_name]
419+
if not isinstance(skill, TransformSkill):
420+
raise ValueError(f"Skill {skill_name} is not a TransformSkill")
421+
422+
# get default runtimes
423+
runtime = self.get_runtime()
424+
teacher_runtime = self.get_teacher_runtime()
425+
426+
# get inputs
427+
# TODO: replace it with async environment.get_data_batch()
428+
if batch_data is None:
429+
predictions = None
430+
else:
431+
inputs = InternalDataFrame.from_records(batch_data or [])
432+
predictions = await self.skills.aapply(inputs, runtime=runtime)
433+
434+
response = await skill.aimprove(
435+
predictions=predictions,
436+
teacher_runtime=teacher_runtime,
437+
target_input_variables=input_variables,
438+
)
439+
return response
440+
396441

397442
def create_agent_from_dict(json_dict: Dict):
398443
"""

adala/runtimes/_litellm.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,21 @@ def record_to_record(
277277
usage = completion.usage
278278
dct = to_jsonable_python(response)
279279
except IncompleteOutputException as e:
280+
logger.error(f"Incomplete output error: {str(e)}")
281+
logger.error(f"Traceback:\n{traceback.format_exc()}")
280282
usage = e.total_usage
281283
dct = _log_llm_exception(e)
282284
except InstructorRetryException as e:
285+
logger.error(f"Instructor retry error: {str(e)}")
286+
logger.error(f"Traceback:\n{traceback.format_exc()}")
283287
usage = e.total_usage
284288
# get root cause error from retries
285289
n_attempts = e.n_attempts
286290
e = e.__cause__.last_attempt.exception()
287291
dct = _log_llm_exception(e)
288292
except Exception as e:
293+
logger.error(f"Other error: {str(e)}")
294+
logger.error(f"Traceback:\n{traceback.format_exc()}")
289295
# usage = e.total_usage
290296
# not available here, so have to approximate by hand, assuming the same error occurred each time
291297
n_attempts = retries.stop.max_attempt_number
@@ -485,8 +491,41 @@ async def record_to_record(
485491
extra_fields: Optional[Dict[str, Any]] = None,
486492
field_schema: Optional[Dict] = None,
487493
instructions_first: bool = True,
494+
response_model: Optional[Type[BaseModel]] = None,
488495
) -> Dict[str, str]:
489-
raise NotImplementedError("record_to_record is not implemented")
496+
"""
497+
Execute LiteLLM request given record and templates for input,
498+
instructions and output.
499+
500+
Args:
501+
record: Record to be used for input, instructions and output templates.
502+
input_template: Template for input message.
503+
instructions_template: Template for instructions message.
504+
output_template: Template for output message.
505+
extra_fields: Extra fields to be used in templates.
506+
field_schema: Field jsonschema to be used for parsing templates.
507+
instructions_first: If True, instructions will be sent before input.
508+
509+
Returns:
510+
Dict[str, str]: The processed record.
511+
"""
512+
# Create a single-row DataFrame from the input record
513+
input_df = InternalDataFrame([record])
514+
515+
# Use the batch_to_batch method to process the single-row DataFrame
516+
output_df = await self.batch_to_batch(
517+
input_df,
518+
input_template=input_template,
519+
instructions_template=instructions_template,
520+
output_template=output_template,
521+
extra_fields=extra_fields,
522+
field_schema=field_schema,
523+
instructions_first=instructions_first,
524+
response_model=response_model,
525+
)
526+
527+
# Extract the single row from the output DataFrame and convert it to a dictionary
528+
return output_df.iloc[0].to_dict()
490529

491530

492531
class LiteLLMVisionRuntime(LiteLLMChatRuntime):

adala/skills/_base.py

+100-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import string
3+
import traceback
34
from pydantic import (
45
BaseModel,
56
Field,
@@ -479,6 +480,50 @@ def improve(
479480
self.instructions = new_prompt
480481

481482

483+
async def aimprove(self, teacher_runtime: AsyncRuntime, target_input_variables: List[str], predictions: Optional[InternalDataFrame] = None):
484+
"""
485+
Improves the skill.
486+
"""
487+
488+
from adala.skills.collection.prompt_improvement import PromptImprovementSkill, ImprovedPromptResponse, ErrorResponseModel, PromptImprovementSkillResponseModel
489+
response_dct = {}
490+
try:
491+
prompt_improvement_skill = PromptImprovementSkill(
492+
skill_to_improve=self,
493+
input_variables=target_input_variables,
494+
)
495+
if predictions is None:
496+
input_df = InternalDataFrame()
497+
else:
498+
input_df = predictions
499+
response_df = await prompt_improvement_skill.aapply(
500+
input=input_df,
501+
runtime=teacher_runtime,
502+
)
503+
504+
# awkward to go from response model -> dict -> df -> dict -> response model
505+
response_dct = response_df.iloc[0].to_dict()
506+
507+
# unflatten the response
508+
if response_dct.pop("_adala_error", False):
509+
output = ErrorResponseModel(**response_dct)
510+
else:
511+
output = PromptImprovementSkillResponseModel(**response_dct)
512+
513+
except Exception as e:
514+
logger.error(f"Error improving skill: {e}. Traceback: {traceback.format_exc()}")
515+
output = ErrorResponseModel(
516+
_adala_message=str(e),
517+
_adala_details=traceback.format_exc(),
518+
)
519+
520+
# get tokens and token cost
521+
resp = ImprovedPromptResponse(output=output, **response_dct)
522+
logger.debug(f"resp: {resp}")
523+
524+
return resp
525+
526+
482527
class SampleTransformSkill(TransformSkill):
483528
sample_size: int
484529

@@ -548,30 +593,22 @@ class AnalysisSkill(Skill):
548593
Analysis skill that analyzes a dataframe and returns a record (e.g. for data analysis purposes).
549594
See base class Skill for more information about the attributes.
550595
"""
551-
596+
input_prefix: str = ""
552597
input_separator: str = "\n"
553598
chunk_size: Optional[int] = None
554599

555-
def apply(
556-
self,
557-
input: Union[InternalDataFrame, InternalSeries, Dict],
558-
runtime: Runtime,
559-
) -> InternalDataFrame:
560-
"""
561-
Applies the skill to a dataframe and returns a record.
562-
563-
Args:
564-
input (InternalDataFrame): The input data to be processed.
565-
runtime (Runtime): The runtime instance to be used for processing.
600+
def _iter_over_chunks(self, input: InternalDataFrame, chunk_size: Optional[int] = None):
566601

567-
Returns:
568-
InternalSeries: The record containing the analysis results.
569-
"""
602+
if input.empty:
603+
yield ""
604+
return
605+
570606
if isinstance(input, InternalSeries):
571607
input = input.to_frame()
572608
elif isinstance(input, dict):
573609
input = InternalDataFrame([input])
574610

611+
575612
extra_fields = self._get_extra_fields()
576613

577614
# if chunk_size is specified, split the input into chunks and process each chunk separately
@@ -582,25 +619,65 @@ def apply(
582619
)
583620
else:
584621
chunks = [input]
585-
outputs = []
622+
586623
total = input.shape[0] // self.chunk_size if self.chunk_size is not None else 1
587624
for chunk in tqdm(chunks, desc="Processing chunks", total=total):
588-
agg_chunk = (
589-
chunk.reset_index()
625+
agg_chunk = chunk\
626+
.reset_index()\
590627
.apply(
591628
lambda row: self.input_template.format(
592629
**row, **extra_fields, i=int(row.name) + 1
593630
),
594631
axis=1,
595-
)
596-
.str.cat(sep=self.input_separator)
597-
)
632+
).str.cat(sep=self.input_separator)
633+
634+
yield agg_chunk
635+
636+
def apply(
637+
self,
638+
input: Union[InternalDataFrame, InternalSeries, Dict],
639+
runtime: Runtime,
640+
) -> InternalDataFrame:
641+
"""
642+
Applies the skill to a dataframe and returns a record.
643+
644+
Args:
645+
input (InternalDataFrame): The input data to be processed.
646+
runtime (Runtime): The runtime instance to be used for processing.
647+
648+
Returns:
649+
InternalSeries: The record containing the analysis results.
650+
"""
651+
outputs = []
652+
for agg_chunk in self._iter_over_chunks(input):
598653
output = runtime.record_to_record(
599-
{"input": agg_chunk},
654+
{"input": f"{self.input_prefix}{agg_chunk}"},
655+
input_template="{input}",
656+
output_template=self.output_template,
657+
instructions_template=self.instructions,
658+
instructions_first=self.instructions_first,
659+
response_model=self.response_model,
660+
)
661+
outputs.append(InternalSeries(output))
662+
output = InternalDataFrame(outputs)
663+
664+
return output
665+
666+
async def aapply(
667+
self,
668+
input: Union[InternalDataFrame, InternalSeries, Dict],
669+
runtime: AsyncRuntime,
670+
) -> InternalDataFrame:
671+
"""
672+
Applies the skill to a dataframe and returns a record.
673+
"""
674+
outputs = []
675+
for agg_chunk in self._iter_over_chunks(input):
676+
output = await runtime.record_to_record(
677+
{"input": f"{self.input_prefix}{agg_chunk}"},
600678
input_template="{input}",
601679
output_template=self.output_template,
602680
instructions_template=self.instructions,
603-
extra_fields=extra_fields,
604681
instructions_first=self.instructions_first,
605682
response_model=self.response_model,
606683
)

0 commit comments

Comments
 (0)