mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
Merge pull request #558 from ScrapeGraphAI/screenshot_scraper
Screenshot scraper integration
This commit is contained in:
commit
860fde8a2c
38
examples/openai/screenshot_scraper.py
Normal file
38
examples/openai/screenshot_scraper.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
Basic example of scraping pipeline using SmartScraper
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import ScreenshotScraperGraph
|
||||
from scrapegraphai.utils import prettify_exec_info
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ************************************************
|
||||
# Define the configuration for the graph
|
||||
# ************************************************
|
||||
|
||||
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"verbose": True,
|
||||
"headless": False,
|
||||
}
|
||||
|
||||
# ************************************************
|
||||
# Create the ScreenshotScraperGraph instance and run it
|
||||
# ************************************************
|
||||
|
||||
smart_scraper_graph = ScreenshotScraperGraph(
|
||||
prompt="List me all the projects",
|
||||
source="https://perinim.github.io/projects/",
|
||||
config=graph_config
|
||||
)
|
||||
|
||||
result = smart_scraper_graph.run()
|
||||
print(json.dumps(result, indent=4))
|
||||
@ -4,9 +4,10 @@ Basic example of scraping pipeline using SmartScraper
|
||||
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SmartScraperGraph
|
||||
from scrapegraphai.utils import prettify_exec_info
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ************************************************
|
||||
@ -17,7 +18,7 @@ load_dotenv()
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"verbose": True,
|
||||
"headless": False,
|
||||
|
||||
@ -24,3 +24,4 @@ from .script_creator_multi_graph import ScriptCreatorMultiGraph
|
||||
from .markdown_scraper_graph import MDScraperGraph
|
||||
from .markdown_scraper_multi_graph import MDScraperMultiGraph
|
||||
from .search_link_graph import SearchLinkGraph
|
||||
from .screenshot_scraper_graph import ScreenshotScraperGraph
|
||||
|
||||
82
scrapegraphai/graphs/screenshot_scraper_graph.py
Normal file
82
scrapegraphai/graphs/screenshot_scraper_graph.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""
|
||||
ScreenshotScraperGraph Module
|
||||
"""
|
||||
from typing import Optional
|
||||
import logging
|
||||
from pydantic import BaseModel
|
||||
from .base_graph import BaseGraph
|
||||
from .abstract_graph import AbstractGraph
|
||||
from ..nodes import ( FetchScreenNode, GenerateAnswerFromImageNode, )
|
||||
|
||||
class ScreenshotScraperGraph(AbstractGraph):
|
||||
"""
|
||||
A graph instance representing the web scraping workflow for images.
|
||||
|
||||
Attributes:
|
||||
prompt (str): The input text to be scraped.
|
||||
config (dict): Configuration parameters for the graph.
|
||||
source (str): The source URL or image link to scrape from.
|
||||
|
||||
Methods:
|
||||
__init__(prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None)
|
||||
Initializes the ScreenshotScraperGraph instance with the given prompt,
|
||||
source, and configuration parameters.
|
||||
|
||||
_create_graph()
|
||||
Creates a graph of nodes representing the web scraping workflow for images.
|
||||
|
||||
run()
|
||||
Executes the scraping process and returns the answer to the prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
|
||||
super().__init__(prompt, config, source, schema)
|
||||
|
||||
|
||||
def _create_graph(self) -> BaseGraph:
|
||||
"""
|
||||
Creates the graph of nodes representing the workflow for web scraping with images.
|
||||
|
||||
Returns:
|
||||
BaseGraph: A graph instance representing the web scraping workflow for images.
|
||||
"""
|
||||
fetch_screen_node = FetchScreenNode(
|
||||
input="url",
|
||||
output=["screenshots"],
|
||||
node_config={
|
||||
"link": self.source
|
||||
}
|
||||
)
|
||||
generate_answer_from_image_node = GenerateAnswerFromImageNode(
|
||||
input="screenshots",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"config": self.config
|
||||
}
|
||||
)
|
||||
|
||||
return BaseGraph(
|
||||
nodes=[
|
||||
fetch_screen_node,
|
||||
generate_answer_from_image_node,
|
||||
],
|
||||
edges=[
|
||||
(fetch_screen_node, generate_answer_from_image_node),
|
||||
],
|
||||
entry_point=fetch_screen_node,
|
||||
graph_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
"""
|
||||
Executes the scraping process and returns the answer to the prompt.
|
||||
|
||||
Returns:
|
||||
str: The answer to the prompt.
|
||||
"""
|
||||
|
||||
inputs = {"user_prompt": self.prompt}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
|
||||
@ -19,4 +19,6 @@ from .generate_answer_pdf_node import GenerateAnswerPDFNode
|
||||
from .graph_iterator_node import GraphIteratorNode
|
||||
from .merge_answers_node import MergeAnswersNode
|
||||
from .generate_answer_omni_node import GenerateAnswerOmniNode
|
||||
from .merge_generated_scripts import MergeGeneratedScriptsNode
|
||||
from .merge_generated_scripts import MergeGeneratedScriptsNode
|
||||
from .fetch_screen_node import FetchScreenNode
|
||||
from .generate_answer_from_image_node import GenerateAnswerFromImageNode
|
||||
|
||||
55
scrapegraphai/nodes/fetch_screen_node.py
Normal file
55
scrapegraphai/nodes/fetch_screen_node.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""
|
||||
fetch_screen_node module
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from playwright.sync_api import sync_playwright
|
||||
from .base_node import BaseNode
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
class FetchScreenNode(BaseNode):
|
||||
"""
|
||||
FetchScreenNode captures screenshots from a given URL and stores the image data as bytes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: str,
|
||||
output: List[str],
|
||||
node_config: Optional[dict] = None,
|
||||
node_name: str = "FetchScreenNode",
|
||||
):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
self.url = node_config.get("link")
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
Captures screenshots from the input URL and stores them in the state dictionary as bytes.
|
||||
"""
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
with sync_playwright() as p:
|
||||
browser = p.chromium.launch()
|
||||
page = browser.new_page()
|
||||
page.goto(self.url)
|
||||
|
||||
viewport_height = page.viewport_size["height"]
|
||||
|
||||
screenshot_counter = 1
|
||||
|
||||
screenshot_data_list = []
|
||||
|
||||
def capture_screenshot(scroll_position, counter):
|
||||
page.evaluate(f"window.scrollTo(0, {scroll_position});")
|
||||
screenshot_data = page.screenshot()
|
||||
screenshot_data_list.append(screenshot_data)
|
||||
|
||||
capture_screenshot(0, screenshot_counter)
|
||||
screenshot_counter += 1
|
||||
capture_screenshot(viewport_height, screenshot_counter)
|
||||
|
||||
browser.close()
|
||||
|
||||
state["link"] = self.url
|
||||
state['screenshots'] = screenshot_data_list
|
||||
|
||||
return state
|
||||
115
scrapegraphai/nodes/generate_answer_from_image_node.py
Normal file
115
scrapegraphai/nodes/generate_answer_from_image_node.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""
|
||||
GenerateAnswerFromImageNode Module
|
||||
"""
|
||||
import base64
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
import aiohttp
|
||||
from .base_node import BaseNode
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
class GenerateAnswerFromImageNode(BaseNode):
|
||||
"""
|
||||
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
|
||||
and updates the state with the consolidated answers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: str,
|
||||
output: List[str],
|
||||
node_config: Optional[dict] = None,
|
||||
node_name: str = "GenerateAnswerFromImageNode",
|
||||
):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
async def process_image(self, session, api_key, image_data, user_prompt):
|
||||
"""
|
||||
async process image
|
||||
"""
|
||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.node_config["config"]["llm"]["model"],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 300
|
||||
}
|
||||
|
||||
async with session.post("https://api.openai.com/v1/chat/completions",
|
||||
headers=headers, json=payload) as response:
|
||||
result = await response.json()
|
||||
return result.get('choices', [{}])[0].get('message', {}).get('content', 'No response')
|
||||
|
||||
async def execute_async(self, state: dict) -> dict:
|
||||
"""
|
||||
Processes images from the state, generates answers,
|
||||
consolidates the results, and updates the state asynchronously.
|
||||
"""
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
images = state.get('screenshots', [])
|
||||
analyses = []
|
||||
|
||||
supported_models = ("gpt-4o", "gpt-4o-mini", "gpt-4-turbo")
|
||||
|
||||
if self.node_config["config"]["llm"]["model"] not in supported_models:
|
||||
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
|
||||
is not supported. Supported models are:
|
||||
{', '.join(supported_models)}.""")
|
||||
|
||||
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [
|
||||
self.process_image(session, api_key, image_data,
|
||||
state.get("user_prompt", "Extract information from the image"))
|
||||
for image_data in images
|
||||
]
|
||||
|
||||
analyses = await asyncio.gather(*tasks)
|
||||
|
||||
consolidated_analysis = " ".join(analyses)
|
||||
|
||||
state['answer'] = {
|
||||
"consolidated_analysis": consolidated_analysis
|
||||
}
|
||||
|
||||
return state
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
Wrapper to run the asynchronous execute_async function in a synchronous context.
|
||||
"""
|
||||
try:
|
||||
eventloop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
eventloop = None
|
||||
|
||||
if eventloop and eventloop.is_running():
|
||||
task = eventloop.create_task(self.execute_async(state))
|
||||
state = eventloop.run_until_complete(asyncio.gather(task))[0]
|
||||
else:
|
||||
state = asyncio.run(self.execute_async(state))
|
||||
|
||||
return state
|
||||
39
tests/graphs/screenshot_scraper_test.py
Normal file
39
tests/graphs/screenshot_scraper_test.py
Normal file
@ -0,0 +1,39 @@
|
||||
import os
|
||||
import pytest
|
||||
import json
|
||||
from scrapegraphai.graphs import ScreenshotScraperGraph
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Define a fixture for the graph configuration
|
||||
@pytest.fixture
|
||||
def graph_config():
|
||||
"""
|
||||
Creation of the graph
|
||||
"""
|
||||
return {
|
||||
"llm": {
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"verbose": True,
|
||||
"headless": False,
|
||||
}
|
||||
|
||||
def test_screenshot_scraper_graph(graph_config):
|
||||
"""
|
||||
test
|
||||
"""
|
||||
smart_scraper_graph = ScreenshotScraperGraph(
|
||||
prompt="List me all the projects",
|
||||
source="https://perinim.github.io/projects/",
|
||||
config=graph_config
|
||||
)
|
||||
|
||||
result = smart_scraper_graph.run()
|
||||
|
||||
assert result is not None, "The result should not be None"
|
||||
|
||||
print(json.dumps(result, indent=4))
|
||||
Loading…
Reference in New Issue
Block a user