diff --git a/examples/openai/screenshot_scraper.py b/examples/openai/screenshot_scraper.py index 795dea9d..826dcc50 100644 --- a/examples/openai/screenshot_scraper.py +++ b/examples/openai/screenshot_scraper.py @@ -29,8 +29,8 @@ graph_config = { # ************************************************ smart_scraper_graph = ScreenshotScraperGraph( - prompt="List me the email of the company", - source="https://scrapegraphai.com/", + prompt="List me all the projects", + source="https://perinim.github.io/projects/", config=graph_config ) diff --git a/scrapegraphai/nodes/fetch_screen_node.py b/scrapegraphai/nodes/fetch_screen_node.py index c869966b..1477f4e4 100644 --- a/scrapegraphai/nodes/fetch_screen_node.py +++ b/scrapegraphai/nodes/fetch_screen_node.py @@ -1,3 +1,6 @@ +""" +fetch_screen_node module +""" from typing import List, Optional from playwright.sync_api import sync_playwright from .base_node import BaseNode @@ -18,8 +21,10 @@ class FetchScreenNode(BaseNode): self.url = node_config.get("link") def execute(self, state: dict) -> dict: - """Captures screenshots from the input URL and stores them in the state dictionary as bytes.""" - + """ + Captures screenshots from the input URL and stores them in the state dictionary as bytes. + """ + screenshots = [] with sync_playwright() as p: @@ -29,28 +34,25 @@ class FetchScreenNode(BaseNode): viewport_height = page.viewport_size["height"] - # Initialize screenshot counter screenshot_counter = 1 - # List to keep track of screenshot data screenshot_data_list = [] - # Function to capture screenshots def capture_screenshot(scroll_position, counter): page.evaluate(f"window.scrollTo(0, {scroll_position});") screenshot_data = page.screenshot() screenshot_data_list.append(screenshot_data) - # Capture screenshots - capture_screenshot(0, screenshot_counter) # First screenshot + capture_screenshot(0, screenshot_counter) screenshot_counter += 1 - capture_screenshot(viewport_height, screenshot_counter) # Second screenshot + capture_screenshot(viewport_height, screenshot_counter) browser.close() - # Store screenshot data as bytes in the state dictionary for screenshot_data in screenshot_data_list: screenshots.append(screenshot_data) + state["link"] = self.url state['screenshots'] = screenshots + return state diff --git a/scrapegraphai/nodes/generate_answer_from_image_node.py b/scrapegraphai/nodes/generate_answer_from_image_node.py index 8844990b..a1d2769b 100644 --- a/scrapegraphai/nodes/generate_answer_from_image_node.py +++ b/scrapegraphai/nodes/generate_answer_from_image_node.py @@ -6,7 +6,7 @@ import requests class GenerateAnswerFromImageNode(BaseNode): """ GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API - and updates the state with the generated answers. + and updates the state with the consolidated answers. """ def __init__( @@ -19,20 +19,28 @@ class GenerateAnswerFromImageNode(BaseNode): super().__init__(node_name, "node", input, output, 2, node_config) def execute(self, state: dict) -> dict: - """Processes images from the state, generates answers, and updates the state.""" - # Retrieve the image data from the state dictionary + """ + Processes images from the state, generates answers, + consolidates the results, and updates the state. + """ images = state.get('screenshots', []) - results = [] + analyses = [] + + api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "") + + supported_models = ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"] + + if self.node_config["config"]["llm"]["model"] not in supported_models: + raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}' + is not supported. Supported models are: + {', '.join(supported_models)}.""") - # OpenAI API Key for image_data in images: - # Encode the image data to base64 base64_image = base64.b64encode(image_data).decode('utf-8') - # Prepare API request headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.node_config.get("config").get("llm").get("api_key")}" + "Authorization": f"Bearer {api_key}" } payload = { @@ -43,7 +51,8 @@ class GenerateAnswerFromImageNode(BaseNode): "content": [ { "type": "text", - "text": state.get("user_prompt", "Extract information from the image") + "text": state.get("user_prompt", + "Extract information from the image") }, { "type": "image_url", @@ -57,18 +66,20 @@ class GenerateAnswerFromImageNode(BaseNode): "max_tokens": 300 } - # Make the API request - response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) + response = requests.post("https://api.openai.com/v1/chat/completions", + headers=headers, + json=payload, + timeout=10 ) result = response.json() - # Extract the response text - response_text = result.get('choices', [{}])[0].get('message', {}).get('content', 'No response') + response_text = result.get('choices', + [{}])[0].get('message', {}).get('content', 'No response') + analyses.append(response_text) - # Append the result to the results list - results.append({ - "analysis": response_text - }) + consolidated_analysis = " ".join(analyses) + + state['answer'] = { + "consolidated_analysis": consolidated_analysis + } - # Update the state dictionary with the results - state['answer'] = results return state diff --git a/tests/graphs/scrape_json_ollama.py b/tests/graphs/scrape_json_ollama_test.py similarity index 100% rename from tests/graphs/scrape_json_ollama.py rename to tests/graphs/scrape_json_ollama_test.py diff --git a/tests/graphs/screenshot_scraper_test.py b/tests/graphs/screenshot_scraper_test.py new file mode 100644 index 00000000..c4f436d2 --- /dev/null +++ b/tests/graphs/screenshot_scraper_test.py @@ -0,0 +1,39 @@ +import os +import pytest +import json +from scrapegraphai.graphs import ScreenshotScraperGraph +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Define a fixture for the graph configuration +@pytest.fixture +def graph_config(): + """ + Creation of the graph + """ + return { + "llm": { + "api_key": os.getenv("OPENAI_API_KEY"), + "model": "gpt-4o", + }, + "verbose": True, + "headless": False, + } + +def test_screenshot_scraper_graph(graph_config): + """ + test + """ + smart_scraper_graph = ScreenshotScraperGraph( + prompt="List me all the projects", + source="https://perinim.github.io/projects/", + config=graph_config + ) + + result = smart_scraper_graph.run() + + assert result is not None, "The result should not be None" + + print(json.dumps(result, indent=4))