fix: refactoring of fetch_node

This commit is contained in:
Matteo Vedovati 2024-08-07 11:56:10 +02:00
parent 82e63213ae
commit 29ad140fa3
5 changed files with 228 additions and 75 deletions

6
examples/local_models/package-lock.json generated Normal file
View File

@ -0,0 +1,6 @@
{
"name": "local_models",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}

View File

@ -0,0 +1 @@
{}

View File

@ -6,6 +6,8 @@
# features: [] # features: []
# all-features: false # all-features: false
# with-sources: false # with-sources: false
# generate-hashes: false
# universal: false
-e file:. -e file:.
aiofiles==24.1.0 aiofiles==24.1.0
@ -110,6 +112,7 @@ filelock==3.15.4
# via huggingface-hub # via huggingface-hub
# via torch # via torch
# via transformers # via transformers
# via triton
fireworks-ai==0.14.0 fireworks-ai==0.14.0
# via langchain-fireworks # via langchain-fireworks
fonttools==4.53.1 fonttools==4.53.1
@ -185,6 +188,7 @@ graphviz==0.20.3
# via scrapegraphai # via scrapegraphai
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
groq==0.9.0 groq==0.9.0
# via langchain-groq # via langchain-groq
grpc-google-iam-v1==0.13.1 grpc-google-iam-v1==0.13.1
@ -353,6 +357,34 @@ numpy==1.26.4
# via shapely # via shapely
# via streamlit # via streamlit
# via transformers # via transformers
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.19.3
# via torch
nvidia-nvjitlink-cu12==12.6.20
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
openai==1.37.0 openai==1.37.0
# via burr # via burr
# via langchain-fireworks # via langchain-fireworks
@ -593,6 +625,8 @@ tqdm==4.66.4
transformers==4.43.3 transformers==4.43.3
# via langchain-huggingface # via langchain-huggingface
# via sentence-transformers # via sentence-transformers
triton==2.2.0
# via torch
typer==0.12.3 typer==0.12.3
# via fastapi-cli # via fastapi-cli
typing-extensions==4.12.2 typing-extensions==4.12.2
@ -635,6 +669,8 @@ uvicorn==0.30.3
# via fastapi # via fastapi
uvloop==0.19.0 uvloop==0.19.0
# via uvicorn # via uvicorn
watchdog==4.0.1
# via streamlit
watchfiles==0.22.0 watchfiles==0.22.0
# via uvicorn # via uvicorn
websockets==12.0 websockets==12.0

View File

@ -6,6 +6,8 @@
# features: [] # features: []
# all-features: false # all-features: false
# with-sources: false # with-sources: false
# generate-hashes: false
# universal: false
-e file:. -e file:.
aiohttp==3.9.5 aiohttp==3.9.5
@ -67,6 +69,7 @@ filelock==3.15.4
# via huggingface-hub # via huggingface-hub
# via torch # via torch
# via transformers # via transformers
# via triton
fireworks-ai==0.14.0 fireworks-ai==0.14.0
# via langchain-fireworks # via langchain-fireworks
free-proxy==1.1.1 free-proxy==1.1.1
@ -133,6 +136,7 @@ graphviz==0.20.3
# via scrapegraphai # via scrapegraphai
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
groq==0.9.0 groq==0.9.0
# via langchain-groq # via langchain-groq
grpc-google-iam-v1==0.13.1 grpc-google-iam-v1==0.13.1
@ -258,6 +262,34 @@ numpy==1.26.4
# via sentence-transformers # via sentence-transformers
# via shapely # via shapely
# via transformers # via transformers
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.19.3
# via torch
nvidia-nvjitlink-cu12==12.6.20
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
openai==1.37.0 openai==1.37.0
# via langchain-fireworks # via langchain-fireworks
# via langchain-openai # via langchain-openai
@ -408,6 +440,8 @@ tqdm==4.66.4
transformers==4.43.3 transformers==4.43.3
# via langchain-huggingface # via langchain-huggingface
# via sentence-transformers # via sentence-transformers
triton==2.2.0
# via torch
typing-extensions==4.12.2 typing-extensions==4.12.2
# via anthropic # via anthropic
# via anyio # via anyio

View File

