13
13
import instructor
14
14
from instructor .exceptions import InstructorRetryException , IncompleteOutputException
15
15
import traceback
16
+ from adala .runtimes .base import CostEstimate
16
17
from adala .utils .exceptions import ConstrainedGenerationError
17
18
from adala .utils .internal_data import InternalDataFrame
18
19
from adala .utils .parse import (
@@ -122,7 +123,6 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
122
123
123
124
124
125
class InstructorClientMixin :
125
-
126
126
def _from_litellm (self , ** kwargs ):
127
127
return instructor .from_litellm (litellm .completion , ** kwargs )
128
128
@@ -139,7 +139,6 @@ def is_custom_openai_endpoint(self) -> bool:
139
139
140
140
141
141
class InstructorAsyncClientMixin (InstructorClientMixin ):
142
-
143
142
def _from_litellm (self , ** kwargs ):
144
143
return instructor .from_litellm (litellm .acompletion , ** kwargs )
145
144
@@ -527,6 +526,73 @@ async def record_to_record(
527
526
# Extract the single row from the output DataFrame and convert it to a dictionary
528
527
return output_df .iloc [0 ].to_dict ()
529
528
529
+ @staticmethod
530
+ def _get_prompt_tokens (string : str , model : str , output_fields : List [str ]) -> int :
531
+ user_tokens = litellm .token_counter (model = model , text = string )
532
+ # FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us.
533
+ # currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that.
534
+ system_tokens = 56 + (6 * len (output_fields ))
535
+ return user_tokens + system_tokens
536
+
537
+ @staticmethod
538
+ def _get_completion_tokens (model : str , output_fields : Optional [List [str ]]) -> int :
539
+ max_tokens = litellm .get_model_info (
540
+ model = model , custom_llm_provider = "openai"
541
+ ).get ("max_tokens" , None )
542
+ if not max_tokens :
543
+ raise ValueError
544
+ # extremely rough heuristic, from testing on some anecdotal examples
545
+ n_outputs = len (output_fields ) if output_fields else 1
546
+ return min (max_tokens , 4 * n_outputs )
547
+
548
+ @classmethod
549
+ def _estimate_cost (
550
+ cls , user_prompt : str , model : str , output_fields : Optional [List [str ]]
551
+ ):
552
+ prompt_tokens = cls ._get_prompt_tokens (user_prompt , model , output_fields )
553
+ completion_tokens = cls ._get_completion_tokens (model , output_fields )
554
+ prompt_cost , completion_cost = litellm .cost_per_token (
555
+ model = model ,
556
+ prompt_tokens = prompt_tokens ,
557
+ completion_tokens = completion_tokens ,
558
+ )
559
+ total_cost = prompt_cost + completion_cost
560
+
561
+ return prompt_cost , completion_cost , total_cost
562
+
563
+ def get_cost_estimate (
564
+ self , prompt : str , substitutions : List [Dict ], output_fields : Optional [List [str ]]
565
+ ) -> CostEstimate :
566
+ try :
567
+ user_prompts = [
568
+ prompt .format (** substitution ) for substitution in substitutions
569
+ ]
570
+ cumulative_prompt_cost = 0
571
+ cumulative_completion_cost = 0
572
+ cumulative_total_cost = 0
573
+ for user_prompt in user_prompts :
574
+ prompt_cost , completion_cost , total_cost = self ._estimate_cost (
575
+ user_prompt = user_prompt ,
576
+ model = self .model ,
577
+ output_fields = output_fields ,
578
+ )
579
+ cumulative_prompt_cost += prompt_cost
580
+ cumulative_completion_cost += completion_cost
581
+ cumulative_total_cost += total_cost
582
+ return CostEstimate (
583
+ prompt_cost_usd = cumulative_prompt_cost ,
584
+ completion_cost_usd = cumulative_completion_cost ,
585
+ total_cost_usd = cumulative_total_cost ,
586
+ )
587
+
588
+ except Exception as e :
589
+ logger .error ("Failed to estimate cost: %s" , e )
590
+ return CostEstimate (
591
+ is_error = True ,
592
+ error_type = type (e ).__name__ ,
593
+ error_message = str (e ),
594
+ )
595
+
530
596
531
597
class LiteLLMVisionRuntime (LiteLLMChatRuntime ):
532
598
"""
0 commit comments