feat: add integration for infos

This commit is contained in:
Marco Vinciguerra 2024-07-01 21:19:16 +02:00
parent e3a19c2059
commit 3bf5f570a8
16 changed files with 96 additions and 6 deletions

View File

@ -0,0 +1,50 @@
"""
Basic example of scraping pipeline using SmartScraper
"""
import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
load_dotenv()
# ************************************************
# Define the configuration for the graph
# ************************************************
openai_key = os.getenv("OPENAI_APIKEY")
prompt = "Some more info"
graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"additional_info": prompt,
"verbose": True,
"headless": False,
}
# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the projects with their description",
# also accepts a string with the already downloaded HTML code
source="https://perinim.github.io/projects/",
config=graph_config,
)
result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))
# ************************************************
# Get graph execution info
# ************************************************
graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))

View File

@ -50,6 +50,7 @@ class CSVScraperGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema, "schema": self.schema,
} }
) )

View File

@ -95,6 +95,7 @@ class DeepScraperGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema "schema": self.schema
} }
) )

View File

@ -75,6 +75,7 @@ class JSONScraperGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema "schema": self.schema
} }
) )

View File

@ -76,6 +76,7 @@ class MDScraperGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema, "schema": self.schema,
} }
) )

View File

@ -18,7 +18,6 @@ from ..nodes import (
from ..models import OpenAIImageToText from ..models import OpenAIImageToText
class OmniScraperGraph(AbstractGraph): class OmniScraperGraph(AbstractGraph):
""" """
OmniScraper is a scraping pipeline that automates the process of OmniScraper is a scraping pipeline that automates the process of
@ -60,7 +59,6 @@ class OmniScraperGraph(AbstractGraph):
super().__init__(prompt, config, source, schema) super().__init__(prompt, config, source, schema)
self.input_key = "url" if source.startswith("http") else "local_dir" self.input_key = "url" if source.startswith("http") else "local_dir"
def _create_graph(self) -> BaseGraph: def _create_graph(self) -> BaseGraph:
""" """
@ -104,6 +102,7 @@ class OmniScraperGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema "schema": self.schema
} }
) )

View File

@ -89,6 +89,7 @@ class PDFScraperGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema "schema": self.schema
} }
) )

View File

@ -84,6 +84,7 @@ class ScriptCreatorGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema, "schema": self.schema,
}, },
library=self.library, library=self.library,

View File

@ -91,6 +91,7 @@ class SmartScraperGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema, "schema": self.schema,
} }
) )

View File

@ -84,6 +84,7 @@ class SpeechGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema "schema": self.schema
} }
) )

View File

@ -77,6 +77,7 @@ class XMLScraperGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema "schema": self.schema
} }
) )

View File

@ -58,11 +58,14 @@ class GenerateAnswerCSVNode(BaseNode):
node_name (str): name of the node node_name (str): name of the node
""" """
super().__init__(node_name, "node", input, output, 2, node_config) super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"] self.llm_model = node_config["llm_model"]
self.verbose = ( self.verbose = (
False if node_config is None else node_config.get("verbose", False) False if node_config is None else node_config.get("verbose", False)
) )
self.additional_info = node_config.get("additional_info")
def execute(self, state): def execute(self, state):
""" """
@ -99,9 +102,14 @@ class GenerateAnswerCSVNode(BaseNode):
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else: else:
output_parser = JsonOutputParser() output_parser = JsonOutputParser()
if self.additional_info is not None:
template_no_chunks_csv += self.additional_info
template_chunks_csv += self.additional_info
template_merge_csv += self.additional_info
format_instructions = output_parser.get_format_instructions() format_instructions = output_parser.get_format_instructions()
chains_dict = {} chains_dict = {}
# Use tqdm to add progress bar # Use tqdm to add progress bar

View File

