feat: refactoring of the code

This commit is contained in:
Marco Vinciguerra 2024-08-18 20:53:35 +02:00
parent 8e3d5deaaa
commit 5eb3cff64f
5 changed files with 82 additions and 30 deletions

View File

@ -29,8 +29,8 @@ graph_config = {
# ************************************************ # ************************************************
smart_scraper_graph = ScreenshotScraperGraph( smart_scraper_graph = ScreenshotScraperGraph(
prompt="List me the email of the company", prompt="List me all the projects",
source="https://scrapegraphai.com/", source="https://perinim.github.io/projects/",
config=graph_config config=graph_config
) )

View File

@ -1,3 +1,6 @@
"""
fetch_screen_node module
"""
from typing import List, Optional from typing import List, Optional
from playwright.sync_api import sync_playwright from playwright.sync_api import sync_playwright
from .base_node import BaseNode from .base_node import BaseNode
@ -18,8 +21,10 @@ class FetchScreenNode(BaseNode):
self.url = node_config.get("link") self.url = node_config.get("link")
def execute(self, state: dict) -> dict: def execute(self, state: dict) -> dict:
"""Captures screenshots from the input URL and stores them in the state dictionary as bytes.""" """
Captures screenshots from the input URL and stores them in the state dictionary as bytes.
"""
screenshots = [] screenshots = []
with sync_playwright() as p: with sync_playwright() as p:
@ -29,28 +34,25 @@ class FetchScreenNode(BaseNode):
viewport_height = page.viewport_size["height"] viewport_height = page.viewport_size["height"]
# Initialize screenshot counter
screenshot_counter = 1 screenshot_counter = 1
# List to keep track of screenshot data
screenshot_data_list = [] screenshot_data_list = []
# Function to capture screenshots
def capture_screenshot(scroll_position, counter): def capture_screenshot(scroll_position, counter):
page.evaluate(f"window.scrollTo(0, {scroll_position});") page.evaluate(f"window.scrollTo(0, {scroll_position});")
screenshot_data = page.screenshot() screenshot_data = page.screenshot()
screenshot_data_list.append(screenshot_data) screenshot_data_list.append(screenshot_data)
# Capture screenshots capture_screenshot(0, screenshot_counter)
capture_screenshot(0, screenshot_counter) # First screenshot
screenshot_counter += 1 screenshot_counter += 1
capture_screenshot(viewport_height, screenshot_counter) # Second screenshot capture_screenshot(viewport_height, screenshot_counter)
browser.close() browser.close()
# Store screenshot data as bytes in the state dictionary
for screenshot_data in screenshot_data_list: for screenshot_data in screenshot_data_list:
screenshots.append(screenshot_data) screenshots.append(screenshot_data)
state["link"] = self.url state["link"] = self.url
state['screenshots'] = screenshots state['screenshots'] = screenshots
return state return state

View File

@ -6,7 +6,7 @@ import requests
class GenerateAnswerFromImageNode(BaseNode): class GenerateAnswerFromImageNode(BaseNode):
""" """
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
and updates the state with the generated answers. and updates the state with the consolidated answers.
""" """
def __init__( def __init__(
@ -19,20 +19,28 @@ class GenerateAnswerFromImageNode(BaseNode):
super().__init__(node_name, "node", input, output, 2, node_config) super().__init__(node_name, "node", input, output, 2, node_config)
def execute(self, state: dict) -> dict: def execute(self, state: dict) -> dict:
"""Processes images from the state, generates answers, and updates the state.""" """
# Retrieve the image data from the state dictionary Processes images from the state, generates answers,
consolidates the results, and updates the state.
"""
images = state.get('screenshots', []) images = state.get('screenshots', [])
results = [] analyses = []
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
supported_models = ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
if self.node_config["config"]["llm"]["model"] not in supported_models:
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
is not supported. Supported models are:
{', '.join(supported_models)}.""")
# OpenAI API Key
for image_data in images: for image_data in images:
# Encode the image data to base64
base64_image = base64.b64encode(image_data).decode('utf-8') base64_image = base64.b64encode(image_data).decode('utf-8')
# Prepare API request
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.node_config.get("config").get("llm").get("api_key")}" "Authorization": f"Bearer {api_key}"
} }
payload = { payload = {
@ -43,7 +51,8 @@ class GenerateAnswerFromImageNode(BaseNode):
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": state.get("user_prompt", "Extract information from the image") "text": state.get("user_prompt",
"Extract information from the image")
}, },
{ {
"type": "image_url", "type": "image_url",
@ -57,18 +66,20 @@ class GenerateAnswerFromImageNode(BaseNode):
"max_tokens": 300 "max_tokens": 300
} }
# Make the API request response = requests.post("https://api.openai.com/v1/chat/completions",
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) headers=headers,
json=payload,
timeout=10 )
result = response.json() result = response.json()
# Extract the response text response_text = result.get('choices',
response_text = result.get('choices', [{}])[0].get('message', {}).get('content', 'No response') [{}])[0].get('message', {}).get('content', 'No response')
analyses.append(response_text)
# Append the result to the results list consolidated_analysis = " ".join(analyses)
results.append({
"analysis": response_text state['answer'] = {
}) "consolidated_analysis": consolidated_analysis
}
# Update the state dictionary with the results
state['answer'] = results
return state return state

View File

@ -0,0 +1,39 @@
import os
import pytest
import json
from scrapegraphai.graphs import ScreenshotScraperGraph
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Define a fixture for the graph configuration
@pytest.fixture
def graph_config():
"""
Creation of the graph
"""
return {
"llm": {
"api_key": os.getenv("OPENAI_API_KEY"),
"model": "gpt-4o",
},
"verbose": True,
"headless": False,
}
def test_screenshot_scraper_graph(graph_config):
"""
test
"""
smart_scraper_graph = ScreenshotScraperGraph(
prompt="List me all the projects",
source="https://perinim.github.io/projects/",
config=graph_config
)
result = smart_scraper_graph.run()
assert result is not None, "The result should not be None"
print(json.dumps(result, indent=4))