1
1
import logging
2
2
import string
3
+ import traceback
3
4
from pydantic import (
4
5
BaseModel ,
5
6
Field ,
@@ -479,6 +480,50 @@ def improve(
479
480
self .instructions = new_prompt
480
481
481
482
483
+ async def aimprove (self , teacher_runtime : AsyncRuntime , target_input_variables : List [str ], predictions : Optional [InternalDataFrame ] = None ):
484
+ """
485
+ Improves the skill.
486
+ """
487
+
488
+ from adala .skills .collection .prompt_improvement import PromptImprovementSkill , ImprovedPromptResponse , ErrorResponseModel , PromptImprovementSkillResponseModel
489
+ response_dct = {}
490
+ try :
491
+ prompt_improvement_skill = PromptImprovementSkill (
492
+ skill_to_improve = self ,
493
+ input_variables = target_input_variables ,
494
+ )
495
+ if predictions is None :
496
+ input_df = InternalDataFrame ()
497
+ else :
498
+ input_df = predictions
499
+ response_df = await prompt_improvement_skill .aapply (
500
+ input = input_df ,
501
+ runtime = teacher_runtime ,
502
+ )
503
+
504
+ # awkward to go from response model -> dict -> df -> dict -> response model
505
+ response_dct = response_df .iloc [0 ].to_dict ()
506
+
507
+ # unflatten the response
508
+ if response_dct .pop ("_adala_error" , False ):
509
+ output = ErrorResponseModel (** response_dct )
510
+ else :
511
+ output = PromptImprovementSkillResponseModel (** response_dct )
512
+
513
+ except Exception as e :
514
+ logger .error (f"Error improving skill: { e } . Traceback: { traceback .format_exc ()} " )
515
+ output = ErrorResponseModel (
516
+ _adala_message = str (e ),
517
+ _adala_details = traceback .format_exc (),
518
+ )
519
+
520
+ # get tokens and token cost
521
+ resp = ImprovedPromptResponse (output = output , ** response_dct )
522
+ logger .debug (f"resp: { resp } " )
523
+
524
+ return resp
525
+
526
+
482
527
class SampleTransformSkill (TransformSkill ):
483
528
sample_size : int
484
529
@@ -548,30 +593,22 @@ class AnalysisSkill(Skill):
548
593
Analysis skill that analyzes a dataframe and returns a record (e.g. for data analysis purposes).
549
594
See base class Skill for more information about the attributes.
550
595
"""
551
-
596
+ input_prefix : str = ""
552
597
input_separator : str = "\n "
553
598
chunk_size : Optional [int ] = None
554
599
555
- def apply (
556
- self ,
557
- input : Union [InternalDataFrame , InternalSeries , Dict ],
558
- runtime : Runtime ,
559
- ) -> InternalDataFrame :
560
- """
561
- Applies the skill to a dataframe and returns a record.
562
-
563
- Args:
564
- input (InternalDataFrame): The input data to be processed.
565
- runtime (Runtime): The runtime instance to be used for processing.
600
+ def _iter_over_chunks (self , input : InternalDataFrame , chunk_size : Optional [int ] = None ):
566
601
567
- Returns:
568
- InternalSeries: The record containing the analysis results.
569
- """
602
+ if input .empty :
603
+ yield ""
604
+ return
605
+
570
606
if isinstance (input , InternalSeries ):
571
607
input = input .to_frame ()
572
608
elif isinstance (input , dict ):
573
609
input = InternalDataFrame ([input ])
574
610
611
+
575
612
extra_fields = self ._get_extra_fields ()
576
613
577
614
# if chunk_size is specified, split the input into chunks and process each chunk separately
@@ -582,25 +619,65 @@ def apply(
582
619
)
583
620
else :
584
621
chunks = [input ]
585
- outputs = []
622
+
586
623
total = input .shape [0 ] // self .chunk_size if self .chunk_size is not None else 1
587
624
for chunk in tqdm (chunks , desc = "Processing chunks" , total = total ):
588
- agg_chunk = (
589
- chunk .reset_index ()
625
+ agg_chunk = chunk \
626
+ .reset_index ()\
590
627
.apply (
591
628
lambda row : self .input_template .format (
592
629
** row , ** extra_fields , i = int (row .name ) + 1
593
630
),
594
631
axis = 1 ,
595
- )
596
- .str .cat (sep = self .input_separator )
597
- )
632
+ ).str .cat (sep = self .input_separator )
633
+
634
+ yield agg_chunk
635
+
636
+ def apply (
637
+ self ,
638
+ input : Union [InternalDataFrame , InternalSeries , Dict ],
639
+ runtime : Runtime ,
640
+ ) -> InternalDataFrame :
641
+ """
642
+ Applies the skill to a dataframe and returns a record.
643
+
644
+ Args:
645
+ input (InternalDataFrame): The input data to be processed.
646
+ runtime (Runtime): The runtime instance to be used for processing.
647
+
648
+ Returns:
649
+ InternalSeries: The record containing the analysis results.
650
+ """
651
+ outputs = []
652
+ for agg_chunk in self ._iter_over_chunks (input ):
598
653
output = runtime .record_to_record (
599
- {"input" : agg_chunk },
654
+ {"input" : f"{ self .input_prefix } { agg_chunk } " },
655
+ input_template = "{input}" ,
656
+ output_template = self .output_template ,
657
+ instructions_template = self .instructions ,
658
+ instructions_first = self .instructions_first ,
659
+ response_model = self .response_model ,
660
+ )
661
+ outputs .append (InternalSeries (output ))
662
+ output = InternalDataFrame (outputs )
663
+
664
+ return output
665
+
666
+ async def aapply (
667
+ self ,
668
+ input : Union [InternalDataFrame , InternalSeries , Dict ],
669
+ runtime : AsyncRuntime ,
670
+ ) -> InternalDataFrame :
671
+ """
672
+ Applies the skill to a dataframe and returns a record.
673
+ """
674
+ outputs = []
675
+ for agg_chunk in self ._iter_over_chunks (input ):
676
+ output = await runtime .record_to_record (
677
+ {"input" : f"{ self .input_prefix } { agg_chunk } " },
600
678
input_template = "{input}" ,
601
679
output_template = self .output_template ,
602
680
instructions_template = self .instructions ,
603
- extra_fields = extra_fields ,
604
681
instructions_first = self .instructions_first ,
605
682
response_model = self .response_model ,
606
683
)
0 commit comments