diff --git a/scrapegraphai/nodes/description_node.py b/scrapegraphai/nodes/description_node.py index 683aabe1..6175133a 100644 --- a/scrapegraphai/nodes/description_node.py +++ b/scrapegraphai/nodes/description_node.py @@ -67,7 +67,11 @@ class DescriptionNode(BaseNode): temp_res = {} for i, (summary, document) in enumerate(zip(batch_results, docs)): - temp_res[summary] = document + temp_res[summary] = { + "id": i, + "summary": summary, + "document": document + } state["descriptions"] = temp_res diff --git a/scrapegraphai/nodes/generate_answer_node_k_level.py b/scrapegraphai/nodes/generate_answer_node_k_level.py index 24235e71..10977617 100644 --- a/scrapegraphai/nodes/generate_answer_node_k_level.py +++ b/scrapegraphai/nodes/generate_answer_node_k_level.py @@ -102,6 +102,7 @@ class GenerateAnswerNodeKLevel(BaseNode): query_text=state["question"] ) + ## TODO: from the id get the data results_db = [elem for elem in state[answer_db]] chains_dict = {} diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index c137b987..cac41a99 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -49,13 +49,13 @@ class RAGNode(BaseNode): else: raise ValueError("client_type provided not correct") - docs = [elem for elem in state.get("descriptions").keys()] - metadata = [] + docs = [elem.get("summary") for elem in state.get("descriptions", {})] + ids = [elem.get("id") for elem in state.get("descriptions", {})] client.add( collection_name="vectorial_collection", documents=docs, - metadata=metadata, + ids=ids ) state["vectorial_db"] = client