mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
feat: refactoring of the code
This commit is contained in:
parent
8e3d5deaaa
commit
5eb3cff64f
@ -29,8 +29,8 @@ graph_config = {
|
|||||||
# ************************************************
|
# ************************************************
|
||||||
|
|
||||||
smart_scraper_graph = ScreenshotScraperGraph(
|
smart_scraper_graph = ScreenshotScraperGraph(
|
||||||
prompt="List me the email of the company",
|
prompt="List me all the projects",
|
||||||
source="https://scrapegraphai.com/",
|
source="https://perinim.github.io/projects/",
|
||||||
config=graph_config
|
config=graph_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
"""
|
||||||
|
fetch_screen_node module
|
||||||
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from playwright.sync_api import sync_playwright
|
from playwright.sync_api import sync_playwright
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
@ -18,7 +21,9 @@ class FetchScreenNode(BaseNode):
|
|||||||
self.url = node_config.get("link")
|
self.url = node_config.get("link")
|
||||||
|
|
||||||
def execute(self, state: dict) -> dict:
|
def execute(self, state: dict) -> dict:
|
||||||
"""Captures screenshots from the input URL and stores them in the state dictionary as bytes."""
|
"""
|
||||||
|
Captures screenshots from the input URL and stores them in the state dictionary as bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
screenshots = []
|
screenshots = []
|
||||||
|
|
||||||
@ -29,28 +34,25 @@ class FetchScreenNode(BaseNode):
|
|||||||
|
|
||||||
viewport_height = page.viewport_size["height"]
|
viewport_height = page.viewport_size["height"]
|
||||||
|
|
||||||
# Initialize screenshot counter
|
|
||||||
screenshot_counter = 1
|
screenshot_counter = 1
|
||||||
|
|
||||||
# List to keep track of screenshot data
|
|
||||||
screenshot_data_list = []
|
screenshot_data_list = []
|
||||||
|
|
||||||
# Function to capture screenshots
|
|
||||||
def capture_screenshot(scroll_position, counter):
|
def capture_screenshot(scroll_position, counter):
|
||||||
page.evaluate(f"window.scrollTo(0, {scroll_position});")
|
page.evaluate(f"window.scrollTo(0, {scroll_position});")
|
||||||
screenshot_data = page.screenshot()
|
screenshot_data = page.screenshot()
|
||||||
screenshot_data_list.append(screenshot_data)
|
screenshot_data_list.append(screenshot_data)
|
||||||
|
|
||||||
# Capture screenshots
|
capture_screenshot(0, screenshot_counter)
|
||||||
capture_screenshot(0, screenshot_counter) # First screenshot
|
|
||||||
screenshot_counter += 1
|
screenshot_counter += 1
|
||||||
capture_screenshot(viewport_height, screenshot_counter) # Second screenshot
|
capture_screenshot(viewport_height, screenshot_counter)
|
||||||
|
|
||||||
browser.close()
|
browser.close()
|
||||||
|
|
||||||
# Store screenshot data as bytes in the state dictionary
|
|
||||||
for screenshot_data in screenshot_data_list:
|
for screenshot_data in screenshot_data_list:
|
||||||
screenshots.append(screenshot_data)
|
screenshots.append(screenshot_data)
|
||||||
|
|
||||||
state["link"] = self.url
|
state["link"] = self.url
|
||||||
state['screenshots'] = screenshots
|
state['screenshots'] = screenshots
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import requests
|
|||||||
class GenerateAnswerFromImageNode(BaseNode):
|
class GenerateAnswerFromImageNode(BaseNode):
|
||||||
"""
|
"""
|
||||||
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
|
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
|
||||||
and updates the state with the generated answers.
|
and updates the state with the consolidated answers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -19,20 +19,28 @@ class GenerateAnswerFromImageNode(BaseNode):
|
|||||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||||
|
|
||||||
def execute(self, state: dict) -> dict:
|
def execute(self, state: dict) -> dict:
|
||||||
"""Processes images from the state, generates answers, and updates the state."""
|
"""
|
||||||
# Retrieve the image data from the state dictionary
|
Processes images from the state, generates answers,
|
||||||
|
consolidates the results, and updates the state.
|
||||||
|
"""
|
||||||
images = state.get('screenshots', [])
|
images = state.get('screenshots', [])
|
||||||
results = []
|
analyses = []
|
||||||
|
|
||||||
|
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
|
||||||
|
|
||||||
|
supported_models = ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
|
||||||
|
|
||||||
|
if self.node_config["config"]["llm"]["model"] not in supported_models:
|
||||||
|
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
|
||||||
|
is not supported. Supported models are:
|
||||||
|
{', '.join(supported_models)}.""")
|
||||||
|
|
||||||
# OpenAI API Key
|
|
||||||
for image_data in images:
|
for image_data in images:
|
||||||
# Encode the image data to base64
|
|
||||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
|
||||||
# Prepare API request
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {self.node_config.get("config").get("llm").get("api_key")}"
|
"Authorization": f"Bearer {api_key}"
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@ -43,7 +51,8 @@ class GenerateAnswerFromImageNode(BaseNode):
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": state.get("user_prompt", "Extract information from the image")
|
"text": state.get("user_prompt",
|
||||||
|
"Extract information from the image")
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@ -57,18 +66,20 @@ class GenerateAnswerFromImageNode(BaseNode):
|
|||||||
"max_tokens": 300
|
"max_tokens": 300
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the API request
|
response = requests.post("https://api.openai.com/v1/chat/completions",
|
||||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=10 )
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
# Extract the response text
|
response_text = result.get('choices',
|
||||||
response_text = result.get('choices', [{}])[0].get('message', {}).get('content', 'No response')
|
[{}])[0].get('message', {}).get('content', 'No response')
|
||||||
|
analyses.append(response_text)
|
||||||
|
|
||||||
# Append the result to the results list
|
consolidated_analysis = " ".join(analyses)
|
||||||
results.append({
|
|
||||||
"analysis": response_text
|
state['answer'] = {
|
||||||
})
|
"consolidated_analysis": consolidated_analysis
|
||||||
|
}
|
||||||
|
|
||||||
# Update the state dictionary with the results
|
|
||||||
state['answer'] = results
|
|
||||||
return state
|
return state
|
||||||
|
|||||||
39
tests/graphs/screenshot_scraper_test.py
Normal file
39
tests/graphs/screenshot_scraper_test.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from scrapegraphai.graphs import ScreenshotScraperGraph
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Define a fixture for the graph configuration
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_config():
|
||||||
|
"""
|
||||||
|
Creation of the graph
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"llm": {
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
"model": "gpt-4o",
|
||||||
|
},
|
||||||
|
"verbose": True,
|
||||||
|
"headless": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_screenshot_scraper_graph(graph_config):
|
||||||
|
"""
|
||||||
|
test
|
||||||
|
"""
|
||||||
|
smart_scraper_graph = ScreenshotScraperGraph(
|
||||||
|
prompt="List me all the projects",
|
||||||
|
source="https://perinim.github.io/projects/",
|
||||||
|
config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
result = smart_scraper_graph.run()
|
||||||
|
|
||||||
|
assert result is not None, "The result should not be None"
|
||||||
|
|
||||||
|
print(json.dumps(result, indent=4))
|
||||||
Loading…
Reference in New Issue
Block a user