@@ -60,7 +60,9 @@ def __init__(
60
60
61
61
# We only need one thread as we'll immediately block after launching it.
62
62
self .executor = (
63
- ThreadPoolExecutor (thread_name_prefix = f"async-pipeline-executor-{ id (self )} " , max_workers = 1 )
63
+ ThreadPoolExecutor (
64
+ thread_name_prefix = f"async-pipeline-executor-{ id (self )} " , max_workers = 1
65
+ )
64
66
if async_executor is None
65
67
else async_executor
66
68
)
@@ -88,17 +90,27 @@ async def _run_component(
88
90
tags = {
89
91
"haystack.component.name" : name ,
90
92
"haystack.component.type" : instance .__class__ .__name__ ,
91
- "haystack.component.input_types" : {k : type (v ).__name__ for k , v in inputs .items ()},
93
+ "haystack.component.input_types" : {
94
+ k : type (v ).__name__ for k , v in inputs .items ()
95
+ },
92
96
"haystack.component.input_spec" : {
93
97
key : {
94
- "type" : (value .type .__name__ if isinstance (value .type , type ) else str (value .type )),
98
+ "type" : (
99
+ value .type .__name__
100
+ if isinstance (value .type , type )
101
+ else str (value .type )
102
+ ),
95
103
"senders" : value .senders ,
96
104
}
97
105
for key , value in instance .__haystack_input__ ._sockets_dict .items () # type: ignore
98
106
},
99
107
"haystack.component.output_spec" : {
100
108
key : {
101
- "type" : (value .type .__name__ if isinstance (value .type , type ) else str (value .type )),
109
+ "type" : (
110
+ value .type .__name__
111
+ if isinstance (value .type , type )
112
+ else str (value .type )
113
+ ),
102
114
"receivers" : value .receivers ,
103
115
}
104
116
for key , value in instance .__haystack_output__ ._sockets_dict .items () # type: ignore
@@ -113,14 +125,18 @@ async def _run_component(
113
125
114
126
res : Dict [str , Any ]
115
127
if instance .__haystack_supports_async__ : # type: ignore
116
- logger .info ("Running async component {component_name}" , component_name = name )
128
+ logger .info (
129
+ "Running async component {component_name}" , component_name = name
130
+ )
117
131
res = await instance .run_async (** inputs ) # type: ignore
118
132
else :
119
133
logger .info (
120
134
"Running sync component {component_name} on executor" ,
121
135
component_name = name ,
122
136
)
123
- res = await asyncio .get_event_loop ().run_in_executor (self .executor , lambda : instance .run (** inputs ))
137
+ res = await asyncio .get_event_loop ().run_in_executor (
138
+ self .executor , lambda : instance .run (** inputs )
139
+ )
124
140
self .graph .nodes [name ]["visits" ] += 1
125
141
126
142
# After a Component that has variadic inputs is run, we need to reset the variadic inputs that were consumed
@@ -187,7 +203,9 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
187
203
while not cycle_received_inputs :
188
204
# Here we run the Components
189
205
name , comp = run_queue .pop (0 )
190
- if _is_lazy_variadic (comp ) and not all (_is_lazy_variadic (comp ) for _ , comp in run_queue ):
206
+ if _is_lazy_variadic (comp ) and not all (
207
+ _is_lazy_variadic (comp ) for _ , comp in run_queue
208
+ ):
191
209
# We run Components with lazy variadic inputs only if there only Components with
192
210
# lazy variadic inputs left to run
193
211
_enqueue_waiting_component ((name , comp ), waiting_queue )
@@ -199,7 +217,9 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
199
217
msg = f"Maximum run count { self ._max_runs_per_component } reached for component '{ name } '"
200
218
raise PipelineMaxComponentRuns (msg )
201
219
202
- res : Dict [str , Any ] = await self ._run_component (name , components_inputs [name ])
220
+ res : Dict [str , Any ] = await self ._run_component (
221
+ name , components_inputs [name ]
222
+ )
203
223
yield {name : deepcopy (res )}, False
204
224
205
225
# Delete the inputs that were consumed by the Component and are not received from
@@ -238,12 +258,18 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
238
258
# We manage to run this component that was in the waiting list, we can remove it.
239
259
# This happens when a component was put in the waiting list but we reached it from another edge.
240
260
_dequeue_waiting_component ((name , comp ), waiting_queue )
241
- for pair in self ._find_components_that_will_receive_no_input (name , res , components_inputs ):
261
+ for pair in self ._find_components_that_will_receive_no_input (
262
+ name , res , components_inputs
263
+ ):
242
264
_dequeue_component (pair , run_queue , waiting_queue )
243
265
244
- receivers = [item for item in self ._find_receivers_from (name ) if item [0 ] in cycle ]
266
+ receivers = [
267
+ item for item in self ._find_receivers_from (name ) if item [0 ] in cycle
268
+ ]
245
269
246
- res = self ._distribute_output (receivers , res , components_inputs , run_queue , waiting_queue )
270
+ res = self ._distribute_output (
271
+ receivers , res , components_inputs , run_queue , waiting_queue
272
+ )
247
273
248
274
# We treat a cycle as a completely independent graph, so we keep track of output
249
275
# that is not sent inside the cycle.
@@ -274,21 +300,31 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
274
300
warn (RuntimeWarning (msg ))
275
301
break
276
302
277
- (name , comp ) = self ._find_next_runnable_lazy_variadic_or_default_component (waiting_queue )
303
+ (name , comp ) = (
304
+ self ._find_next_runnable_lazy_variadic_or_default_component (
305
+ waiting_queue
306
+ )
307
+ )
278
308
_add_missing_input_defaults (name , comp , components_inputs )
279
309
_enqueue_component ((name , comp ), run_queue , waiting_queue )
280
310
continue
281
311
282
- before_last_waiting_queue = last_waiting_queue .copy () if last_waiting_queue is not None else None
312
+ before_last_waiting_queue = (
313
+ last_waiting_queue .copy ()
314
+ if last_waiting_queue is not None
315
+ else None
316
+ )
283
317
last_waiting_queue = {item [0 ] for item in waiting_queue }
284
318
285
- (name , comp ) = self ._find_next_runnable_component (components_inputs , waiting_queue )
319
+ (name , comp ) = self ._find_next_runnable_component (
320
+ components_inputs , waiting_queue
321
+ )
286
322
_add_missing_input_defaults (name , comp , components_inputs )
287
323
_enqueue_component ((name , comp ), run_queue , waiting_queue )
288
324
289
325
yield subgraph_outputs , True
290
326
291
- async def run ( # noqa: PLR0915
327
+ async def run ( # noqa: PLR0915, PLR0912
292
328
self ,
293
329
data : Dict [str , Any ],
294
330
) -> AsyncIterator [Dict [str , Any ]]:
@@ -368,7 +404,9 @@ def run(self, word: str):
368
404
self ._validate_input (data )
369
405
370
406
# Normalize the input data
371
- components_inputs : Dict [str , Dict [str , Any ]] = self ._normalize_varidiac_input_data (data )
407
+ components_inputs : Dict [str , Dict [str , Any ]] = (
408
+ self ._normalize_varidiac_input_data (data )
409
+ )
372
410
373
411
# These variables are used to detect when we're stuck in a loop.
374
412
# Stuck loops can happen when one or more components are waiting for input but
@@ -391,7 +429,9 @@ def run(self, word: str):
391
429
392
430
# Break cycles in case there are, this is a noop if no cycle is found.
393
431
# This will raise if a cycle can't be broken.
394
- graph_without_cycles , components_in_cycles = self ._break_supported_cycles_in_graph ()
432
+ graph_without_cycles , components_in_cycles = (
433
+ self ._break_supported_cycles_in_graph ()
434
+ )
395
435
396
436
run_queue : List [Tuple [str , Component ]] = []
397
437
for node in nx .topological_sort (graph_without_cycles ):
@@ -426,14 +466,16 @@ def run(self, word: str):
426
466
while len (run_queue ) > 0 :
427
467
name , comp = run_queue .pop (0 )
428
468
429
- if _is_lazy_variadic (comp ) and not all (_is_lazy_variadic (comp ) for _ , comp in run_queue ):
469
+ if _is_lazy_variadic (comp ) and not all (
470
+ _is_lazy_variadic (comp ) for _ , comp in run_queue
471
+ ):
430
472
# We run Components with lazy variadic inputs only if there only Components with
431
473
# lazy variadic inputs left to run
432
474
_enqueue_waiting_component ((name , comp ), waiting_queue )
433
475
continue
434
- if self ._component_has_enough_inputs_to_run (name , components_inputs ) and components_in_cycles . get (
435
- name , []
436
- ):
476
+ if self ._component_has_enough_inputs_to_run (
477
+ name , components_inputs
478
+ ) and components_in_cycles . get ( name , []) :
437
479
cycles = components_in_cycles .get (name , [])
438
480
439
481
# This component is part of one or more cycles, let's get the first one and run it.
@@ -474,13 +516,17 @@ def run(self, word: str):
474
516
msg = f"Maximum run count { self ._max_runs_per_component } reached for component '{ name } '"
475
517
raise PipelineMaxComponentRuns (msg )
476
518
477
- res : Dict [str , Any ] = await self ._run_component (name , components_inputs [name ], parent_span = span )
519
+ res : Dict [str , Any ] = await self ._run_component (
520
+ name , components_inputs [name ], parent_span = span
521
+ )
478
522
yield {name : deepcopy (res )}
479
523
480
524
# Delete the inputs that were consumed by the Component and are not received from the user
481
525
sockets = list (components_inputs [name ].keys ())
482
526
for socket_name in sockets :
483
- senders = comp .__haystack_input__ ._sockets_dict [socket_name ].senders
527
+ senders = comp .__haystack_input__ ._sockets_dict [
528
+ socket_name
529
+ ].senders
484
530
if senders :
485
531
# Delete all inputs that are received from other Components
486
532
del components_inputs [name ][socket_name ]
@@ -494,10 +540,14 @@ def run(self, word: str):
494
540
# This happens when a component was put in the waiting list but we reached it from another edge.
495
541
_dequeue_waiting_component ((name , comp ), waiting_queue )
496
542
497
- for pair in self ._find_components_that_will_receive_no_input (name , res , components_inputs ):
543
+ for pair in self ._find_components_that_will_receive_no_input (
544
+ name , res , components_inputs
545
+ ):
498
546
_dequeue_component (pair , run_queue , waiting_queue )
499
547
receivers = self ._find_receivers_from (name )
500
- res = self ._distribute_output (receivers , res , components_inputs , run_queue , waiting_queue )
548
+ res = self ._distribute_output (
549
+ receivers , res , components_inputs , run_queue , waiting_queue
550
+ )
501
551
502
552
if len (res ) > 0 :
503
553
final_outputs [name ] = res
@@ -523,15 +573,25 @@ def run(self, word: str):
523
573
warn (RuntimeWarning (msg ))
524
574
break
525
575
526
- (name , comp ) = self ._find_next_runnable_lazy_variadic_or_default_component (waiting_queue )
576
+ (name , comp ) = (
577
+ self ._find_next_runnable_lazy_variadic_or_default_component (
578
+ waiting_queue
579
+ )
580
+ )
527
581
_add_missing_input_defaults (name , comp , components_inputs )
528
582
_enqueue_component ((name , comp ), run_queue , waiting_queue )
529
583
continue
530
584
531
- before_last_waiting_queue = last_waiting_queue .copy () if last_waiting_queue is not None else None
585
+ before_last_waiting_queue = (
586
+ last_waiting_queue .copy ()
587
+ if last_waiting_queue is not None
588
+ else None
589
+ )
532
590
last_waiting_queue = {item [0 ] for item in waiting_queue }
533
591
534
- (name , comp ) = self ._find_next_runnable_component (components_inputs , waiting_queue )
592
+ (name , comp ) = self ._find_next_runnable_component (
593
+ components_inputs , waiting_queue
594
+ )
535
595
_add_missing_input_defaults (name , comp , components_inputs )
536
596
_enqueue_component ((name , comp ), run_queue , waiting_queue )
537
597
@@ -567,7 +627,10 @@ async def run_async_pipeline(
567
627
outputs = [x async for x in pipeline .run (data )]
568
628
569
629
intermediate_outputs = {
570
- k : v for d in outputs [:- 1 ] for k , v in d .items () if include_outputs_from is None or k in include_outputs_from
630
+ k : v
631
+ for d in outputs [:- 1 ]
632
+ for k , v in d .items ()
633
+ if include_outputs_from is None or k in include_outputs_from
571
634
}
572
635
final_output = outputs [- 1 ]
573
636
0 commit comments