@ -54,6 +54,7 @@ class GenerateAnswerNode(BaseNode):
False if node_config is None else node_config.get("script_creator", False) False if node_config is None else node_config.get("script_creator", False)
) )
self.additional_info = node_config.get("additional_info")
def execute(self, state: dict) -> dict: def execute(self, state: dict) -> dict:
""" """
@ -98,6 +99,11 @@ class GenerateAnswerNode(BaseNode):
template_chunks_prompt = template_chunks template_chunks_prompt = template_chunks
template_merge_prompt = template_merge template_merge_prompt = template_merge
if self.additional_info is not None:
template_no_chunks_prompt += self.additional_info
template_chunks_prompt += self.additional_info
template_merge_prompt += self.additional_info
chains_dict = {} chains_dict = {}
# Use tqdm to add progress bar # Use tqdm to add progress bar
@ -118,7 +124,6 @@ class GenerateAnswerNode(BaseNode):
partial_variables={"context": chunk.page_content, partial_variables={"context": chunk.page_content,
"chunk_id": i + 1, "chunk_id": i + 1,
"format_instructions": format_instructions}) "format_instructions": format_instructions})
# Dynamically name the chains based on their index # Dynamically name the chains based on their index
chain_name = f"chunk{i+1}" chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser chains_dict[chain_name] = prompt | self.llm_model | output_parser

View File

@ -46,11 +46,13 @@ class GenerateAnswerOmniNode(BaseNode):
self.llm_model = node_config["llm_model"] self.llm_model = node_config["llm_model"]
if isinstance(node_config["llm_model"], Ollama): if isinstance(node_config["llm_model"], Ollama):
self.llm_model.format="json" self.llm_model.format="json"
self.verbose = ( self.verbose = (
False if node_config is None else node_config.get("verbose", False) False if node_config is None else node_config.get("verbose", False)
) )
self.additional_info = node_config.get("additional_info")
def execute(self, state: dict) -> dict: def execute(self, state: dict) -> dict:
""" """
Generates an answer by constructing a prompt from the user's input and the scraped Generates an answer by constructing a prompt from the user's input and the scraped
@ -86,6 +88,11 @@ class GenerateAnswerOmniNode(BaseNode):
else: else:
output_parser = JsonOutputParser() output_parser = JsonOutputParser()
if self.additional_info is not None:
template_no_chunk_omni += self.additional_info
template_chunks_omni += self.additional_info
template_merge_omni += self.additional_info
format_instructions = output_parser.get_format_instructions() format_instructions = output_parser.get_format_instructions()

View File

@ -61,10 +61,13 @@ class GenerateAnswerPDFNode(BaseNode):
self.llm_model = node_config["llm_model"] self.llm_model = node_config["llm_model"]
if isinstance(node_config["llm_model"], Ollama): if isinstance(node_config["llm_model"], Ollama):
self.llm_model.format="json" self.llm_model.format="json"
self.verbose = ( self.verbose = (
False if node_config is None else node_config.get("verbose", False) False if node_config is None else node_config.get("verbose", False)
) )
self.additional_info = node_config.get("additional_info")
def execute(self, state): def execute(self, state):
""" """
Generates an answer by constructing a prompt from the user's input and the scraped Generates an answer by constructing a prompt from the user's input and the scraped
@ -100,6 +103,11 @@ class GenerateAnswerPDFNode(BaseNode):
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else: else:
output_parser = JsonOutputParser() output_parser = JsonOutputParser()
if self.additional_info is not None:
template_no_chunks_pdf += self.additional_info
template_chunks_pdf += self.additional_info
template_merge_pdf += self.additional_info
format_instructions = output_parser.get_format_instructions() format_instructions = output_parser.get_format_instructions()

View File

@ -54,6 +54,8 @@ class GenerateScraperNode(BaseNode):
False if node_config is None else node_config.get("verbose", False) False if node_config is None else node_config.get("verbose", False)
) )
self.additional_info = node_config.get("additional_info")
def execute(self, state: dict) -> dict: def execute(self, state: dict) -> dict:
""" """
Generates a python script for scraping a website using the specified library. Generates a python script for scraping a website using the specified library.
@ -106,6 +108,8 @@ class GenerateScraperNode(BaseNode):
USER QUESTION: {question} USER QUESTION: {question}
SCHEMA INSTRUCTIONS: {schema_instructions} SCHEMA INSTRUCTIONS: {schema_instructions}
""" """
if self.additional_info is not None:
template_no_chunks += self.additional_info
if len(doc) > 1: if len(doc) > 1:
raise NotImplementedError( raise NotImplementedError(