refactored RagNode and GenerateAnswerNode

This commit is contained in:
Perinim 2024-03-17 16:48:58 +01:00
parent 3585cd81e5
commit 875a7cc4f0
5 changed files with 75 additions and 52 deletions

View File

@ -39,7 +39,7 @@ class BaseNode(ABC):
raised to indicate the incorrect usage.
"""
def __init__(self, node_name: str, node_type: str, input: str, output: List[str], min_input_len: int = 1, model = None):
def __init__(self, node_name: str, node_type: str, input: str, output: List[str], min_input_len: int = 1, model_config: Optional[dict] = None):
"""
Initialize the node with a unique identifier and a specified node type.
@ -54,7 +54,7 @@ class BaseNode(ABC):
self.input = input
self.output = output
self.min_input_len = min_input_len
self.model = model
self.model_config = model_config
if node_type not in ["node", "conditional_node"]:
raise ValueError(

View File

@ -11,7 +11,7 @@ from langchain_core.runnables import RunnableParallel
# Imports from the library
from .base_node import BaseNode
from typing import List
class GenerateAnswerNode(BaseNode):
"""
@ -38,17 +38,17 @@ class GenerateAnswerNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""
def __init__(self, llm, node_name: str):
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswerNode"):
"""
Initializes the GenerateAnswerNode with a language model client and a node name.
Args:
llm (OpenAIImageToText): An instance of the OpenAIImageToText class.
node_name (str): name of the node
"""
super().__init__(node_name, "node")
self.llm = llm
super().__init__(node_name, "node", input, output, 2, model_config)
self.llm_model = model_config["llm_model"]
def execute(self, state: dict) -> dict:
def execute(self, state):
"""
Generates an answer by constructing a prompt from the user's input and the scraped
content, querying the language model, and parsing its response.
@ -67,23 +67,16 @@ class GenerateAnswerNode(BaseNode):
that the necessary information for generating an answer is missing.
"""
print("---GENERATING ANSWER---")
try:
user_input = state["user_input"]
document = state["document"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
print(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
parsed_document = state.get("parsed_document", None)
relevant_chunks = state.get("relevant_chunks", None)
if relevant_chunks:
context = relevant_chunks
elif parsed_document:
context = parsed_document
else:
context = document
user_prompt = input_data[0]
doc = input_data[1]
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()
@ -106,7 +99,7 @@ class GenerateAnswerNode(BaseNode):
chains_dict = {}
# Use tqdm to add progress bar
for i, chunk in enumerate(tqdm(context, desc="Processing chunks")):
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks")):
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
@ -115,12 +108,12 @@ class GenerateAnswerNode(BaseNode):
)
# Dynamically name the chains based on their index
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm | output_parser
chains_dict[chain_name] = prompt | self.llm_model | output_parser
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
map_chain = RunnableParallel(**chains_dict)
# Chain
answer_map = map_chain.invoke({"question": user_input})
answer_map = map_chain.invoke({"question": user_prompt})
# Merge the answers from the chunks
merge_prompt = PromptTemplate(
@ -128,10 +121,10 @@ class GenerateAnswerNode(BaseNode):
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm | output_parser
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke(
{"context": answer_map, "question": user_input})
{"context": answer_map, "question": user_prompt})
# Update the state with the generated answer
state.update({"answer": answer})
state.update({self.output[0]: answer})
return state

View File

@ -77,6 +77,8 @@ class ParseNode(BaseNode):
docs_transformed = Html2TextTransformer(
).transform_documents(input_data[0])[0]
# TODO: keep the metadata
chunks = text_splitter.split_text(docs_transformed.page_content)
state.update({self.output[0]: chunks})

View File

@ -8,7 +8,7 @@ from langchain.retrievers.document_compressors import EmbeddingsFilter, Document
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from typing import List
from .base_node import BaseNode
@ -33,12 +33,12 @@ class RAGNode(BaseNode):
the specified tags, if provided, and updates the state with the parsed content.
"""
def __init__(self, llm, node_name="RagNode"):
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAGNode"):
"""
Initializes the ParseHTMLNode with a node name.
"""
super().__init__(node_name, "node")
self.llm = llm
super().__init__(node_name, "node", input, output, 2, model_config)
self.llm_model = model_config["llm_model"]
def execute(self, state):
"""
@ -56,25 +56,20 @@ class RAGNode(BaseNode):
information for parsing is missing.
"""
print("---RAG STARTED---")
try:
user_input = state["user_input"]
document = state["document"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
print(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
parsed_document = state.get("parsed_document", None)
if parsed_document:
chunks = parsed_document
else:
print("Parsed document not found. Using original document.")
chunks = document
user_prompt = input_data[0]
doc = input_data[1]
chunked_docs = []
for i, chunk in enumerate(chunks):
for i, chunk in enumerate(doc):
doc = Document(
page_content=chunk,
metadata={
@ -85,7 +80,7 @@ class RAGNode(BaseNode):
print("---UPDATED CHUNKS METADATA---")
openai_key = self.llm.openai_api_key
openai_key = self.llm_model.openai_api_key
retriever = FAISS.from_documents(chunked_docs,
OpenAIEmbeddings(api_key=openai_key)).as_retriever()
# could be any embedding of your choice
@ -108,9 +103,9 @@ class RAGNode(BaseNode):
# )
compressed_docs = compression_retriever.get_relevant_documents(
user_input)
user_prompt)
print("---TOKENS COMPRESSED AND VECTOR STORED---")
state.update({"relevant_chunks": compressed_docs})
state.update({self.output[0]: compressed_docs})
return state

View File

@ -1,4 +1,19 @@
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
import os
from dotenv import load_dotenv
from scrapegraphai.models import OpenAI
load_dotenv()
# Define the configuration for the language model
openai_key = os.getenv("OPENAI_APIKEY")
llm_config = {
"api_key": openai_key,
"model_name": "gpt-3.5-turbo",
"temperature": 0,
"streaming": True
}
llm_model = OpenAI(llm_config)
state = {
"user_prompt": "List me all the projects",
@ -18,4 +33,22 @@ parse_node = ParseNode(
node_name="parse_document"
)
parse_node.execute(updated_state)
updated_state = parse_node.execute(updated_state)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
model_config={"llm_model": llm_model},
node_name="rag_node"
)
updated_state = rag_node.execute(updated_state)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
model_config={"llm_model": llm_model},
node_name="generate_answer"
)
print(generate_answer_node.execute(updated_state))