mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
This commit is contained in:
parent
0bf79b5926
commit
f60aa3acde
@ -1,9 +1,7 @@
|
|||||||
"""
|
|
||||||
generate answer from image module
|
|
||||||
"""
|
|
||||||
import base64
|
import base64
|
||||||
|
import asyncio
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import requests
|
import aiohttp
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
@ -22,10 +20,46 @@ class GenerateAnswerFromImageNode(BaseNode):
|
|||||||
):
|
):
|
||||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||||
|
|
||||||
def execute(self, state: dict) -> dict:
|
async def process_image(self, session, api_key, image_data, user_prompt):
|
||||||
|
# Convert image data to base64
|
||||||
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.node_config["config"]["llm"]["model"],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": user_prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 300
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post("https://api.openai.com/v1/chat/completions",
|
||||||
|
headers=headers, json=payload) as response:
|
||||||
|
result = await response.json()
|
||||||
|
return result.get('choices', [{}])[0].get('message', {}).get('content', 'No response')
|
||||||
|
|
||||||
|
async def execute_async(self, state: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Processes images from the state, generates answers,
|
Processes images from the state, generates answers,
|
||||||
consolidates the results, and updates the state.
|
consolidates the results, and updates the state asynchronously.
|
||||||
"""
|
"""
|
||||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||||
|
|
||||||
@ -39,54 +73,27 @@ class GenerateAnswerFromImageNode(BaseNode):
|
|||||||
is not supported. Supported models are:
|
is not supported. Supported models are:
|
||||||
{', '.join(supported_models)}.""")
|
{', '.join(supported_models)}.""")
|
||||||
|
|
||||||
if self.node_config["config"]["llm"]["model"].startswith("gpt"):
|
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
|
||||||
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
|
|
||||||
|
|
||||||
for image_data in images:
|
async with aiohttp.ClientSession() as session:
|
||||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
tasks = [
|
||||||
|
self.process_image(session, api_key, image_data,
|
||||||
|
state.get("user_prompt", "Extract information from the image"))
|
||||||
|
for image_data in images
|
||||||
|
]
|
||||||
|
|
||||||
headers = {
|
analyses = await asyncio.gather(*tasks)
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
consolidated_analysis = " ".join(analyses)
|
||||||
"model": self.node_config["config"]["llm"]["model"],
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": state.get("user_prompt",
|
|
||||||
"Extract information from the image")
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 300
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post("https://api.openai.com/v1/chat/completions",
|
state['answer'] = {
|
||||||
headers=headers,
|
"consolidated_analysis": consolidated_analysis
|
||||||
json=payload,
|
}
|
||||||
timeout=10)
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
response_text = result.get('choices',
|
return state
|
||||||
[{}])[0].get('message', {}).get('content', 'No response')
|
|
||||||
analyses.append(response_text)
|
|
||||||
|
|
||||||
consolidated_analysis = " ".join(analyses)
|
def execute(self, state: dict) -> dict:
|
||||||
|
"""
|
||||||
state['answer'] = {
|
Wrapper to run the asynchronous execute_async function in a synchronous context.
|
||||||
"consolidated_analysis": consolidated_analysis
|
"""
|
||||||
}
|
return asyncio.run(self.execute_async(state))
|
||||||
|
|
||||||
return state
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user