Skip to content

Commit 22fcc4b

Browse files
committed
fix: Lints
1 parent 8b11072 commit 22fcc4b

File tree

1 file changed

+92
-29
lines changed

1 file changed

+92
-29
lines changed

haystack_experimental/core/pipeline/async_pipeline.py

+92-29
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def __init__(
6060

6161
# We only need one thread as we'll immediately block after launching it.
6262
self.executor = (
63-
ThreadPoolExecutor(thread_name_prefix=f"async-pipeline-executor-{id(self)}", max_workers=1)
63+
ThreadPoolExecutor(
64+
thread_name_prefix=f"async-pipeline-executor-{id(self)}", max_workers=1
65+
)
6466
if async_executor is None
6567
else async_executor
6668
)
@@ -88,17 +90,27 @@ async def _run_component(
8890
tags={
8991
"haystack.component.name": name,
9092
"haystack.component.type": instance.__class__.__name__,
91-
"haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()},
93+
"haystack.component.input_types": {
94+
k: type(v).__name__ for k, v in inputs.items()
95+
},
9296
"haystack.component.input_spec": {
9397
key: {
94-
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
98+
"type": (
99+
value.type.__name__
100+
if isinstance(value.type, type)
101+
else str(value.type)
102+
),
95103
"senders": value.senders,
96104
}
97105
for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore
98106
},
99107
"haystack.component.output_spec": {
100108
key: {
101-
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
109+
"type": (
110+
value.type.__name__
111+
if isinstance(value.type, type)
112+
else str(value.type)
113+
),
102114
"receivers": value.receivers,
103115
}
104116
for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore
@@ -113,14 +125,18 @@ async def _run_component(
113125

114126
res: Dict[str, Any]
115127
if instance.__haystack_supports_async__: # type: ignore
116-
logger.info("Running async component {component_name}", component_name=name)
128+
logger.info(
129+
"Running async component {component_name}", component_name=name
130+
)
117131
res = await instance.run_async(**inputs) # type: ignore
118132
else:
119133
logger.info(
120134
"Running sync component {component_name} on executor",
121135
component_name=name,
122136
)
123-
res = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: instance.run(**inputs))
137+
res = await asyncio.get_event_loop().run_in_executor(
138+
self.executor, lambda: instance.run(**inputs)
139+
)
124140
self.graph.nodes[name]["visits"] += 1
125141

126142
# After a Component that has variadic inputs is run, we need to reset the variadic inputs that were consumed
@@ -187,7 +203,9 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
187203
while not cycle_received_inputs:
188204
# Here we run the Components
189205
name, comp = run_queue.pop(0)
190-
if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue):
206+
if _is_lazy_variadic(comp) and not all(
207+
_is_lazy_variadic(comp) for _, comp in run_queue
208+
):
191209
# We run Components with lazy variadic inputs only if there only Components with
192210
# lazy variadic inputs left to run
193211
_enqueue_waiting_component((name, comp), waiting_queue)
@@ -199,7 +217,9 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
199217
msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'"
200218
raise PipelineMaxComponentRuns(msg)
201219

202-
res: Dict[str, Any] = await self._run_component(name, components_inputs[name])
220+
res: Dict[str, Any] = await self._run_component(
221+
name, components_inputs[name]
222+
)
203223
yield {name: deepcopy(res)}, False
204224

205225
# Delete the inputs that were consumed by the Component and are not received from
@@ -238,12 +258,18 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
238258
# We manage to run this component that was in the waiting list, we can remove it.
239259
# This happens when a component was put in the waiting list but we reached it from another edge.
240260
_dequeue_waiting_component((name, comp), waiting_queue)
241-
for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs):
261+
for pair in self._find_components_that_will_receive_no_input(
262+
name, res, components_inputs
263+
):
242264
_dequeue_component(pair, run_queue, waiting_queue)
243265

244-
receivers = [item for item in self._find_receivers_from(name) if item[0] in cycle]
266+
receivers = [
267+
item for item in self._find_receivers_from(name) if item[0] in cycle
268+
]
245269

246-
res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue)
270+
res = self._distribute_output(
271+
receivers, res, components_inputs, run_queue, waiting_queue
272+
)
247273

248274
# We treat a cycle as a completely independent graph, so we keep track of output
249275
# that is not sent inside the cycle.
@@ -274,21 +300,31 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
274300
warn(RuntimeWarning(msg))
275301
break
276302

277-
(name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue)
303+
(name, comp) = (
304+
self._find_next_runnable_lazy_variadic_or_default_component(
305+
waiting_queue
306+
)
307+
)
278308
_add_missing_input_defaults(name, comp, components_inputs)
279309
_enqueue_component((name, comp), run_queue, waiting_queue)
280310
continue
281311

282-
before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None
312+
before_last_waiting_queue = (
313+
last_waiting_queue.copy()
314+
if last_waiting_queue is not None
315+
else None
316+
)
283317
last_waiting_queue = {item[0] for item in waiting_queue}
284318

285-
(name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue)
319+
(name, comp) = self._find_next_runnable_component(
320+
components_inputs, waiting_queue
321+
)
286322
_add_missing_input_defaults(name, comp, components_inputs)
287323
_enqueue_component((name, comp), run_queue, waiting_queue)
288324

289325
yield subgraph_outputs, True
290326

291-
async def run( # noqa: PLR0915
327+
async def run( # noqa: PLR0915, PLR0912
292328
self,
293329
data: Dict[str, Any],
294330
) -> AsyncIterator[Dict[str, Any]]:
@@ -368,7 +404,9 @@ def run(self, word: str):
368404
self._validate_input(data)
369405

370406
# Normalize the input data
371-
components_inputs: Dict[str, Dict[str, Any]] = self._normalize_varidiac_input_data(data)
407+
components_inputs: Dict[str, Dict[str, Any]] = (
408+
self._normalize_varidiac_input_data(data)
409+
)
372410

373411
# These variables are used to detect when we're stuck in a loop.
374412
# Stuck loops can happen when one or more components are waiting for input but
@@ -391,7 +429,9 @@ def run(self, word: str):
391429

392430
# Break cycles in case there are, this is a noop if no cycle is found.
393431
# This will raise if a cycle can't be broken.
394-
graph_without_cycles, components_in_cycles = self._break_supported_cycles_in_graph()
432+
graph_without_cycles, components_in_cycles = (
433+
self._break_supported_cycles_in_graph()
434+
)
395435

396436
run_queue: List[Tuple[str, Component]] = []
397437
for node in nx.topological_sort(graph_without_cycles):
@@ -426,14 +466,16 @@ def run(self, word: str):
426466
while len(run_queue) > 0:
427467
name, comp = run_queue.pop(0)
428468

429-
if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue):
469+
if _is_lazy_variadic(comp) and not all(
470+
_is_lazy_variadic(comp) for _, comp in run_queue
471+
):
430472
# We run Components with lazy variadic inputs only if there only Components with
431473
# lazy variadic inputs left to run
432474
_enqueue_waiting_component((name, comp), waiting_queue)
433475
continue
434-
if self._component_has_enough_inputs_to_run(name, components_inputs) and components_in_cycles.get(
435-
name, []
436-
):
476+
if self._component_has_enough_inputs_to_run(
477+
name, components_inputs
478+
) and components_in_cycles.get(name, []):
437479
cycles = components_in_cycles.get(name, [])
438480

439481
# This component is part of one or more cycles, let's get the first one and run it.
@@ -474,13 +516,17 @@ def run(self, word: str):
474516
msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'"
475517
raise PipelineMaxComponentRuns(msg)
476518

477-
res: Dict[str, Any] = await self._run_component(name, components_inputs[name], parent_span=span)
519+
res: Dict[str, Any] = await self._run_component(
520+
name, components_inputs[name], parent_span=span
521+
)
478522
yield {name: deepcopy(res)}
479523

480524
# Delete the inputs that were consumed by the Component and are not received from the user
481525
sockets = list(components_inputs[name].keys())
482526
for socket_name in sockets:
483-
senders = comp.__haystack_input__._sockets_dict[socket_name].senders
527+
senders = comp.__haystack_input__._sockets_dict[
528+
socket_name
529+
].senders
484530
if senders:
485531
# Delete all inputs that are received from other Components
486532
del components_inputs[name][socket_name]
@@ -494,10 +540,14 @@ def run(self, word: str):
494540
# This happens when a component was put in the waiting list but we reached it from another edge.
495541
_dequeue_waiting_component((name, comp), waiting_queue)
496542

497-
for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs):
543+
for pair in self._find_components_that_will_receive_no_input(
544+
name, res, components_inputs
545+
):
498546
_dequeue_component(pair, run_queue, waiting_queue)
499547
receivers = self._find_receivers_from(name)
500-
res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue)
548+
res = self._distribute_output(
549+
receivers, res, components_inputs, run_queue, waiting_queue
550+
)
501551

