feat(parallel-exeuction): add asyncio event loop dispatcher with semaphore for parallel graph instances

TODO: still untested
This commit is contained in:
Federico Minutoli 2024-05-11 00:13:27 +02:00
parent 7ae50c035e
commit 627cbeeb20

View File

@ -2,12 +2,18 @@
GraphIterator Module GraphIterator Module
""" """
from typing import List, Optional import asyncio
import copy import copy
from tqdm import tqdm from typing import List, Optional
from tqdm.asyncio import tqdm
from .base_node import BaseNode from .base_node import BaseNode
_default_batchsize = 4
class GraphIteratorNode(BaseNode): class GraphIteratorNode(BaseNode):
""" """
A node responsible for instantiating and running multiple graph instances in parallel. A node responsible for instantiating and running multiple graph instances in parallel.
@ -23,12 +29,20 @@ class GraphIteratorNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "Parse". node_name (str): The unique identifier name for the node, defaulting to "Parse".
""" """
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "GraphIterator"): def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "GraphIterator",
):
super().__init__(node_name, "node", input, output, 2, node_config) super().__init__(node_name, "node", input, output, 2, node_config)
self.verbose = False if node_config is None else node_config.get("verbose", False) self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)
def execute(self, state: dict) -> dict: def execute(self, state: dict) -> dict:
""" """
Executes the node's logic to instantiate and run multiple graph instances in parallel. Executes the node's logic to instantiate and run multiple graph instances in parallel.
@ -43,37 +57,78 @@ class GraphIteratorNode(BaseNode):
KeyError: If the input keys are not found in the state, indicating that the KeyError: If the input keys are not found in the state, indicating that the
necessary information for running the graph instances is missing. necessary information for running the graph instances is missing.
""" """
batchsize = self.node_config.get("batchsize", _default_batchsize)
if self.verbose: if self.verbose:
print(f"--- Executing {self.node_name} Node ---") print(f"--- Executing {self.node_name} Node with batchsize {batchsize} ---")
# Interpret input keys based on the provided input expression try:
eventloop = asyncio.get_event_loop()
except RuntimeError:
eventloop = None
if eventloop and eventloop.is_running():
state = eventloop.run_until_complete(self._async_execute(state, batchsize))
else:
state = asyncio.run(self._async_execute(state, batchsize))
return state
async def _async_execute(self, state: dict, batchsize: int) -> dict:
"""asynchronously executes the node's logic with multiple graph instances
running in parallel, using a semaphore of some size for concurrency regulation
Args:
state: The current state of the graph.
batchsize: The maximum number of concurrent instances allowed.
Returns:
The updated state with the output key containing the results
aggregated out of all parallel graph instances.
Raises:
KeyError: If the input keys are not found in the state.
"""
# interprets input keys based on the provided input expression
input_keys = self.get_input_keys(state) input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys # fetches data from the state based on the input keys
input_data = [state[key] for key in input_keys] input_data = [state[key] for key in input_keys]
user_prompt = input_data[0] user_prompt = input_data[0]
urls = input_data[1] urls = input_data[1]
graph_instance = self.node_config.get("graph_instance", None) graph_instance = self.node_config.get("graph_instance", None)
if graph_instance is None: if graph_instance is None:
raise ValueError("Graph instance is required for graph iteration.") raise ValueError("graph instance is required for concurrent execution")
# set the prompt and source for each url # sets the prompt for the graph instance
graph_instance.prompt = user_prompt graph_instance.prompt = user_prompt
graphs_instances = []
participants = []
# semaphore to limit the number of concurrent tasks
semaphore = asyncio.Semaphore(batchsize)
async def _async_run(graph):
async with semaphore:
return await asyncio.to_thread(graph.run)
# creates a deepcopy of the graph instance for each endpoint
for url in urls: for url in urls:
# make a copy of the graph instance instance = copy.deepcopy(graph_instance)
copy_graph_instance = copy.copy(graph_instance) instance.source = url
copy_graph_instance.source = url
graphs_instances.append(copy_graph_instance)
# run the graph for each url and use tqdm for progress bar participants.append(instance)
graphs_answers = []
for graph in tqdm(graphs_instances, desc="Processing Graph Instances", disable=not self.verbose): futures = [_async_run(graph) for graph in participants]
result = graph.run()
graphs_answers.append(result) answers = await tqdm.gather(
*futures, desc="processing graph instances", disable=not self.verbose
)
state.update({self.output[0]: answers})
state.update({self.output[0]: graphs_answers})
return state return state