@ -102,81 +102,150 @@ class FetchNode(BaseNode):
input_data = [state[key] for key in input_keys] input_data = [state[key] for key in input_keys]
source = input_data[0] source = input_data[0]
if ( input_type = input_keys[0]
input_keys[0] == "json_dir"
or input_keys[0] == "xml_dir" handlers = {
or input_keys[0] == "csv_dir" "json_dir": self.handle_directory,
or input_keys[0] == "pdf_dir" "xml_dir": self.handle_directory,
or input_keys[0] == "md_dir" "csv_dir": self.handle_directory,
): "pdf_dir": self.handle_directory,
compressed_document = [ "md_dir": self.handle_directory,
source "pdf": self.handle_file,
] "csv": self.handle_file,
"json": self.handle_file,
state.update({self.output[0]: compressed_document}) "xml": self.handle_file,
return state "md": self.handle_file,
# handling pdf }
elif input_keys[0] == "pdf":
loader = PyPDFLoader(source) if input_type in handlers:
compressed_document = loader.load() return handlers[input_type](state, input_type, source)
state.update({self.output[0]: compressed_document})
return state
elif input_keys[0] == "csv":
compressed_document = [
Document(
page_content=str(pd.read_csv(source)), metadata={"source": "csv"}
)
]
state.update({self.output[0]: compressed_document})
return state
elif input_keys[0] == "json":
f = open(source, encoding="utf-8")
compressed_document = [
Document(page_content=str(json.load(f)), metadata={"source": "json"})
]
state.update({self.output[0]: compressed_document})
return state
elif input_keys[0] == "xml":
with open(source, "r", encoding="utf-8") as f:
data = f.read()
compressed_document = [
Document(page_content=data, metadata={"source": "xml"})
]
state.update({self.output[0]: compressed_document})
return state
elif input_keys[0] == "md":
with open(source, "r", encoding="utf-8") as f:
data = f.read()
compressed_document = [
Document(page_content=data, metadata={"source": "md"})
]
state.update({self.output[0]: compressed_document})
return state
elif self.input == "pdf_dir": elif self.input == "pdf_dir":
pass pass
elif not source.startswith("http"): elif not source.startswith("http"):
self.logger.info(f"--- (Fetching HTML from: {source}) ---") return self.handle_local_source(state, source)
if not source.strip(): else:
raise ValueError("No HTML body content found in the local source.") return self.handle_web_source(state, source)
def handle_directory(self, state, input_type, source):
"""
Handles the directory by compressing the source document and updating the state.
Parameters:
state (dict): The current state of the graph.
input_type (str): The type of input being processed.
source (str): The source document to be compressed.
Returns:
dict: The updated state with the compressed document.
"""
compressed_document = [
source
]
state.update({self.output[0]: compressed_document})
return state
def handle_file(self, state, input_type, source):
"""
Loads the content of a file based on its input type.
Parameters:
state (dict): The current state of the graph.
input_type (str): The type of the input file (e.g., "pdf", "csv", "json", "xml", "md").
source (str): The path to the source file.
Returns:
dict: The updated state with the compressed document.
The function supports the following input types:
- "pdf": Uses PyPDFLoader to load the content of a PDF file.
- "csv": Reads the content of a CSV file using pandas and converts it to a string.
- "json": Loads the content of a JSON file.
- "xml": Reads the content of an XML file as a string.
- "md": Reads the content of a Markdown file as a string.
"""
compressed_document = self.load_file_content(source, input_type)
return self.update_state(state, compressed_document)
def load_file_content(self, source, input_type):
"""
Loads the content of a file based on its input type.
Parameters:
source (str): The path to the source file.
input_type (str): The type of the input file (e.g., "pdf", "csv", "json", "xml", "md").
Returns:
list: A list containing a Document object with the loaded content and metadata.
"""
if input_type == "pdf":
loader = PyPDFLoader(source)
return loader.load()
elif input_type == "csv":
return [Document(page_content=str(pd.read_csv(source)), metadata={"source": "csv"})]
elif input_type == "json":
with open(source, encoding="utf-8") as f:
return [Document(page_content=str(json.load(f)), metadata={"source": "json"})]
elif input_type == "xml" or input_type == "md":
with open(source, "r", encoding="utf-8") as f:
data = f.read()
return [Document(page_content=data, metadata={"source": input_type})]
def handle_local_source(self, state, source):
"""
Handles the local source by fetching HTML content, optionally converting it to Markdown,
and updating the state.
Parameters:
state (dict): The current state of the graph.
source (str): The HTML content from the local source.
Returns:
dict: The updated state with the processed content.
Raises:
ValueError: If the source is empty or contains only whitespace.
"""
self.logger.info(f"--- (Fetching HTML from: {source}) ---")
if not source.strip():
raise ValueError("No HTML body content found in the local source.")
parsed_content = source
if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator:
parsed_content = convert_to_md(source)
else:
parsed_content = source parsed_content = source
if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator: compressed_document = [
Document(page_content=parsed_content, metadata={"source": "local_dir"})
]
return self.update_state(state, compressed_document)
def handle_web_source(self, state, source):
"""
Handles the web source by fetching HTML content from a URL, optionally converting it to Markdown,
and updating the state.
parsed_content = convert_to_md(source) Parameters:
else: state (dict): The current state of the graph.
parsed_content = source source (str): The URL of the web source to fetch HTML content from.
compressed_document = [ Returns:
Document(page_content=parsed_content, metadata={"source": "local_dir"}) dict: The updated state with the processed content.
]
elif self.use_soup: Raises:
self.logger.info(f"--- (Fetching HTML from: {source}) ---") ValueError: If the fetched HTML content is empty or contains only whitespace.
"""
self.logger.info(f"--- (Fetching HTML from: {source}) ---")
if self.use_soup:
response = requests.get(source) response = requests.get(source)
if response.status_code == 200: if response.status_code == 200:
if not response.text.strip(): if not response.text.strip():
@ -194,9 +263,7 @@ class FetchNode(BaseNode):
self.logger.warning( self.logger.warning(
f"Failed to retrieve contents from the webpage at url: {source}" f"Failed to retrieve contents from the webpage at url: {source}"
) )
else: else:
self.logger.info(f"--- (Fetching HTML from: {source}) ---")
loader_kwargs = {} loader_kwargs = {}
if self.node_config is not None: if self.node_config is not None:
@ -219,15 +286,24 @@ class FetchNode(BaseNode):
if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator and not self.openai_md_enabled: if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator and not self.openai_md_enabled:
parsed_content = convert_to_md(document[0].page_content, input_data[0]) parsed_content = convert_to_md(document[0].page_content, input_data[0])
compressed_document = [ compressed_document = [
Document(page_content=parsed_content, metadata={"source": "html file"}) Document(page_content=parsed_content, metadata={"source": "html file"})
] ]
return self.update_state(state, compressed_document)
def update_state(self, state, compressed_document):
"""
Updates the state with the output data from the node.
state.update( Args:
{ state (dict): The current state of the graph.
self.output[0]: compressed_document, compressed_document (List[Document]): The compressed document content fetched
} by the node.
)
return state Returns:
dict: The updated state with the output data.
"""
state.update({self.output[0]: compressed_document,})
return state