From a626401f8a8d883ce563c8ed06625d15883cdd2c Mon Sep 17 00:00:00 2001 From: Perinim Date: Wed, 21 Feb 2024 18:43:38 +0100 Subject: [PATCH] merged multiple chunks answer --- scrapegraphai/nodes/generate_answer_node.py | 42 +++++++++++++-------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 71f80252..164726eb 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -79,9 +79,19 @@ class GenerateAnswerNode(BaseNode): output_parser = JsonOutputParser() format_instructions = output_parser.get_format_instructions() - template = """You are a website scraper and you have just scraped the + template_chunks = """You are a website scraper and you have just scraped the following content from a website. - You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n The content is as follows: {context} + You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n + The website is big so I am giving you one chunk at the time to be merged later with the other chunks.\n + Content of {chunk_id}: {context} + Question: {question} + """ + + template_merge = """You are a website scraper and you have just scraped the + following content from a website. + You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n + You have scraped many chunks since the website is big and now you are asked to merge them into a single answer without repetitions (if there are any).\n + Content to merge: {context} Question: {question} """ @@ -89,28 +99,28 @@ class GenerateAnswerNode(BaseNode): for i, chunk in enumerate(context): prompt = PromptTemplate( - template=template, + template=template_chunks, input_variables=["question"], - partial_variables={"format_instructions": format_instructions, "context": chunk.page_content}, + partial_variables={"context": chunk.page_content, "chunk_id": i + 1, "format_instructions": format_instructions}, ) # Dynamically name the chains based on their index - chain_name = f"chunk{i}" + chain_name = f"chunk{i+1}" chains_dict[chain_name] = prompt | self.llm | output_parser # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel - map_chain = RunnableParallel(**chains_dict) - # schema_prompt = PromptTemplate( - # template=template, - # input_variables=["context", "question"], - # partial_variables={"format_instructions": format_instructions}, - # ) - # schema_chain = schema_prompt | self.llm | output_parser - # answer = schema_chain.invoke( - # {"context": context, "question": user_input}) - + map_chain = RunnableParallel(**chains_dict) # Chain - answer = map_chain.invoke({"question": user_input}) + answer_map = map_chain.invoke({"question": user_input}) + # Merge the answers from the chunks + merge_prompt = PromptTemplate( + template=template_merge, + input_variables=["context", "question"], + partial_variables={"format_instructions": format_instructions}, + ) + merge_chain = merge_prompt | self.llm | output_parser + answer = merge_chain.invoke( + {"context": answer_map, "question": user_input}) # Update the state with the generated answer state.update({"answer": answer})