dev: first graph implementation

This commit is contained in:
Marco Perini 2024-02-13 20:42:06 +01:00 committed by GitHub
parent cbf654906c
commit ea7616f9be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 263 additions and 0 deletions

5
yosoai/graph/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .base_graph import BaseGraph
from .conditional_node import ConditionalNode
from .get_probable_tags_node import GetProbableTagsNode
from .generate_answer_node import GenerateAnswerNode
from .parse_html_node import ParseHTMLNode

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,33 @@
from .conditional_node import ConditionalNode
class BaseGraph:
def __init__(self, nodes, edges, entry_point):
self.nodes = {node.node_name: node for node in nodes}
self.edges = self._create_edges(edges)
self.entry_point = entry_point.node_name
def _create_edges(self, edges):
edge_dict = {}
for from_node, to_node in edges:
edge_dict[from_node.node_name] = to_node.node_name
return edge_dict
def execute(self, initial_state):
current_node_name = self.entry_point
state = initial_state
while current_node_name is not None:
current_node = self.nodes[current_node_name]
result = current_node.execute(state)
if current_node.node_type == "conditional_node":
# For ConditionalNode, result is the next node based on the condition
current_node_name = result
elif current_node_name in self.edges:
# For regular nodes, move to the next node based on the defined edges
current_node_name = self.edges[current_node_name]
else:
# No further edges, end the execution
current_node_name = None
return state

28
yosoai/graph/base_node.py Normal file
View File

@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
class BaseNode(ABC):
def __init__(self, node_name: str, node_type: str):
"""
Initialize the node with a unique identifier and a specified node type.
Args:
node_name (str): The unique identifier name for the node.
node_type (str): The type of the node, limited to "node" or "conditional_node".
Raises:
ValueError: If node_type is not "node" or "conditional_node".
"""
self.node_name = node_name
if node_type not in ["node", "conditional_node"]:
raise ValueError(f"node_type must be 'node' or 'conditional_node', got '{node_type}'")
self.node_type = node_type
@abstractmethod
def execute(self, state):
"""
Execute the node's logic and return the updated state.
:param state: The current state of the graph.
:return: The updated state after executing this node.
"""
pass

View File

@ -0,0 +1,37 @@
from .base_node import BaseNode
class ConditionalNode(BaseNode):
def __init__(self, key_name, next_nodes, node_name="ConditionalNode"):
"""
Initializes the node with the key to check and the next node names based on the condition.
Args:
key_name (str): The name of the key to check in the state.
next_nodes (list): A list containing exactly two names of the next nodes.
The first is used if the key exists, the second if it does not.
Raises:
ValueError: If next_nodes does not contain exactly two elements.
"""
super().__init__(node_name, "conditional_node")
self.key_name = key_name
if len(next_nodes) != 2:
raise ValueError("next_nodes must contain exactly two elements.")
self.next_nodes = next_nodes
def execute(self, state):
"""
Checks if the specified key is present in the state and decides the next node accordingly.
Args:
state (dict): The current state of the graph.
Returns:
str: The name of the next node to execute based on the presence of the key.
"""
if self.key_name in state.get("keys", {}) and len(state["keys"][self.key_name]) > 0:
return self.next_nodes[0].node_name
else:
return self.next_nodes[1].node_name

View File

@ -0,0 +1,64 @@
from .base_node import BaseNode
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
class GenerateAnswerNode(BaseNode):
def __init__(self, llm, node_name="GenerateAnswerNode"):
super().__init__(node_name, "node")
self.llm = llm
# Initialize any other configurations for the LLM here
def execute(self, state):
"""
Generates an answer based on the user's input and the parsed document.
Args:
state: The current state of the graph, expected to contain
'user_input' and 'parsed_document' within 'keys'.
Returns:
The updated state with 'answer' within 'keys', containing the generated answer.
"""
print("---GENERATE ANSWER---")
try:
user_input = state["keys"]["user_input"]
document = state["keys"]["document"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
parsed_document = state["keys"].get("parsed_document", None)
relevant_chunks = state["keys"].get("relevant_chunks", None)
# Use relevant chunks if available, otherwise use the parsed document or the original document
if relevant_chunks:
context = relevant_chunks
elif parsed_document:
context = parsed_document
else:
context = document
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()
template = """You are a website scraper and you have just scraped the following content from a website. You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n The content is as follows:
{context}
Question: {question}
"""
schema_prompt = PromptTemplate(
template=template,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
# Chain
schema_chain = schema_prompt | self.llm | output_parser
answer = schema_chain.invoke({"context": context, "question": user_input})
# Update the state with the generated answer
state["keys"].update({"answer": answer})
return state

View File

@ -0,0 +1,51 @@
from .base_node import BaseNode
from langchain.prompts import PromptTemplate
from langchain.output_parsers import CommaSeparatedListOutputParser
class GetProbableTagsNode(BaseNode):
def __init__(self, llm, node_name="GetProbableTagsNode"):
super().__init__(node_name, "node")
self.llm = llm
def execute(self, state):
"""
Identifies probable HTML tags from a document based on a user's question.
Args:
state (dict): The current state of the graph, including 'document', 'user_input', and 'url' within 'keys'.
Returns:
dict: The updated state with a new key 'tags' within 'keys' containing probable HTML tags.
"""
print("---GET PROBABLE TAGS---")
# Accessing the nested structure
try:
user_input = state["keys"]["user_input"]
url = state["keys"]["url"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
output_parser = CommaSeparatedListOutputParser()
format_instructions = output_parser.get_format_instructions()
template = """You are a website scraper that knows all the types of html tags. You are now asked to list all the html tags where you think you can find the information of the asked question.\n {format_instructions} \n The webpage is: {webpage} \n The asked question is the following:
{question}
"""
tag_prompt = PromptTemplate(
template=template,
input_variables=["question"],
partial_variables={"format_instructions": format_instructions, "webpage": url},
)
# Execute the chain to get probable tags
tag_answer = tag_prompt | self.llm | output_parser
probable_tags = tag_answer.invoke({"question": user_input})
print("Possible tags: ", *probable_tags)
# Update the nested 'keys' dictionary with probable tags
state["keys"].update({"tags": probable_tags})
return state

View File

@ -0,0 +1,45 @@
from .base_node import BaseNode
from langchain_community.document_transformers import BeautifulSoupTransformer
class ParseHTMLNode(BaseNode):
def __init__(self, node_name="ParseHTMLNode"):
super().__init__(node_name, "node")
def execute(self, state):
"""
Checks for the 'tags' key in the state. If it exists, parses the document
based on these tags. Otherwise, returns the document as is.
Args:
state (dict): The current state of the graph, expected to contain
'document' within 'keys', and optionally 'tags'.
Returns:
dict: The updated state with 'parsed_document' within 'keys',
containing either the original or parsed document.
"""
print("---PARSE HTML DOCUMENT---")
try:
document = state["keys"]["document"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
# Check if tags are specified in the state
tags = state["keys"].get("tags", None)
if tags:
# Initialize the BeautifulSoupTransformer with any required configurations
bs_transformer = BeautifulSoupTransformer()
# Parse the document with specified tags
parsed_document = bs_transformer.transform_documents(document, tags_to_extract=tags)
print("Document parsed with specified tags.")
else:
# If no tags are specified, return the document as is
print("No specific tags provided; returning document as is.")
return state
# Update the state with the parsed document
state["keys"].update({"parsed_document": parsed_document})
return state