502552
if len(res) > 0:
503553
final_outputs[name] = res
@@ -523,15 +573,25 @@ def run(self, word: str):
523573
warn(RuntimeWarning(msg))
524574
break
525575

526-
(name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue)
576+
(name, comp) = (
577+
self._find_next_runnable_lazy_variadic_or_default_component(
578+
waiting_queue
579+
)
580+
)
527581
_add_missing_input_defaults(name, comp, components_inputs)
528582
_enqueue_component((name, comp), run_queue, waiting_queue)
529583
continue
530584

531-
before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None
585+
before_last_waiting_queue = (
586+
last_waiting_queue.copy()
587+
if last_waiting_queue is not None
588+
else None
589+
)
532590
last_waiting_queue = {item[0] for item in waiting_queue}
533591

534-
(name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue)
592+
(name, comp) = self._find_next_runnable_component(
593+
components_inputs, waiting_queue
594+
)
535595
_add_missing_input_defaults(name, comp, components_inputs)
536596
_enqueue_component((name, comp), run_queue, waiting_queue)
537597

@@ -567,7 +627,10 @@ async def run_async_pipeline(
567627
outputs = [x async for x in pipeline.run(data)]
568628

569629
intermediate_outputs = {
570-
k: v for d in outputs[:-1] for k, v in d.items() if include_outputs_from is None or k in include_outputs_from
630+
k: v
631+
for d in outputs[:-1]
632+
for k, v in d.items()
633+
if include_outputs_from is None or k in include_outputs_from
571634
}
572635
final_output = outputs[-1]
573636

0 commit comments

Comments
 (0)