diff --git a/scrapegraphai/nodes/base_node.py b/scrapegraphai/nodes/base_node.py index 43e43fa4..f2f1001b 100644 --- a/scrapegraphai/nodes/base_node.py +++ b/scrapegraphai/nodes/base_node.py @@ -39,7 +39,7 @@ class BaseNode(ABC): raised to indicate the incorrect usage. """ - def __init__(self, node_name: str, node_type: str, input: str, output: List[str], min_input_len: int = 1, model = None): + def __init__(self, node_name: str, node_type: str, input: str, output: List[str], min_input_len: int = 1, model_config: Optional[dict] = None): """ Initialize the node with a unique identifier and a specified node type. @@ -54,7 +54,7 @@ class BaseNode(ABC): self.input = input self.output = output self.min_input_len = min_input_len - self.model = model + self.model_config = model_config if node_type not in ["node", "conditional_node"]: raise ValueError( diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 3524e187..c1f8d291 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -11,7 +11,7 @@ from langchain_core.runnables import RunnableParallel # Imports from the library from .base_node import BaseNode - +from typing import List class GenerateAnswerNode(BaseNode): """ @@ -38,17 +38,17 @@ class GenerateAnswerNode(BaseNode): updating the state with the generated answer under the 'answer' key. """ - def __init__(self, llm, node_name: str): + def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswerNode"): """ Initializes the GenerateAnswerNode with a language model client and a node name. Args: llm (OpenAIImageToText): An instance of the OpenAIImageToText class. node_name (str): name of the node """ - super().__init__(node_name, "node") - self.llm = llm + super().__init__(node_name, "node", input, output, 2, model_config) + self.llm_model = model_config["llm_model"] - def execute(self, state: dict) -> dict: + def execute(self, state): """ Generates an answer by constructing a prompt from the user's input and the scraped content, querying the language model, and parsing its response. @@ -67,23 +67,16 @@ class GenerateAnswerNode(BaseNode): that the necessary information for generating an answer is missing. """ - print("---GENERATING ANSWER---") - try: - user_input = state["user_input"] - document = state["document"] - except KeyError as e: - print(f"Error: {e} not found in state.") - raise + print(f"--- Executing {self.node_name} Node ---") + + # Interpret input keys based on the provided input expression + input_keys = self.get_input_keys(state) + + # Fetching data from the state based on the input keys + input_data = [state[key] for key in input_keys] - parsed_document = state.get("parsed_document", None) - relevant_chunks = state.get("relevant_chunks", None) - - if relevant_chunks: - context = relevant_chunks - elif parsed_document: - context = parsed_document - else: - context = document + user_prompt = input_data[0] + doc = input_data[1] output_parser = JsonOutputParser() format_instructions = output_parser.get_format_instructions() @@ -106,7 +99,7 @@ class GenerateAnswerNode(BaseNode): chains_dict = {} # Use tqdm to add progress bar - for i, chunk in enumerate(tqdm(context, desc="Processing chunks")): + for i, chunk in enumerate(tqdm(doc, desc="Processing chunks")): prompt = PromptTemplate( template=template_chunks, input_variables=["question"], @@ -115,12 +108,12 @@ class GenerateAnswerNode(BaseNode): ) # Dynamically name the chains based on their index chain_name = f"chunk{i+1}" - chains_dict[chain_name] = prompt | self.llm | output_parser + chains_dict[chain_name] = prompt | self.llm_model | output_parser # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel map_chain = RunnableParallel(**chains_dict) # Chain - answer_map = map_chain.invoke({"question": user_input}) + answer_map = map_chain.invoke({"question": user_prompt}) # Merge the answers from the chunks merge_prompt = PromptTemplate( @@ -128,10 +121,10 @@ class GenerateAnswerNode(BaseNode): input_variables=["context", "question"], partial_variables={"format_instructions": format_instructions}, ) - merge_chain = merge_prompt | self.llm | output_parser + merge_chain = merge_prompt | self.llm_model | output_parser answer = merge_chain.invoke( - {"context": answer_map, "question": user_input}) + {"context": answer_map, "question": user_prompt}) # Update the state with the generated answer - state.update({"answer": answer}) + state.update({self.output[0]: answer}) return state \ No newline at end of file diff --git a/scrapegraphai/nodes/parse_node.py b/scrapegraphai/nodes/parse_node.py index c8ac24b4..62e0f0f9 100644 --- a/scrapegraphai/nodes/parse_node.py +++ b/scrapegraphai/nodes/parse_node.py @@ -77,6 +77,8 @@ class ParseNode(BaseNode): docs_transformed = Html2TextTransformer( ).transform_documents(input_data[0])[0] + # TODO: keep the metadata + chunks = text_splitter.split_text(docs_transformed.page_content) state.update({self.output[0]: chunks}) diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index 2b1e6898..b37684b9 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -8,7 +8,7 @@ from langchain.retrievers.document_compressors import EmbeddingsFilter, Document from langchain_community.document_transformers import EmbeddingsRedundantFilter from langchain_community.vectorstores import FAISS from langchain_openai import OpenAIEmbeddings - +from typing import List from .base_node import BaseNode @@ -33,12 +33,12 @@ class RAGNode(BaseNode): the specified tags, if provided, and updates the state with the parsed content. """ - def __init__(self, llm, node_name="RagNode"): + def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAGNode"): """ Initializes the ParseHTMLNode with a node name. """ - super().__init__(node_name, "node") - self.llm = llm + super().__init__(node_name, "node", input, output, 2, model_config) + self.llm_model = model_config["llm_model"] def execute(self, state): """ @@ -56,25 +56,20 @@ class RAGNode(BaseNode): information for parsing is missing. """ - print("---RAG STARTED---") - try: - user_input = state["user_input"] - document = state["document"] - except KeyError as e: - print(f"Error: {e} not found in state.") - raise + print(f"--- Executing {self.node_name} Node ---") + + # Interpret input keys based on the provided input expression + input_keys = self.get_input_keys(state) + + # Fetching data from the state based on the input keys + input_data = [state[key] for key in input_keys] - parsed_document = state.get("parsed_document", None) - - if parsed_document: - chunks = parsed_document - else: - print("Parsed document not found. Using original document.") - chunks = document + user_prompt = input_data[0] + doc = input_data[1] chunked_docs = [] - for i, chunk in enumerate(chunks): + for i, chunk in enumerate(doc): doc = Document( page_content=chunk, metadata={ @@ -85,7 +80,7 @@ class RAGNode(BaseNode): print("---UPDATED CHUNKS METADATA---") - openai_key = self.llm.openai_api_key + openai_key = self.llm_model.openai_api_key retriever = FAISS.from_documents(chunked_docs, OpenAIEmbeddings(api_key=openai_key)).as_retriever() # could be any embedding of your choice @@ -108,9 +103,9 @@ class RAGNode(BaseNode): # ) compressed_docs = compression_retriever.get_relevant_documents( - user_input) + user_prompt) print("---TOKENS COMPRESSED AND VECTOR STORED---") - state.update({"relevant_chunks": compressed_docs}) + state.update({self.output[0]: compressed_docs}) return state diff --git a/scrapegraphai/utils/test_node.py b/scrapegraphai/utils/test_node.py index a0f20763..fb6fcd82 100644 --- a/scrapegraphai/utils/test_node.py +++ b/scrapegraphai/utils/test_node.py @@ -1,4 +1,19 @@ from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode +import os +from dotenv import load_dotenv +from scrapegraphai.models import OpenAI +load_dotenv() + +# Define the configuration for the language model +openai_key = os.getenv("OPENAI_APIKEY") + +llm_config = { + "api_key": openai_key, + "model_name": "gpt-3.5-turbo", + "temperature": 0, + "streaming": True +} +llm_model = OpenAI(llm_config) state = { "user_prompt": "List me all the projects", @@ -18,4 +33,22 @@ parse_node = ParseNode( node_name="parse_document" ) -parse_node.execute(updated_state) \ No newline at end of file +updated_state = parse_node.execute(updated_state) + +rag_node = RAGNode( + input="user_prompt & (parsed_doc | doc)", + output=["relevant_chunks"], + model_config={"llm_model": llm_model}, + node_name="rag_node" + ) + +updated_state = rag_node.execute(updated_state) + +generate_answer_node = GenerateAnswerNode( + input="user_prompt & (relevant_chunks | parsed_doc | doc)", + output=["answer"], + model_config={"llm_model": llm_model}, + node_name="generate_answer" + ) + +print(generate_answer_node.execute(updated_state)) \ No newline at end of file