feat: refactoring of the code

This commit is contained in:
Marco Vinciguerra 2024-08-18 20:53:35 +02:00
parent 8e3d5deaaa
commit 5eb3cff64f
5 changed files with 82 additions and 30 deletions

View File

@ -29,8 +29,8 @@ graph_config = {
# ************************************************
smart_scraper_graph = ScreenshotScraperGraph(
prompt="List me the email of the company",
source="https://scrapegraphai.com/",
prompt="List me all the projects",
source="https://perinim.github.io/projects/",
config=graph_config
)

View File

@ -1,3 +1,6 @@
"""
fetch_screen_node module
"""
from typing import List, Optional
from playwright.sync_api import sync_playwright
from .base_node import BaseNode
@ -18,8 +21,10 @@ class FetchScreenNode(BaseNode):
self.url = node_config.get("link")
def execute(self, state: dict) -> dict:
"""Captures screenshots from the input URL and stores them in the state dictionary as bytes."""
"""
Captures screenshots from the input URL and stores them in the state dictionary as bytes.
"""
screenshots = []
with sync_playwright() as p:
@ -29,28 +34,25 @@ class FetchScreenNode(BaseNode):
viewport_height = page.viewport_size["height"]
# Initialize screenshot counter
screenshot_counter = 1
# List to keep track of screenshot data
screenshot_data_list = []
# Function to capture screenshots
def capture_screenshot(scroll_position, counter):
page.evaluate(f"window.scrollTo(0, {scroll_position});")
screenshot_data = page.screenshot()
screenshot_data_list.append(screenshot_data)
# Capture screenshots
capture_screenshot(0, screenshot_counter) # First screenshot
capture_screenshot(0, screenshot_counter)
screenshot_counter += 1
capture_screenshot(viewport_height, screenshot_counter) # Second screenshot
capture_screenshot(viewport_height, screenshot_counter)
browser.close()
# Store screenshot data as bytes in the state dictionary
for screenshot_data in screenshot_data_list:
screenshots.append(screenshot_data)
state["link"] = self.url
state['screenshots'] = screenshots
return state

View File

@ -6,7 +6,7 @@ import requests
class GenerateAnswerFromImageNode(BaseNode):
"""
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
and updates the state with the generated answers.
and updates the state with the consolidated answers.
"""
def __init__(
@ -19,20 +19,28 @@ class GenerateAnswerFromImageNode(BaseNode):
super().__init__(node_name, "node", input, output, 2, node_config)
def execute(self, state: dict) -> dict:
"""Processes images from the state, generates answers, and updates the state."""
# Retrieve the image data from the state dictionary
"""
Processes images from the state, generates answers,
consolidates the results, and updates the state.
"""
images = state.get('screenshots', [])
results = []
analyses = []
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
supported_models = ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
if self.node_config["config"]["llm"]["model"] not in supported_models:
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
is not supported. Supported models are:
{', '.join(supported_models)}.""")
# OpenAI API Key
for image_data in images:
# Encode the image data to base64
base64_image = base64.b64encode(image_data).decode('utf-8')
# Prepare API request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.node_config.get("config").get("llm").get("api_key")}"
"Authorization": f"Bearer {api_key}"
}
payload = {
@ -43,7 +51,8 @@ class GenerateAnswerFromImageNode(BaseNode):
"content": [
{
"type": "text",
"text": state.get("user_prompt", "Extract information from the image")
"text": state.get("user_prompt",
"Extract information from the image")
},
{
"type": "image_url",
@ -57,18 +66,20 @@ class GenerateAnswerFromImageNode(BaseNode):
"max_tokens": 300
}
# Make the API request
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
response = requests.post("https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
timeout=10 )
result = response.json()
# Extract the response text
response_text = result.get('choices', [{}])[0].get('message', {}).get('content', 'No response')
response_text = result.get('choices',
[{}])[0].get('message', {}).get('content', 'No response')
analyses.append(response_text)
# Append the result to the results list
results.append({
"analysis": response_text
})
consolidated_analysis = " ".join(analyses)
state['answer'] = {
"consolidated_analysis": consolidated_analysis
}
# Update the state dictionary with the results
state['answer'] = results
return state

View File

@ -0,0 +1,39 @@
import os
import pytest
import json
from scrapegraphai.graphs import ScreenshotScraperGraph
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Define a fixture for the graph configuration
@pytest.fixture
def graph_config():
"""
Creation of the graph
"""
return {
"llm": {
"api_key": os.getenv("OPENAI_API_KEY"),
"model": "gpt-4o",
},
"verbose": True,
"headless": False,
}
def test_screenshot_scraper_graph(graph_config):
"""
test
"""
smart_scraper_graph = ScreenshotScraperGraph(
prompt="List me all the projects",
source="https://perinim.github.io/projects/",
config=graph_config
)
result = smart_scraper_graph.run()
assert result is not None, "The result should not be None"
print(json.dumps(result, indent=4))