Merge remote-tracking branch 'origin/main' into refactoring_nodes

This commit is contained in:
Perinim 2024-03-18 11:07:42 +01:00
commit b5bdd4e729
12 changed files with 197 additions and 20 deletions

1
.gitignore vendored
View File

@ -27,3 +27,4 @@ venv/
*.mp3
*.sqlite
examples/graph_examples/ScrapeGraphAI_generated_graph
main.py

View File

@ -4,6 +4,7 @@
# 🕷️ ScrapeGraphAI: You Only Scrape Once
[![Downloads](https://static.pepy.tech/badge/scrapegraphai)](https://pepy.tech/project/scrapegraphai)
[![linting: pylint](https://img.shields.io/badge/linting-pylint-yellowgreen)](https://github.com/pylint-dev/pylint)
[![Pylint](https://github.com/VinciGit00/Scrapegraph-ai/actions/workflows/pylint.yml/badge.svg)](https://github.com/VinciGit00/Scrapegraph-ai/actions/workflows/pylint.yml)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
@ -27,7 +28,7 @@ Official streamlit demo:
[![My Skills](https://skillicons.dev/icons?i=react)](https://scrapegraph-ai-demo.streamlit.app/)
Is it possible to try also the colab version
Try it directly on the web using Google Colab:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1sEZBonBMGP44CtO6GQTwAlL0BGJXjtfd?usp=sharing)
@ -36,7 +37,9 @@ Follow the procedure on the following link to setup your OpenAI API key: [link](
## 📖 Documentation
The documentation for ScrapeGraphAI can be found [here](https://scrapegraph-ai.readthedocs.io/en/latest/).
Behind this there is also the docusaurus documentation [here](https://scrapegraph-doc.onrender.com/).
Check out also the docusaurus [documentation](https://scrapegraph-doc.onrender.com/).
## 💻 Usage
### Case 1: Extracting information using a prompt
@ -77,25 +80,18 @@ The output will be a dictionary with the extracted information, for example:
## 🤝 Contributing
Scrapegraph-ai is [MIT LICENSED](https://github.com/VinciGit00/Scrapegraph-ai/blob/main/LICENSE).
Contributions are welcome! Please check out the todos below, and feel free to open a pull request.
Fell free to contribute and join our Discord server to discuss with us improvements and give us suggestions!
For more information, please see the [contributing guidelines](https://github.com/VinciGit00/Scrapegraph-ai/blob/main/CONTRIBUTING.md).
Join our Discord server to discuss with us improvements and give us suggestions!
[![My Skills](https://skillicons.dev/icons?i=discord)](https://discord.gg/DujC7HG8)
You can also follow all the updates on linkedin!
[![My Skills](https://skillicons.dev/icons?i=linkedin)](https://www.linkedin.com/company/scrapegraphai/)
## ❤️ Contributors
[![Contributors](https://contrib.rocks/image?repo=VinciGit00/Scrapegraph-ai)](https://github.com/VinciGit00/Scrapegraph-ai/graphs/contributors)
### Citations
If you want to use our library for research purposes please quote us with the following reference
## 🎓 Citations
If you have used our library for research purposes please quote us with the following reference:
```text
@misc{scrapegraph-ai,
author = {Marco Perini, Lorenzo Padoan, Marco Vinciguerra},
@ -120,7 +116,7 @@ If you want to use our library for research purposes please quote us with the fo
## 📜 License
ScrapeGraphAI is licensed under the Apache 2.0 License. See the [LICENSE](https://github.com/VinciGit00/Scrapegraph-ai/blob/main/LICENSE) file for more information.
ScrapeGraphAI is licensed under the MIT License. See the [LICENSE](https://github.com/VinciGit00/Scrapegraph-ai/blob/main/LICENSE) file for more information.
## Acknowledgements

View File

@ -0,0 +1,49 @@
"""
Example of custom graph using existing node using Gemini APIs
"""
import os
from dotenv import load_dotenv
from scrapegraphai.models import Gemini
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchHTMLNode, ParseNode, GenerateAnswerNodeVanilla
load_dotenv()
gemini_key = os.getenv("GOOGLE_API_KEY")
llm_config = {
"api_key": gemini_key,
"model_name": "gemini-pro",
}
model = Gemini(llm_config)
# define the nodes for the graph
fetch_html_node = FetchHTMLNode("fetch_html")
parse_document_node = ParseNode(
doc_type="html", chunks_size=4000, node_name="parse_document")
generate_answer_node = GenerateAnswerNodeVanilla(model, "generate_answer")
# create the graph
graph = BaseGraph(
nodes={
fetch_html_node,
parse_document_node,
generate_answer_node
},
edges={
(fetch_html_node, parse_document_node),
(parse_document_node, generate_answer_node)
},
entry_point=fetch_html_node
)
# execute the graph
inputs = {"user_input": "List me the projects with their description",
"url": "https://perinim.github.io/projects/"}
result = graph.execute(inputs)
# get the answer from the result
answer = result.get("answer", "No answer found.")
print(answer)

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "scrapegraphai"
version = "0.0.8"
version = "0.0.9"
description = "A web scraping library based on LangChain which uses LLM and direct graph logic to create scraping pipelines."
authors = [
"Marco Vinciguerra <mvincig11@gmail.com>",

View File

@ -2,6 +2,7 @@ langchain==0.1.6
langchain_community==0.0.19
langchain_core==0.1.22
langchain_openai==0.0.5
langchain_google_genai==0.0.11
faiss-cpu==1.7.4
html2text==2020.1.16
beautifulsoup4==4.12.3

View File

@ -4,7 +4,7 @@ Module for making the graph building
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_extraction_chain
from ..models import OpenAI
from ..models import OpenAI, Gemini
from ..helpers import nodes_metadata, graph_schema
@ -67,6 +67,13 @@ class GraphBuilder:
llm_params = {**llm_defaults, **llm_config}
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
# select the model based on the model name
if "gpt-" in llm_params["model_name"]:
return OpenAI(llm_params)
elif "gemini" in llm_params["model_name"]:
return Gemini(llm_params)
return OpenAI(llm_params)
def _generate_nodes_description(self):

View File

@ -5,3 +5,4 @@
from .openai import OpenAI
from .openai_itt import OpenAIImageToText
from .openai_tts import OpenAITextToSpeech
from .gemini import Gemini

View File

@ -0,0 +1,19 @@
from langchain_google_genai import ChatGoogleGenerativeAI
class Gemini(ChatGoogleGenerativeAI):
"""Class for wrapping gemini module"""
def __init__(self, llm_config: dict):
"""
A wrapper for the Gemini class that provides default configuration
and could be extended with additional methods if needed.
Args:
llm_config (dict): Configuration parameters for the language model.
such as model="gemini-pro" and api_key
"""
# change the key model_name to model
llm_config["model"] = llm_config["model_name"]
# Initialize the superclass (ChatOpenAI) with provided config parameters
super().__init__(**llm_config)

View File

@ -9,4 +9,4 @@ from .generate_answer_node import GenerateAnswerNode
from .parse_node import ParseNode
from .rag_node import RAGNode
from .text_to_speech_node import TextToSpeechNode
from .image_to_text_node import ImageToTextNode
from .image_to_text_node import ImageToTextNode

View File

@ -13,7 +13,7 @@ from langchain_core.runnables import RunnableParallel
from .base_node import BaseNode
from typing import List
class GenerateAnswerNode(BaseNode):
class GenerateAnswerNodeFromRag(BaseNode):
"""
A node that generates an answer using a language model (LLM) based on the user's input
and the content extracted from a webpage. It constructs a prompt from the user's input

View File

@ -0,0 +1,103 @@
"""
Module for generating the answer node
"""
# Imports from standard library
from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
# Imports from the library
from .base_node import BaseNode
class GenerateAnswerNodeVanilla(BaseNode):
"""
A node that generates an answer using a language model (LLM) based on the user's input
and the content extracted from a webpage. It constructs a prompt from the user's input
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
an answer.
Attributes:
llm (ChatOpenAI): An instance of a language model client, configured for generating answers.
node_name (str): The unique identifier name for the node, defaulting
to "GenerateAnswerNode".
node_type (str): The type of the node, set to "node" indicating a
standard operational node.
Args:
llm: An instance of the language model client (e.g., ChatOpenAI) used
for generating answers.
node_name (str, optional): The unique identifier name for the node.
Defaults to "GenerateAnswerNodeVanilla".
Methods:
execute(state): Processes the input and document from the state to generate an answer,
updating the state with the generated answer under the 'answer' key.
"""
def __init__(self, llm, node_name: str):
"""
Initializes the GenerateAnswerNode with a language model client and a node name.
Args:
llm (OpenAIImageToText): An instance of the OpenAIImageToText class.
node_name (str): name of the node
"""
super().__init__(node_name, "node")
self.llm = llm
def execute(self, state: dict) -> dict:
"""
Generates an answer by constructing a prompt from the user's input and the scraped
content, querying the language model, and parsing its response.
The method updates the state with the generated answer under the 'answer' key.
Args:
state (dict): The current state of the graph, expected to contain 'user_input',
and optionally 'parsed_document' or 'relevant_chunks' within 'keys'.
Returns:
dict: The updated state with the 'answer' key containing the generated answer.
Raises:
KeyError: If 'user_input' or 'document' is not found in the state, indicating
that the necessary information for generating an answer is missing.
"""
print("---GENERATING ANSWER---")
try:
user_input = state["user_input"]
document = state["document"][0]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
context = document
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()
template_json = """You are a website scraper and you have just scraped the
following content from a website.
You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n
This is the scraped text:\n
{context} \n
Question: {question}
"""
# Merge the answers from the chunks
merge_prompt = PromptTemplate(
template=template_json,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm | output_parser
answer = merge_chain.invoke(
{"context": context, "question": user_input})
# Update the state with the generated answer
state.update({"answer": answer})
return state

View File

@ -6,7 +6,7 @@ import unittest
from unittest.mock import patch
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchTextNode, ParseNode, RAGNode, GenerateAnswerNode
from scrapegraphai.nodes import FetchTextNode, ParseNode, RAGNode, GenerateAnswerNodeFromRag
class TestCustomGraph(unittest.TestCase):
@ -59,7 +59,7 @@ class TestCustomGraph(unittest.TestCase):
parse_document_node = ParseNode(
doc_type="text", chunks_size=20, node_name="parse_document")
rag_node = RAGNode(model, "rag")
generate_answer_node = GenerateAnswerNode(model, "generate_answer")
generate_answer_node = GenerateAnswerNodeFromRag(model, "generate_answer")
graph = BaseGraph(
nodes={