removed rag node

This commit is contained in:
Marco Vinciguerra 2024-07-14 10:24:38 +02:00
parent d3e63d91be
commit a7249685cb
11 changed files with 21 additions and 91 deletions

View File

@ -10,7 +10,6 @@ from .abstract_graph import AbstractGraph
from ..nodes import (
FetchNode,
RAGNode,
GenerateAnswerCSVNode
)
@ -37,14 +36,7 @@ class CSVScraperGraph(AbstractGraph):
input="csv | csv_dir",
output=["doc"],
)
rag_node = RAGNode(
input="user_prompt & doc",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model,
}
)
generate_answer_node = GenerateAnswerCSVNode(
input="user_prompt & (relevant_chunks | doc)",
output=["answer"],
@ -58,12 +50,10 @@ class CSVScraperGraph(AbstractGraph):
return BaseGraph(
nodes=[
fetch_node,
rag_node,
generate_answer_node,
],
edges=[
(fetch_node, rag_node),
(rag_node, generate_answer_node)
(fetch_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__

View File

@ -10,7 +10,6 @@ from .abstract_graph import AbstractGraph
from ..nodes import (
FetchNode,
RAGNode,
GenerateAnswerNode
)
@ -62,14 +61,7 @@ class JSONScraperGraph(AbstractGraph):
input="json | json_dir",
output=["doc", "link_urls", "img_urls"],
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
@ -83,12 +75,10 @@ class JSONScraperGraph(AbstractGraph):
return BaseGraph(
nodes=[
fetch_node,
rag_node,
generate_answer_node,
],
edges=[
(fetch_node, rag_node),
(rag_node, generate_answer_node)
(fetch_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__

View File

@ -3,7 +3,7 @@ import logging
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from ..nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
from ..nodes import FetchNode, ParseNode, GenerateAnswerNode
class MDScraperGraph(AbstractGraph):
"""
@ -63,14 +63,6 @@ class MDScraperGraph(AbstractGraph):
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
@ -86,13 +78,11 @@ class MDScraperGraph(AbstractGraph):
nodes=[
fetch_node,
parse_node,
rag_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
(parse_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__

View File

@ -12,7 +12,6 @@ from ..nodes import (
FetchNode,
ParseNode,
ImageToTextNode,
RAGNode,
GenerateAnswerOmniNode
)
@ -89,14 +88,7 @@ class OmniScraperGraph(AbstractGraph):
"max_images": self.max_images
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_omni_node = GenerateAnswerOmniNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc) & img_desc",
output=["answer"],
@ -112,14 +104,12 @@ class OmniScraperGraph(AbstractGraph):
fetch_node,
parse_node,
image_to_text_node,
rag_node,
generate_answer_omni_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, image_to_text_node),
(image_to_text_node, rag_node),
(rag_node, generate_answer_omni_node)
(image_to_text_node, generate_answer_omni_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
@ -136,4 +126,4 @@ class OmniScraperGraph(AbstractGraph):
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

View File

@ -12,7 +12,6 @@ from .abstract_graph import AbstractGraph
from ..nodes import (
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerPDFNode
)
@ -76,14 +75,6 @@ class PDFScraperGraph(AbstractGraph):
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node_pdf = GenerateAnswerPDFNode(
input="user_prompt & (relevant_chunks | doc)",
output=["answer"],
@ -98,13 +89,11 @@ class PDFScraperGraph(AbstractGraph):
nodes=[
fetch_node,
parse_node,
rag_node,
generate_answer_node_pdf,
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node_pdf)
(parse_node, generate_answer_node_pdf)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__

View File

@ -78,14 +78,7 @@ class SmartScraperGraph(AbstractGraph):
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
@ -100,13 +93,11 @@ class SmartScraperGraph(AbstractGraph):
nodes=[
fetch_node,
parse_node,
rag_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
(parse_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__

View File

@ -10,7 +10,6 @@ from .abstract_graph import AbstractGraph
from ..nodes import (
FetchNode,
RAGNode,
GenerateAnswerNode
)
@ -64,14 +63,7 @@ class XMLScraperGraph(AbstractGraph):
input="xml | xml_dir",
output=["doc", "link_urls", "img_urls"]
)
rag_node = RAGNode(
input="user_prompt & doc",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | doc)",
output=["answer"],
@ -85,12 +77,10 @@ class XMLScraperGraph(AbstractGraph):
return BaseGraph(
nodes=[
fetch_node,
rag_node,
generate_answer_node,
],
edges=[
(fetch_node, rag_node),
(rag_node, generate_answer_node)
(fetch_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__

View File

@ -125,7 +125,7 @@ class GenerateAnswerCSVNode(BaseNode):
template=template_no_chunks_csv_prompt,
input_variables=["question"],
partial_variables={
"context": chunk.page_content,
"context": chunk,
"format_instructions": format_instructions,
},
)
@ -137,7 +137,7 @@ class GenerateAnswerCSVNode(BaseNode):
template=template_chunks_csv_prompt,
input_variables=["question"],
partial_variables={
"context": chunk.page_content,
"context": chunk,
"chunk_id": i + 1,
"format_instructions": format_instructions,
},

View File

@ -115,7 +115,7 @@ class GenerateAnswerNode(BaseNode):
prompt = PromptTemplate(
template=template_no_chunks_prompt,
input_variables=["question"],
partial_variables={"context": chunk.page_content,
partial_variables={"context": chunk,
"format_instructions": format_instructions})
chain = prompt | self.llm_model | output_parser
answer = chain.invoke({"question": user_prompt})
@ -124,7 +124,7 @@ class GenerateAnswerNode(BaseNode):
prompt = PromptTemplate(
template=template_chunks_prompt,
input_variables=["question"],
partial_variables={"context": chunk.page_content,
partial_variables={"context": chunk,
"chunk_id": i + 1,
"format_instructions": format_instructions})
# Dynamically name the chains based on their index

View File

@ -110,7 +110,7 @@ class GenerateAnswerOmniNode(BaseNode):
template=template_no_chunk_omni_prompt,
input_variables=["question"],
partial_variables={
"context": chunk.page_content,
"context": chunk,
"format_instructions": format_instructions,
"img_desc": imag_desc,
},
@ -123,7 +123,7 @@ class GenerateAnswerOmniNode(BaseNode):
template=template_chunks_omni_prompt,
input_variables=["question"],
partial_variables={
"context": chunk.page_content,
"context": chunk,
"chunk_id": i + 1,
"format_instructions": format_instructions,
},

View File

@ -124,7 +124,7 @@ class GenerateAnswerPDFNode(BaseNode):
template=template_no_chunks_pdf_prompt,
input_variables=["question"],
partial_variables={
"context":chunk.page_content,
"context":chunk,
"format_instructions": format_instructions,
},
)