feat: update search_link_graph

This commit is contained in:
Marco Vinciguerra 2024-09-23 08:26:36 +02:00
parent 369332b39c
commit de10b281ba
3 changed files with 31 additions and 31 deletions

View File

@ -6,9 +6,11 @@ import logging
from pydantic import BaseModel from pydantic import BaseModel
from .base_graph import BaseGraph from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph from .abstract_graph import AbstractGraph
from ..nodes import ( FetchNode, ParseNode, SearchLinkNode ) from ..nodes import (FetchNode,
SearchLinkNode,
SearchLinksWithContext)
class SearchLinkGraph(AbstractGraph): class SearchLinkGraph(AbstractGraph):
""" """
SearchLinkGraph is a scraping pipeline that automates the process of SearchLinkGraph is a scraping pipeline that automates the process of
extracting information from web pages using a natural language model extracting information from web pages using a natural language model
@ -30,13 +32,7 @@ class SearchLinkGraph(AbstractGraph):
config (dict): Configuration parameters for the graph. config (dict): Configuration parameters for the graph.
schema (BaseModel, optional): The schema for the graph output. Defaults to None. schema (BaseModel, optional): The schema for the graph output. Defaults to None.
Example:
>>> smart_scraper = SearchLinkGraph(
... "List me all the attractions in Chioggia.",
... "https://en.wikipedia.org/wiki/Chioggia",
... {"llm": {"model": "openai/gpt-3.5-turbo"}}
... )
>>> result = smart_scraper.run()
""" """
def __init__(self, source: str, config: dict, schema: Optional[BaseModel] = None): def __init__(self, source: str, config: dict, schema: Optional[BaseModel] = None):
@ -51,28 +47,33 @@ class SearchLinkGraph(AbstractGraph):
Returns: Returns:
BaseGraph: A graph instance representing the web scraping workflow. BaseGraph: A graph instance representing the web scraping workflow.
""" """
fetch_node = FetchNode( fetch_node = FetchNode(
input="url| local_dir", input="url| local_dir",
output=["doc"], output=["doc"],
node_config={ node_config={
"llm_model": self.llm_model, "force": self.config.get("force", False),
"force": self.config.get("force", False), "cut": self.config.get("cut", True),
"cut": self.config.get("cut", True), "loader_kwargs": self.config.get("loader_kwargs", {}),
"loader_kwargs": self.config.get("loader_kwargs", {}), }
} )
)
search_link_node = SearchLinkNode( if self.config.get("llm_style") == (True, None):
input="doc", search_link_node = SearchLinksWithContext(
output=["parsed_doc"], input="doc",
node_config={ output=["parsed_doc"],
"llm_model": self.llm_model, node_config={
"chunk_size": self.model_token, "llm_model": self.llm_model,
"filter_links": self.config.get("filter_links", None), "chunk_size": self.model_token,
"filter_config": self.config.get("filter_config", None) }
} )
) else:
search_link_node = SearchLinkNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
}
)
return BaseGraph( return BaseGraph(
nodes=[ nodes=[

View File

@ -23,3 +23,4 @@ from .merge_generated_scripts import MergeGeneratedScriptsNode
from .fetch_screen_node import FetchScreenNode from .fetch_screen_node import FetchScreenNode
from .generate_answer_from_image_node import GenerateAnswerFromImageNode from .generate_answer_from_image_node import GenerateAnswerFromImageNode
from .concat_answers_node import ConcatAnswersNode from .concat_answers_node import ConcatAnswersNode
from .search_node_with_context import SearchLinksWithContext

View File

@ -40,8 +40,6 @@ class SearchLinkNode(BaseNode):
): ):
super().__init__(node_name, "node", input, output, 1, node_config) super().__init__(node_name, "node", input, output, 1, node_config)
self.llm_model = node_config["llm_model"]
if node_config.get("filter_links", False) or "filter_config" in node_config: if node_config.get("filter_links", False) or "filter_config" in node_config:
provided_filter_config = node_config.get("filter_config", {}) provided_filter_config = node_config.get("filter_config", {})
self.filter_config = {**default_filters.filter_dict, **provided_filter_config} self.filter_config = {**default_filters.filter_dict, **provided_filter_config}