mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
feat: add integration for infos
This commit is contained in:
parent
e3a19c2059
commit
3bf5f570a8
50
examples/extras/custom_prompt.py
Normal file
50
examples/extras/custom_prompt.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""
|
||||
Basic example of scraping pipeline using SmartScraper
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SmartScraperGraph
|
||||
from scrapegraphai.utils import prettify_exec_info
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# ************************************************
|
||||
# Define the configuration for the graph
|
||||
# ************************************************
|
||||
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
|
||||
prompt = "Some more info"
|
||||
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": openai_key,
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
"additional_info": prompt,
|
||||
"verbose": True,
|
||||
"headless": False,
|
||||
}
|
||||
|
||||
# ************************************************
|
||||
# Create the SmartScraperGraph instance and run it
|
||||
# ************************************************
|
||||
|
||||
smart_scraper_graph = SmartScraperGraph(
|
||||
prompt="List me all the projects with their description",
|
||||
# also accepts a string with the already downloaded HTML code
|
||||
source="https://perinim.github.io/projects/",
|
||||
config=graph_config,
|
||||
)
|
||||
|
||||
result = smart_scraper_graph.run()
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
# ************************************************
|
||||
# Get graph execution info
|
||||
# ************************************************
|
||||
|
||||
graph_exec_info = smart_scraper_graph.get_execution_info()
|
||||
print(prettify_exec_info(graph_exec_info))
|
||||
@ -50,6 +50,7 @@ class CSVScraperGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema,
|
||||
}
|
||||
)
|
||||
|
||||
@ -95,6 +95,7 @@ class DeepScraperGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema
|
||||
}
|
||||
)
|
||||
|
||||
@ -75,6 +75,7 @@ class JSONScraperGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema
|
||||
}
|
||||
)
|
||||
|
||||
@ -76,6 +76,7 @@ class MDScraperGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema,
|
||||
}
|
||||
)
|
||||
|
||||
@ -18,7 +18,6 @@ from ..nodes import (
|
||||
|
||||
from ..models import OpenAIImageToText
|
||||
|
||||
|
||||
class OmniScraperGraph(AbstractGraph):
|
||||
"""
|
||||
OmniScraper is a scraping pipeline that automates the process of
|
||||
@ -60,7 +59,6 @@ class OmniScraperGraph(AbstractGraph):
|
||||
super().__init__(prompt, config, source, schema)
|
||||
|
||||
self.input_key = "url" if source.startswith("http") else "local_dir"
|
||||
|
||||
|
||||
def _create_graph(self) -> BaseGraph:
|
||||
"""
|
||||
@ -104,6 +102,7 @@ class OmniScraperGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema
|
||||
}
|
||||
)
|
||||
|
||||
@ -89,6 +89,7 @@ class PDFScraperGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema
|
||||
}
|
||||
)
|
||||
|
||||
@ -84,6 +84,7 @@ class ScriptCreatorGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema,
|
||||
},
|
||||
library=self.library,
|
||||
|
||||
@ -91,6 +91,7 @@ class SmartScraperGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema,
|
||||
}
|
||||
)
|
||||
|
||||
@ -84,6 +84,7 @@ class SpeechGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema
|
||||
}
|
||||
)
|
||||
|
||||
@ -77,6 +77,7 @@ class XMLScraperGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"additional_info": self.config.get("additional_info"),
|
||||
"schema": self.schema
|
||||
}
|
||||
)
|
||||
|
||||
@ -58,11 +58,14 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
node_name (str): name of the node
|
||||
"""
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
|
||||
self.llm_model = node_config["llm_model"]
|
||||
|
||||
self.verbose = (
|
||||
False if node_config is None else node_config.get("verbose", False)
|
||||
)
|
||||
|
||||
self.additional_info = node_config.get("additional_info")
|
||||
|
||||
def execute(self, state):
|
||||
"""
|
||||
@ -99,9 +102,14 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
|
||||
if self.additional_info is not None:
|
||||
template_no_chunks_csv += self.additional_info
|
||||
template_chunks_csv += self.additional_info
|
||||
template_merge_csv += self.additional_info
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
|
||||
chains_dict = {}
|
||||
|
||||
# Use tqdm to add progress bar
|
||||
|
||||
@ -54,6 +54,7 @@ class GenerateAnswerNode(BaseNode):
|
||||
False if node_config is None else node_config.get("script_creator", False)
|
||||
)
|
||||
|
||||
self.additional_info = node_config.get("additional_info")
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
@ -98,6 +99,11 @@ class GenerateAnswerNode(BaseNode):
|
||||
template_chunks_prompt = template_chunks
|
||||
template_merge_prompt = template_merge
|
||||
|
||||
if self.additional_info is not None:
|
||||
template_no_chunks_prompt += self.additional_info
|
||||
template_chunks_prompt += self.additional_info
|
||||
template_merge_prompt += self.additional_info
|
||||
|
||||
chains_dict = {}
|
||||
|
||||
# Use tqdm to add progress bar
|
||||
@ -118,7 +124,6 @@ class GenerateAnswerNode(BaseNode):
|
||||
partial_variables={"context": chunk.page_content,
|
||||
"chunk_id": i + 1,
|
||||
"format_instructions": format_instructions})
|
||||
|
||||
# Dynamically name the chains based on their index
|
||||
chain_name = f"chunk{i+1}"
|
||||
chains_dict[chain_name] = prompt | self.llm_model | output_parser
|
||||
|
||||
@ -46,11 +46,13 @@ class GenerateAnswerOmniNode(BaseNode):
|
||||
self.llm_model = node_config["llm_model"]
|
||||
if isinstance(node_config["llm_model"], Ollama):
|
||||
self.llm_model.format="json"
|
||||
|
||||
|
||||
self.verbose = (
|
||||
False if node_config is None else node_config.get("verbose", False)
|
||||
)
|
||||
|
||||
self.additional_info = node_config.get("additional_info")
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
Generates an answer by constructing a prompt from the user's input and the scraped
|
||||
@ -86,6 +88,11 @@ class GenerateAnswerOmniNode(BaseNode):
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
|
||||
if self.additional_info is not None:
|
||||
template_no_chunk_omni += self.additional_info
|
||||
template_chunks_omni += self.additional_info
|
||||
template_merge_omni += self.additional_info
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
|
||||
|
||||
@ -61,10 +61,13 @@ class GenerateAnswerPDFNode(BaseNode):
|
||||
self.llm_model = node_config["llm_model"]
|
||||
if isinstance(node_config["llm_model"], Ollama):
|
||||
self.llm_model.format="json"
|
||||
|
||||
self.verbose = (
|
||||
False if node_config is None else node_config.get("verbose", False)
|
||||
)
|
||||
|
||||
self.additional_info = node_config.get("additional_info")
|
||||
|
||||
def execute(self, state):
|
||||
"""
|
||||
Generates an answer by constructing a prompt from the user's input and the scraped
|
||||
@ -100,6 +103,11 @@ class GenerateAnswerPDFNode(BaseNode):
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
|
||||
if self.additional_info is not None:
|
||||
template_no_chunks_pdf += self.additional_info
|
||||
template_chunks_pdf += self.additional_info
|
||||
template_merge_pdf += self.additional_info
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
|
||||
@ -54,6 +54,8 @@ class GenerateScraperNode(BaseNode):
|
||||
False if node_config is None else node_config.get("verbose", False)
|
||||
)
|
||||
|
||||
self.additional_info = node_config.get("additional_info")
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
Generates a python script for scraping a website using the specified library.
|
||||
@ -106,6 +108,8 @@ class GenerateScraperNode(BaseNode):
|
||||
USER QUESTION: {question}
|
||||
SCHEMA INSTRUCTIONS: {schema_instructions}
|
||||
"""
|
||||
if self.additional_info is not None:
|
||||
template_no_chunks += self.additional_info
|
||||
|
||||
if len(doc) > 1:
|
||||
raise NotImplementedError(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user