feat: refactoring of ScrapeGraph to SmartScraperLiteGraph

This commit is contained in:
Marco Vinciguerra 2024-10-21 10:12:53 +02:00
parent b84883bfd1
commit 52b6bf5fb8
3 changed files with 8 additions and 7 deletions

View File

@ -26,4 +26,4 @@ from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph
from .code_generator_graph import CodeGeneratorGraph from .code_generator_graph import CodeGeneratorGraph
from .depth_search_graph import DepthSearchGraph from .depth_search_graph import DepthSearchGraph
from .smart_scraper_multi_lite_graph import SmartScraperMultiLiteGraph from .smart_scraper_multi_lite_graph import SmartScraperMultiLiteGraph
from .scrape_graph import ScrapeGraph from .smart_scraper_lite_graph import SmartScraperLiteGraph

View File

@ -10,9 +10,9 @@ from ..nodes import (
ParseNode, ParseNode,
) )
class ScrapeGraph(AbstractGraph): class SmartScraperLiteGraph(AbstractGraph):
""" """
ScrapeGraph is a scraping pipeline that automates the process of SmartScraperLiteGraph is a scraping pipeline that automates the process of
extracting information from web pages. extracting information from web pages.
Attributes: Attributes:
@ -30,7 +30,7 @@ class ScrapeGraph(AbstractGraph):
schema (BaseModel): The schema for the graph output. schema (BaseModel): The schema for the graph output.
Example: Example:
>>> scraper = ScraperGraph( >>> scraper = SmartScraperLiteGraph(
... "https://en.wikipedia.org/wiki/Chioggia", ... "https://en.wikipedia.org/wiki/Chioggia",
... {"llm": {"model": "openai/gpt-3.5-turbo"}} ... {"llm": {"model": "openai/gpt-3.5-turbo"}}
... ) ... )
@ -38,7 +38,8 @@ class ScrapeGraph(AbstractGraph):
) )
""" """
def __init__(self, source: str, config: dict, prompt: str = "", schema: Optional[BaseModel] = None): def __init__(self, source: str, config: dict, prompt: str = "",
schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema) super().__init__(prompt, config, source, schema)
self.input_key = "url" if source.startswith("http") else "local_dir" self.input_key = "url" if source.startswith("http") else "local_dir"

View File

@ -6,7 +6,7 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from .base_graph import BaseGraph from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph from .abstract_graph import AbstractGraph
from .scrape_graph import ScrapeGraph from .smart_scraper_lite_graph import SmartScraperLiteGraph
from ..nodes import ( from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode, MergeAnswersNode,
@ -63,7 +63,7 @@ class SmartScraperMultiLiteGraph(AbstractGraph):
input="user_prompt & urls", input="user_prompt & urls",
output=["parsed_doc"], output=["parsed_doc"],
node_config={ node_config={
"graph_instance": ScrapeGraph, "graph_instance": SmartScraperLiteGraph,
"scraper_config": self.copy_config, "scraper_config": self.copy_config,
}, },
schema=self.copy_schema schema=self.copy_schema