@@ -457,48 +457,64 @@ def __init__(
457
457
458
458
# This is part of the public API.
459
459
@BlockRegistry .register ("LLMMessagesBlock" )
460
- class LLMMessagesBlock (LLMBlock ):
460
+ class LLMMessagesBlock (Block ):
461
461
def __init__ (
462
462
self ,
463
463
ctx ,
464
464
pipe ,
465
465
block_name ,
466
- config_path ,
467
- output_cols ,
468
- model_prompt = None ,
466
+ input_col ,
467
+ output_col ,
469
468
gen_kwargs = {},
470
- parser_kwargs = {},
471
- batch_kwargs = {},
472
469
) -> None :
473
- super ().__init__ (
474
- ctx ,
475
- pipe ,
476
- block_name ,
477
- config_path ,
478
- output_cols ,
479
- model_prompt = model_prompt ,
480
- gen_kwargs = gen_kwargs ,
481
- parser_kwargs = parser_kwargs ,
482
- batch_kwargs = batch_kwargs ,
470
+ super ().__init__ (ctx , pipe , block_name )
471
+ self .input_col = input_col
472
+ self .output_col = output_col
473
+ self .gen_kwargs = self ._gen_kwargs (
474
+ gen_kwargs ,
475
+ model = self .ctx .model_id ,
476
+ temperature = 0 ,
477
+ max_tokens = DEFAULT_MAX_NUM_TOKENS ,
483
478
)
484
479
485
- # def _generate(self, samples) -> list:
486
- # generate_args = {**self.defaults, **gen_kwargs}
487
-
488
- # if "n" in generate_args and generate_args.get("temperature", 0) <= 0:
489
- # generate_args["temperature"] = 0.7
490
- # logger.warning(
491
- # "Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
492
- # )
480
+ def _gen_kwargs (self , gen_kwargs , ** defaults ):
481
+ gen_kwargs = {** defaults , ** gen_kwargs }
482
+ if "temperature" in gen_kwargs :
483
+ gen_kwargs ["temperature" ] = float (gen_kwargs ["temperature" ])
484
+ if (
485
+ "n" in gen_kwargs
486
+ and gen_kwargs ["n" ] > 1
487
+ and gen_kwargs .get ("temperature" , 0 ) <= 0
488
+ ):
489
+ gen_kwargs ["temperature" ] = 0.7
490
+ logger .warning (
491
+ "Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
492
+ )
493
+ return gen_kwargs
493
494
494
- # messages = samples[self.input_col]
495
+ def _generate (self , samples ) -> list :
496
+ messages = samples [self .input_col ]
497
+ logger .debug ("STARTING GENERATION FOR LLMMessagesBlock" )
498
+ logger .debug (f"Generation arguments: { self .gen_kwargs } " )
499
+ results = []
500
+ progress_bar = tqdm (
501
+ range (len (samples )), desc = f"{ self .block_name } Chat Completion Generation"
502
+ )
503
+ n = self .gen_kwargs .get ("n" , 1 )
504
+ for message in messages :
505
+ logger .debug (f"CREATING CHAT COMPLETION FOR MESSAGE: { message } " )
506
+ responses = self .ctx .client .chat .completions .create (
507
+ messages = message , ** self .gen_kwargs
508
+ )
509
+ if n > 1 :
510
+ results .append ([choice .message .content for choice in responses .choices ])
511
+ else :
512
+ results .append (responses .choices [0 ].message .content )
513
+ progress_bar .update (n )
514
+ return results
495
515
496
- # results = []
497
- # n = gen_kwargs.get("n", 1)
498
- # for message in messages:
499
- # responses = self.client.chat.completions.create(messages=message, **generate_args)
500
- # if n > 1:
501
- # results.append([choice.message.content for choice in responses.choices])
502
- # else:
503
- # results.append(responses.choices[0].message.content)
504
- # return results
516
+ def generate (self , samples : Dataset ) -> Dataset :
517
+ outputs = self ._generate (samples )
518
+ logger .debug ("Generated outputs: %s" , outputs )
519
+ samples = samples .add_column (self .output_col , outputs )
520
+ return samples
0 commit comments