mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
refactored RagNode and GenerateAnswerNode
This commit is contained in:
parent
3585cd81e5
commit
875a7cc4f0
@ -39,7 +39,7 @@ class BaseNode(ABC):
|
||||
raised to indicate the incorrect usage.
|
||||
"""
|
||||
|
||||
def __init__(self, node_name: str, node_type: str, input: str, output: List[str], min_input_len: int = 1, model = None):
|
||||
def __init__(self, node_name: str, node_type: str, input: str, output: List[str], min_input_len: int = 1, model_config: Optional[dict] = None):
|
||||
"""
|
||||
Initialize the node with a unique identifier and a specified node type.
|
||||
|
||||
@ -54,7 +54,7 @@ class BaseNode(ABC):
|
||||
self.input = input
|
||||
self.output = output
|
||||
self.min_input_len = min_input_len
|
||||
self.model = model
|
||||
self.model_config = model_config
|
||||
|
||||
if node_type not in ["node", "conditional_node"]:
|
||||
raise ValueError(
|
||||
|
||||
@ -11,7 +11,7 @@ from langchain_core.runnables import RunnableParallel
|
||||
|
||||
# Imports from the library
|
||||
from .base_node import BaseNode
|
||||
|
||||
from typing import List
|
||||
|
||||
class GenerateAnswerNode(BaseNode):
|
||||
"""
|
||||
@ -38,17 +38,17 @@ class GenerateAnswerNode(BaseNode):
|
||||
updating the state with the generated answer under the 'answer' key.
|
||||
"""
|
||||
|
||||
def __init__(self, llm, node_name: str):
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswerNode"):
|
||||
"""
|
||||
Initializes the GenerateAnswerNode with a language model client and a node name.
|
||||
Args:
|
||||
llm (OpenAIImageToText): An instance of the OpenAIImageToText class.
|
||||
node_name (str): name of the node
|
||||
"""
|
||||
super().__init__(node_name, "node")
|
||||
self.llm = llm
|
||||
super().__init__(node_name, "node", input, output, 2, model_config)
|
||||
self.llm_model = model_config["llm_model"]
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
def execute(self, state):
|
||||
"""
|
||||
Generates an answer by constructing a prompt from the user's input and the scraped
|
||||
content, querying the language model, and parsing its response.
|
||||
@ -67,23 +67,16 @@ class GenerateAnswerNode(BaseNode):
|
||||
that the necessary information for generating an answer is missing.
|
||||
"""
|
||||
|
||||
print("---GENERATING ANSWER---")
|
||||
try:
|
||||
user_input = state["user_input"]
|
||||
document = state["document"]
|
||||
except KeyError as e:
|
||||
print(f"Error: {e} not found in state.")
|
||||
raise
|
||||
print(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
parsed_document = state.get("parsed_document", None)
|
||||
relevant_chunks = state.get("relevant_chunks", None)
|
||||
|
||||
if relevant_chunks:
|
||||
context = relevant_chunks
|
||||
elif parsed_document:
|
||||
context = parsed_document
|
||||
else:
|
||||
context = document
|
||||
user_prompt = input_data[0]
|
||||
doc = input_data[1]
|
||||
|
||||
output_parser = JsonOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
@ -106,7 +99,7 @@ class GenerateAnswerNode(BaseNode):
|
||||
chains_dict = {}
|
||||
|
||||
# Use tqdm to add progress bar
|
||||
for i, chunk in enumerate(tqdm(context, desc="Processing chunks")):
|
||||
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks")):
|
||||
prompt = PromptTemplate(
|
||||
template=template_chunks,
|
||||
input_variables=["question"],
|
||||
@ -115,12 +108,12 @@ class GenerateAnswerNode(BaseNode):
|
||||
)
|
||||
# Dynamically name the chains based on their index
|
||||
chain_name = f"chunk{i+1}"
|
||||
chains_dict[chain_name] = prompt | self.llm | output_parser
|
||||
chains_dict[chain_name] = prompt | self.llm_model | output_parser
|
||||
|
||||
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
|
||||
map_chain = RunnableParallel(**chains_dict)
|
||||
# Chain
|
||||
answer_map = map_chain.invoke({"question": user_input})
|
||||
answer_map = map_chain.invoke({"question": user_prompt})
|
||||
|
||||
# Merge the answers from the chunks
|
||||
merge_prompt = PromptTemplate(
|
||||
@ -128,10 +121,10 @@ class GenerateAnswerNode(BaseNode):
|
||||
input_variables=["context", "question"],
|
||||
partial_variables={"format_instructions": format_instructions},
|
||||
)
|
||||
merge_chain = merge_prompt | self.llm | output_parser
|
||||
merge_chain = merge_prompt | self.llm_model | output_parser
|
||||
answer = merge_chain.invoke(
|
||||
{"context": answer_map, "question": user_input})
|
||||
{"context": answer_map, "question": user_prompt})
|
||||
|
||||
# Update the state with the generated answer
|
||||
state.update({"answer": answer})
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
@ -77,6 +77,8 @@ class ParseNode(BaseNode):
|
||||
docs_transformed = Html2TextTransformer(
|
||||
).transform_documents(input_data[0])[0]
|
||||
|
||||
# TODO: keep the metadata
|
||||
|
||||
chunks = text_splitter.split_text(docs_transformed.page_content)
|
||||
|
||||
state.update({self.output[0]: chunks})
|
||||
|
||||
@ -8,7 +8,7 @@ from langchain.retrievers.document_compressors import EmbeddingsFilter, Document
|
||||
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
from typing import List
|
||||
|
||||
from .base_node import BaseNode
|
||||
|
||||
@ -33,12 +33,12 @@ class RAGNode(BaseNode):
|
||||
the specified tags, if provided, and updates the state with the parsed content.
|
||||
"""
|
||||
|
||||
def __init__(self, llm, node_name="RagNode"):
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAGNode"):
|
||||
"""
|
||||
Initializes the ParseHTMLNode with a node name.
|
||||
"""
|
||||
super().__init__(node_name, "node")
|
||||
self.llm = llm
|
||||
super().__init__(node_name, "node", input, output, 2, model_config)
|
||||
self.llm_model = model_config["llm_model"]
|
||||
|
||||
def execute(self, state):
|
||||
"""
|
||||
@ -56,25 +56,20 @@ class RAGNode(BaseNode):
|
||||
information for parsing is missing.
|
||||
"""
|
||||
|
||||
print("---RAG STARTED---")
|
||||
try:
|
||||
user_input = state["user_input"]
|
||||
document = state["document"]
|
||||
except KeyError as e:
|
||||
print(f"Error: {e} not found in state.")
|
||||
raise
|
||||
print(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
parsed_document = state.get("parsed_document", None)
|
||||
|
||||
if parsed_document:
|
||||
chunks = parsed_document
|
||||
else:
|
||||
print("Parsed document not found. Using original document.")
|
||||
chunks = document
|
||||
user_prompt = input_data[0]
|
||||
doc = input_data[1]
|
||||
|
||||
chunked_docs = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
for i, chunk in enumerate(doc):
|
||||
doc = Document(
|
||||
page_content=chunk,
|
||||
metadata={
|
||||
@ -85,7 +80,7 @@ class RAGNode(BaseNode):
|
||||
|
||||
print("---UPDATED CHUNKS METADATA---")
|
||||
|
||||
openai_key = self.llm.openai_api_key
|
||||
openai_key = self.llm_model.openai_api_key
|
||||
retriever = FAISS.from_documents(chunked_docs,
|
||||
OpenAIEmbeddings(api_key=openai_key)).as_retriever()
|
||||
# could be any embedding of your choice
|
||||
@ -108,9 +103,9 @@ class RAGNode(BaseNode):
|
||||
# )
|
||||
|
||||
compressed_docs = compression_retriever.get_relevant_documents(
|
||||
user_input)
|
||||
user_prompt)
|
||||
|
||||
print("---TOKENS COMPRESSED AND VECTOR STORED---")
|
||||
|
||||
state.update({"relevant_chunks": compressed_docs})
|
||||
state.update({self.output[0]: compressed_docs})
|
||||
return state
|
||||
|
||||
@ -1,4 +1,19 @@
|
||||
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.models import OpenAI
|
||||
load_dotenv()
|
||||
|
||||
# Define the configuration for the language model
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
|
||||
llm_config = {
|
||||
"api_key": openai_key,
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"temperature": 0,
|
||||
"streaming": True
|
||||
}
|
||||
llm_model = OpenAI(llm_config)
|
||||
|
||||
state = {
|
||||
"user_prompt": "List me all the projects",
|
||||
@ -18,4 +33,22 @@ parse_node = ParseNode(
|
||||
node_name="parse_document"
|
||||
)
|
||||
|
||||
parse_node.execute(updated_state)
|
||||
updated_state = parse_node.execute(updated_state)
|
||||
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
model_config={"llm_model": llm_model},
|
||||
node_name="rag_node"
|
||||
)
|
||||
|
||||
updated_state = rag_node.execute(updated_state)
|
||||
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
model_config={"llm_model": llm_model},
|
||||
node_name="generate_answer"
|
||||
)
|
||||
|
||||
print(generate_answer_node.execute(updated_state))
|
||||
Loading…
Reference in New Issue
Block a user