From f60aa3acde3c9bead2250e81eb8fc77d2e1e450c Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Mon, 19 Aug 2024 11:22:40 +0200 Subject: [PATCH] feat: add async call --- .../nodes/generate_answer_from_image_node.py | 109 ++++++++++-------- 1 file changed, 58 insertions(+), 51 deletions(-) diff --git a/scrapegraphai/nodes/generate_answer_from_image_node.py b/scrapegraphai/nodes/generate_answer_from_image_node.py index 7f4aa687..7d145f0e 100644 --- a/scrapegraphai/nodes/generate_answer_from_image_node.py +++ b/scrapegraphai/nodes/generate_answer_from_image_node.py @@ -1,9 +1,7 @@ -""" -generate answer from image module -""" import base64 +import asyncio from typing import List, Optional -import requests +import aiohttp from .base_node import BaseNode from ..utils.logging import get_logger @@ -22,10 +20,46 @@ class GenerateAnswerFromImageNode(BaseNode): ): super().__init__(node_name, "node", input, output, 2, node_config) - def execute(self, state: dict) -> dict: + async def process_image(self, session, api_key, image_data, user_prompt): + # Convert image data to base64 + base64_image = base64.b64encode(image_data).decode('utf-8') + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + payload = { + "model": self.node_config["config"]["llm"]["model"], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": user_prompt + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + } + ] + } + ], + "max_tokens": 300 + } + + async with session.post("https://api.openai.com/v1/chat/completions", + headers=headers, json=payload) as response: + result = await response.json() + return result.get('choices', [{}])[0].get('message', {}).get('content', 'No response') + + async def execute_async(self, state: dict) -> dict: """ Processes images from the state, generates answers, - consolidates the results, and updates the state. + consolidates the results, and updates the state asynchronously. """ self.logger.info(f"--- Executing {self.node_name} Node ---") @@ -39,54 +73,27 @@ class GenerateAnswerFromImageNode(BaseNode): is not supported. Supported models are: {', '.join(supported_models)}.""") - if self.node_config["config"]["llm"]["model"].startswith("gpt"): - api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "") + api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "") - for image_data in images: - base64_image = base64.b64encode(image_data).decode('utf-8') + async with aiohttp.ClientSession() as session: + tasks = [ + self.process_image(session, api_key, image_data, + state.get("user_prompt", "Extract information from the image")) + for image_data in images + ] - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } + analyses = await asyncio.gather(*tasks) - payload = { - "model": self.node_config["config"]["llm"]["model"], - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": state.get("user_prompt", - "Extract information from the image") - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - } - } - ] - } - ], - "max_tokens": 300 - } + consolidated_analysis = " ".join(analyses) - response = requests.post("https://api.openai.com/v1/chat/completions", - headers=headers, - json=payload, - timeout=10) - result = response.json() + state['answer'] = { + "consolidated_analysis": consolidated_analysis + } - response_text = result.get('choices', - [{}])[0].get('message', {}).get('content', 'No response') - analyses.append(response_text) + return state - consolidated_analysis = " ".join(analyses) - - state['answer'] = { - "consolidated_analysis": consolidated_analysis - } - - return state + def execute(self, state: dict) -> dict: + """ + Wrapper to run the asynchronous execute_async function in a synchronous context. + """ + return asyncio.run(self.execute_async(state))