diff --git a/examples/openai/screenshot_scraper.py b/examples/openai/screenshot_scraper.py new file mode 100644 index 00000000..826dcc50 --- /dev/null +++ b/examples/openai/screenshot_scraper.py @@ -0,0 +1,38 @@ +""" +Basic example of scraping pipeline using SmartScraper +""" + +import os +import json +from dotenv import load_dotenv +from scrapegraphai.graphs import ScreenshotScraperGraph +from scrapegraphai.utils import prettify_exec_info + +load_dotenv() + +# ************************************************ +# Define the configuration for the graph +# ************************************************ + + +graph_config = { + "llm": { + "api_key": os.getenv("OPENAI_API_KEY"), + "model": "gpt-4o", + }, + "verbose": True, + "headless": False, +} + +# ************************************************ +# Create the ScreenshotScraperGraph instance and run it +# ************************************************ + +smart_scraper_graph = ScreenshotScraperGraph( + prompt="List me all the projects", + source="https://perinim.github.io/projects/", + config=graph_config +) + +result = smart_scraper_graph.run() +print(json.dumps(result, indent=4)) diff --git a/examples/openai/smart_scraper_openai.py b/examples/openai/smart_scraper_openai.py index 4299ec29..119f67e5 100644 --- a/examples/openai/smart_scraper_openai.py +++ b/examples/openai/smart_scraper_openai.py @@ -4,9 +4,10 @@ Basic example of scraping pipeline using SmartScraper import os import json +from dotenv import load_dotenv from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info -from dotenv import load_dotenv + load_dotenv() # ************************************************ @@ -17,7 +18,7 @@ load_dotenv() graph_config = { "llm": { "api_key": os.getenv("OPENAI_API_KEY"), - "model": "gpt-3.5-turbo", + "model": "gpt-4o", }, "verbose": True, "headless": False, diff --git a/scrapegraphai/graphs/__init__.py b/scrapegraphai/graphs/__init__.py index 26a0b9e1..6dda222d 100644 --- a/scrapegraphai/graphs/__init__.py +++ b/scrapegraphai/graphs/__init__.py @@ -24,3 +24,4 @@ from .script_creator_multi_graph import ScriptCreatorMultiGraph from .markdown_scraper_graph import MDScraperGraph from .markdown_scraper_multi_graph import MDScraperMultiGraph from .search_link_graph import SearchLinkGraph +from .screenshot_scraper_graph import ScreenshotScraperGraph diff --git a/scrapegraphai/graphs/screenshot_scraper_graph.py b/scrapegraphai/graphs/screenshot_scraper_graph.py new file mode 100644 index 00000000..f3ce608d --- /dev/null +++ b/scrapegraphai/graphs/screenshot_scraper_graph.py @@ -0,0 +1,82 @@ +""" +ScreenshotScraperGraph Module +""" +from typing import Optional +import logging +from pydantic import BaseModel +from .base_graph import BaseGraph +from .abstract_graph import AbstractGraph +from ..nodes import ( FetchScreenNode, GenerateAnswerFromImageNode, ) + +class ScreenshotScraperGraph(AbstractGraph): + """ + A graph instance representing the web scraping workflow for images. + + Attributes: + prompt (str): The input text to be scraped. + config (dict): Configuration parameters for the graph. + source (str): The source URL or image link to scrape from. + + Methods: + __init__(prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None) + Initializes the ScreenshotScraperGraph instance with the given prompt, + source, and configuration parameters. + + _create_graph() + Creates a graph of nodes representing the web scraping workflow for images. + + run() + Executes the scraping process and returns the answer to the prompt. + """ + + def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): + super().__init__(prompt, config, source, schema) + + + def _create_graph(self) -> BaseGraph: + """ + Creates the graph of nodes representing the workflow for web scraping with images. + + Returns: + BaseGraph: A graph instance representing the web scraping workflow for images. + """ + fetch_screen_node = FetchScreenNode( + input="url", + output=["screenshots"], + node_config={ + "link": self.source + } + ) + generate_answer_from_image_node = GenerateAnswerFromImageNode( + input="screenshots", + output=["answer"], + node_config={ + "config": self.config + } + ) + + return BaseGraph( + nodes=[ + fetch_screen_node, + generate_answer_from_image_node, + ], + edges=[ + (fetch_screen_node, generate_answer_from_image_node), + ], + entry_point=fetch_screen_node, + graph_name=self.__class__.__name__ + ) + + def run(self) -> str: + """ + Executes the scraping process and returns the answer to the prompt. + + Returns: + str: The answer to the prompt. + """ + + inputs = {"user_prompt": self.prompt} + self.final_state, self.execution_info = self.graph.execute(inputs) + + return self.final_state.get("answer", "No answer found.") + \ No newline at end of file diff --git a/scrapegraphai/nodes/__init__.py b/scrapegraphai/nodes/__init__.py index 856438cd..dd1c3fcc 100644 --- a/scrapegraphai/nodes/__init__.py +++ b/scrapegraphai/nodes/__init__.py @@ -19,4 +19,6 @@ from .generate_answer_pdf_node import GenerateAnswerPDFNode from .graph_iterator_node import GraphIteratorNode from .merge_answers_node import MergeAnswersNode from .generate_answer_omni_node import GenerateAnswerOmniNode -from .merge_generated_scripts import MergeGeneratedScriptsNode +from .merge_generated_scripts import MergeGeneratedScriptsNode +from .fetch_screen_node import FetchScreenNode +from .generate_answer_from_image_node import GenerateAnswerFromImageNode diff --git a/scrapegraphai/nodes/fetch_screen_node.py b/scrapegraphai/nodes/fetch_screen_node.py new file mode 100644 index 00000000..0bb71c37 --- /dev/null +++ b/scrapegraphai/nodes/fetch_screen_node.py @@ -0,0 +1,55 @@ +""" +fetch_screen_node module +""" +from typing import List, Optional +from playwright.sync_api import sync_playwright +from .base_node import BaseNode +from ..utils.logging import get_logger + +class FetchScreenNode(BaseNode): + """ + FetchScreenNode captures screenshots from a given URL and stores the image data as bytes. + """ + + def __init__( + self, + input: str, + output: List[str], + node_config: Optional[dict] = None, + node_name: str = "FetchScreenNode", + ): + super().__init__(node_name, "node", input, output, 2, node_config) + self.url = node_config.get("link") + + def execute(self, state: dict) -> dict: + """ + Captures screenshots from the input URL and stores them in the state dictionary as bytes. + """ + self.logger.info(f"--- Executing {self.node_name} Node ---") + + with sync_playwright() as p: + browser = p.chromium.launch() + page = browser.new_page() + page.goto(self.url) + + viewport_height = page.viewport_size["height"] + + screenshot_counter = 1 + + screenshot_data_list = [] + + def capture_screenshot(scroll_position, counter): + page.evaluate(f"window.scrollTo(0, {scroll_position});") + screenshot_data = page.screenshot() + screenshot_data_list.append(screenshot_data) + + capture_screenshot(0, screenshot_counter) + screenshot_counter += 1 + capture_screenshot(viewport_height, screenshot_counter) + + browser.close() + + state["link"] = self.url + state['screenshots'] = screenshot_data_list + + return state diff --git a/scrapegraphai/nodes/generate_answer_from_image_node.py b/scrapegraphai/nodes/generate_answer_from_image_node.py new file mode 100644 index 00000000..4cc93d18 --- /dev/null +++ b/scrapegraphai/nodes/generate_answer_from_image_node.py @@ -0,0 +1,115 @@ +""" +GenerateAnswerFromImageNode Module +""" +import base64 +import asyncio +from typing import List, Optional +import aiohttp +from .base_node import BaseNode +from ..utils.logging import get_logger + +class GenerateAnswerFromImageNode(BaseNode): + """ + GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API + and updates the state with the consolidated answers. + """ + + def __init__( + self, + input: str, + output: List[str], + node_config: Optional[dict] = None, + node_name: str = "GenerateAnswerFromImageNode", + ): + super().__init__(node_name, "node", input, output, 2, node_config) + + async def process_image(self, session, api_key, image_data, user_prompt): + """ + async process image + """ + base64_image = base64.b64encode(image_data).decode('utf-8') + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + payload = { + "model": self.node_config["config"]["llm"]["model"], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": user_prompt + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + } + ] + } + ], + "max_tokens": 300 + } + + async with session.post("https://api.openai.com/v1/chat/completions", + headers=headers, json=payload) as response: + result = await response.json() + return result.get('choices', [{}])[0].get('message', {}).get('content', 'No response') + + async def execute_async(self, state: dict) -> dict: + """ + Processes images from the state, generates answers, + consolidates the results, and updates the state asynchronously. + """ + self.logger.info(f"--- Executing {self.node_name} Node ---") + + images = state.get('screenshots', []) + analyses = [] + + supported_models = ("gpt-4o", "gpt-4o-mini", "gpt-4-turbo") + + if self.node_config["config"]["llm"]["model"] not in supported_models: + raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}' + is not supported. Supported models are: + {', '.join(supported_models)}.""") + + api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "") + + async with aiohttp.ClientSession() as session: + tasks = [ + self.process_image(session, api_key, image_data, + state.get("user_prompt", "Extract information from the image")) + for image_data in images + ] + + analyses = await asyncio.gather(*tasks) + + consolidated_analysis = " ".join(analyses) + + state['answer'] = { + "consolidated_analysis": consolidated_analysis + } + + return state + + def execute(self, state: dict) -> dict: + """ + Wrapper to run the asynchronous execute_async function in a synchronous context. + """ + try: + eventloop = asyncio.get_event_loop() + except RuntimeError: + eventloop = None + + if eventloop and eventloop.is_running(): + task = eventloop.create_task(self.execute_async(state)) + state = eventloop.run_until_complete(asyncio.gather(task))[0] + else: + state = asyncio.run(self.execute_async(state)) + + return state diff --git a/tests/graphs/scrape_json_ollama.py b/tests/graphs/scrape_json_ollama_test.py similarity index 100% rename from tests/graphs/scrape_json_ollama.py rename to tests/graphs/scrape_json_ollama_test.py diff --git a/tests/graphs/screenshot_scraper_test.py b/tests/graphs/screenshot_scraper_test.py new file mode 100644 index 00000000..c4f436d2 --- /dev/null +++ b/tests/graphs/screenshot_scraper_test.py @@ -0,0 +1,39 @@ +import os +import pytest +import json +from scrapegraphai.graphs import ScreenshotScraperGraph +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Define a fixture for the graph configuration +@pytest.fixture +def graph_config(): + """ + Creation of the graph + """ + return { + "llm": { + "api_key": os.getenv("OPENAI_API_KEY"), + "model": "gpt-4o", + }, + "verbose": True, + "headless": False, + } + +def test_screenshot_scraper_graph(graph_config): + """ + test + """ + smart_scraper_graph = ScreenshotScraperGraph( + prompt="List me all the projects", + source="https://perinim.github.io/projects/", + config=graph_config + ) + + result = smart_scraper_graph.run() + + assert result is not None, "The result should not be None" + + print(json.dumps(result, indent=4))