feat: add integration for infos

This commit is contained in:
Marco Vinciguerra 2024-07-01 21:19:16 +02:00
parent e3a19c2059
commit 3bf5f570a8
16 changed files with 96 additions and 6 deletions

View File

@ -0,0 +1,50 @@
"""
Basic example of scraping pipeline using SmartScraper
"""
import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
load_dotenv()
# ************************************************
# Define the configuration for the graph
# ************************************************
openai_key = os.getenv("OPENAI_APIKEY")
prompt = "Some more info"
graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"additional_info": prompt,
"verbose": True,
"headless": False,
}
# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the projects with their description",
# also accepts a string with the already downloaded HTML code
source="https://perinim.github.io/projects/",
config=graph_config,
)
result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))
# ************************************************
# Get graph execution info
# ************************************************
graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))

View File

@ -50,6 +50,7 @@ class CSVScraperGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema,
}
)

View File

@ -95,6 +95,7 @@ class DeepScraperGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema
}
)

View File

@ -75,6 +75,7 @@ class JSONScraperGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema
}
)

View File

@ -76,6 +76,7 @@ class MDScraperGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema,
}
)

View File

@ -18,7 +18,6 @@ from ..nodes import (
from ..models import OpenAIImageToText
class OmniScraperGraph(AbstractGraph):
"""
OmniScraper is a scraping pipeline that automates the process of
@ -60,7 +59,6 @@ class OmniScraperGraph(AbstractGraph):
super().__init__(prompt, config, source, schema)
self.input_key = "url" if source.startswith("http") else "local_dir"
def _create_graph(self) -> BaseGraph:
"""
@ -104,6 +102,7 @@ class OmniScraperGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema
}
)

View File

@ -89,6 +89,7 @@ class PDFScraperGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema
}
)

View File

@ -84,6 +84,7 @@ class ScriptCreatorGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema,
},
library=self.library,

View File

@ -91,6 +91,7 @@ class SmartScraperGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema,
}
)

View File

@ -84,6 +84,7 @@ class SpeechGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema
}
)

View File

@ -77,6 +77,7 @@ class XMLScraperGraph(AbstractGraph):
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": self.config.get("additional_info"),
"schema": self.schema
}
)

View File

@ -58,11 +58,14 @@ class GenerateAnswerCSVNode(BaseNode):
node_name (str): name of the node
"""
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]
self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)
self.additional_info = node_config.get("additional_info")
def execute(self, state):
"""
@ -99,9 +102,14 @@ class GenerateAnswerCSVNode(BaseNode):
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
output_parser = JsonOutputParser()
if self.additional_info is not None:
template_no_chunks_csv += self.additional_info
template_chunks_csv += self.additional_info
template_merge_csv += self.additional_info
format_instructions = output_parser.get_format_instructions()
chains_dict = {}
# Use tqdm to add progress bar

View File

@ -54,6 +54,7 @@ class GenerateAnswerNode(BaseNode):
False if node_config is None else node_config.get("script_creator", False)
)
self.additional_info = node_config.get("additional_info")
def execute(self, state: dict) -> dict:
"""
@ -98,6 +99,11 @@ class GenerateAnswerNode(BaseNode):
template_chunks_prompt = template_chunks
template_merge_prompt = template_merge
if self.additional_info is not None:
template_no_chunks_prompt += self.additional_info
template_chunks_prompt += self.additional_info
template_merge_prompt += self.additional_info
chains_dict = {}
# Use tqdm to add progress bar
@ -118,7 +124,6 @@ class GenerateAnswerNode(BaseNode):
partial_variables={"context": chunk.page_content,
"chunk_id": i + 1,
"format_instructions": format_instructions})
# Dynamically name the chains based on their index
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser

View File

@ -46,11 +46,13 @@ class GenerateAnswerOmniNode(BaseNode):
self.llm_model = node_config["llm_model"]
if isinstance(node_config["llm_model"], Ollama):
self.llm_model.format="json"
self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)
self.additional_info = node_config.get("additional_info")
def execute(self, state: dict) -> dict:
"""
Generates an answer by constructing a prompt from the user's input and the scraped
@ -86,6 +88,11 @@ class GenerateAnswerOmniNode(BaseNode):
else:
output_parser = JsonOutputParser()
if self.additional_info is not None:
template_no_chunk_omni += self.additional_info
template_chunks_omni += self.additional_info
template_merge_omni += self.additional_info
format_instructions = output_parser.get_format_instructions()

View File

@ -61,10 +61,13 @@ class GenerateAnswerPDFNode(BaseNode):
self.llm_model = node_config["llm_model"]
if isinstance(node_config["llm_model"], Ollama):
self.llm_model.format="json"
self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)
self.additional_info = node_config.get("additional_info")
def execute(self, state):
"""
Generates an answer by constructing a prompt from the user's input and the scraped
@ -100,6 +103,11 @@ class GenerateAnswerPDFNode(BaseNode):
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
output_parser = JsonOutputParser()
if self.additional_info is not None:
template_no_chunks_pdf += self.additional_info
template_chunks_pdf += self.additional_info
template_merge_pdf += self.additional_info
format_instructions = output_parser.get_format_instructions()

View File

@ -54,6 +54,8 @@ class GenerateScraperNode(BaseNode):
False if node_config is None else node_config.get("verbose", False)
)
self.additional_info = node_config.get("additional_info")
def execute(self, state: dict) -> dict:
"""
Generates a python script for scraping a website using the specified library.
@ -106,6 +108,8 @@ class GenerateScraperNode(BaseNode):
USER QUESTION: {question}
SCHEMA INSTRUCTIONS: {schema_instructions}
"""
if self.additional_info is not None:
template_no_chunks += self.additional_info
if len(doc) > 1:
raise NotImplementedError(