Scrapegraph-ai/scrapegraphai/nodes/generate_answer_node.py
2024-03-05 18:07:21 +01:00

145 lines
5.6 KiB
Python

"""
Module for generating the answer node
"""
# Imports from standard library
from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
# Imports from the library
from .base_node import BaseNode
from langchain.text_splitter import RecursiveCharacterTextSplitter
class GenerateAnswerNode(BaseNode):
"""
A node that generates an answer using a language model (LLM) based on the user's input
and the content extracted from a webpage. It constructs a prompt from the user's input
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
an answer.
Attributes:
llm (ChatOpenAI): An instance of a language model client, configured for generating answers.
node_name (str): The unique identifier name for the node, defaulting
to "GenerateAnswerNode".
node_type (str): The type of the node, set to "node" indicating a
standard operational node.
Args:
llm: An instance of the language model client (e.g., ChatOpenAI) used
for generating answers.
node_name (str, optional): The unique identifier name for the node.
Defaults to "GenerateAnswerNode".
Methods:
execute(state): Processes the input and document from the state to generate an answer,
updating the state with the generated answer under the 'answer' key.
"""
def __init__(self, llm, node_name: str):
"""
Initializes the GenerateAnswerNode with a language model client and a node name.
Args:
llm (OpenAIImageToText): An instance of the OpenAIImageToText class.
node_name (str): name of the node
"""
super().__init__(node_name, "node")
self.llm = llm
def execute(self, state: dict) -> dict:
"""
Generates an answer by constructing a prompt from the user's input and the scraped
content, querying the language model, and parsing its response.
The method updates the state with the generated answer under the 'answer' key.
Args:
state (dict): The current state of the graph, expected to contain 'user_input',
and optionally 'parsed_document' or 'relevant_chunks' within 'keys'.
Returns:
dict: The updated state with the 'answer' key containing the generated answer.
Raises:
KeyError: If 'user_input' or 'document' is not found in the state, indicating
that the necessary information for generating an answer is missing.
"""
print("---GENERATING ANSWER---")
try:
user_input = state["user_input"]
document = state["document_chunks"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
parsed_document = state.get("parsed_document", None)
relevant_chunks = state.get("relevant_chunks", None)
if relevant_chunks:
context = relevant_chunks
elif parsed_document:
context = parsed_document
else:
context = document
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()
template_chunks = """You are a website scraper and you have just scraped the
following content from a website.
You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n
The website is big so I am giving you one chunk at the time to be merged later with the other chunks.\n
Content of {chunk_id}: {context}
Question: {question}
"""
template_merge = """You are a website scraper and you have just scraped the
following content from a website.
You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n
You have scraped many chunks since the website is big and now you are asked to merge them into a single answer without repetitions (if there are any).\n
Content to merge: {context}
Question: {question}
"""
chains_dict = {}
# Use tqdm to add progress bar
for i, chunk in enumerate(tqdm(context, desc="Processing chunks")):
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
partial_variables={"context": chunk,
"chunk_id": i + 1, "format_instructions": format_instructions},
)
# Dynamically name the chains based on their index
chains_dict[f"chunk{i+1}"] = prompt | self.llm | output_parser
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=4000,
chunk_overlap=0,
)
chunks = text_splitter.split_text(str(chains_dict))
merge_prompt = PromptTemplate(
template=template_merge,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm | output_parser
answer_lines = []
for chunk in chunks:
answer_temp = merge_chain.invoke(
{"context": chunk, "question": user_input})
answer_lines.append(answer_temp)
unique_answer_lines = list(set(answer_lines))
answer = '\n'.join(unique_answer_lines)
state.update({"answer": answer})
return state