feat: add async call
Some checks are pending
/ build (push) Waiting to run

This commit is contained in:
Marco Vinciguerra 2024-08-19 11:22:40 +02:00
parent 0bf79b5926
commit f60aa3acde

View File

@ -1,9 +1,7 @@
"""
generate answer from image module
"""
import base64
import asyncio
from typing import List, Optional
import requests
import aiohttp
from .base_node import BaseNode
from ..utils.logging import get_logger
@ -22,10 +20,46 @@ class GenerateAnswerFromImageNode(BaseNode):
):
super().__init__(node_name, "node", input, output, 2, node_config)
def execute(self, state: dict) -> dict:
async def process_image(self, session, api_key, image_data, user_prompt):
# Convert image data to base64
base64_image = base64.b64encode(image_data).decode('utf-8')
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
payload = {
"model": self.node_config["config"]["llm"]["model"],
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300
}
async with session.post("https://api.openai.com/v1/chat/completions",
headers=headers, json=payload) as response:
result = await response.json()
return result.get('choices', [{}])[0].get('message', {}).get('content', 'No response')
async def execute_async(self, state: dict) -> dict:
"""
Processes images from the state, generates answers,
consolidates the results, and updates the state.
consolidates the results, and updates the state asynchronously.
"""
self.logger.info(f"--- Executing {self.node_name} Node ---")
@ -39,54 +73,27 @@ class GenerateAnswerFromImageNode(BaseNode):
is not supported. Supported models are:
{', '.join(supported_models)}.""")
if self.node_config["config"]["llm"]["model"].startswith("gpt"):
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
for image_data in images:
base64_image = base64.b64encode(image_data).decode('utf-8')
async with aiohttp.ClientSession() as session:
tasks = [
self.process_image(session, api_key, image_data,
state.get("user_prompt", "Extract information from the image"))
for image_data in images
]
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
analyses = await asyncio.gather(*tasks)
payload = {
"model": self.node_config["config"]["llm"]["model"],
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": state.get("user_prompt",
"Extract information from the image")
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300
}
consolidated_analysis = " ".join(analyses)
response = requests.post("https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
timeout=10)
result = response.json()
state['answer'] = {
"consolidated_analysis": consolidated_analysis
}
response_text = result.get('choices',
[{}])[0].get('message', {}).get('content', 'No response')
analyses.append(response_text)
return state
consolidated_analysis = " ".join(analyses)
state['answer'] = {
"consolidated_analysis": consolidated_analysis
}
return state
def execute(self, state: dict) -> dict:
"""
Wrapper to run the asynchronous execute_async function in a synchronous context.
"""
return asyncio.run(self.execute_async(state))