mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
fix(pdf_scraper): fix the pdf scraper gaph
This commit is contained in:
parent
00a392bdbe
commit
d00cde6030
@ -181,6 +181,7 @@ class AbstractGraph(ABC):
|
|||||||
try:
|
try:
|
||||||
self.model_token = models_tokens["ollama"][llm_params["model"]]
|
self.model_token = models_tokens["ollama"][llm_params["model"]]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
|
print("model not found, using default token size (8192)")
|
||||||
self.model_token = 8192
|
self.model_token = 8192
|
||||||
else:
|
else:
|
||||||
self.model_token = 8192
|
self.model_token = 8192
|
||||||
@ -191,16 +192,18 @@ class AbstractGraph(ABC):
|
|||||||
elif "hugging_face" in llm_params["model"]:
|
elif "hugging_face" in llm_params["model"]:
|
||||||
try:
|
try:
|
||||||
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
|
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
|
||||||
except KeyError as exc:
|
except KeyError:
|
||||||
raise KeyError("Model not supported") from exc
|
print("model not found, using default token size (8192)")
|
||||||
|
self.model_token = 8192
|
||||||
return HuggingFace(llm_params)
|
return HuggingFace(llm_params)
|
||||||
elif "groq" in llm_params["model"]:
|
elif "groq" in llm_params["model"]:
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_token = models_tokens["groq"][llm_params["model"]]
|
self.model_token = models_tokens["groq"][llm_params["model"]]
|
||||||
except KeyError as exc:
|
except KeyError:
|
||||||
raise KeyError("Model not supported") from exc
|
print("model not found, using default token size (8192)")
|
||||||
|
self.model_token = 8192
|
||||||
return Groq(llm_params)
|
return Groq(llm_params)
|
||||||
elif "bedrock" in llm_params["model"]:
|
elif "bedrock" in llm_params["model"]:
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||||
@ -208,8 +211,9 @@ class AbstractGraph(ABC):
|
|||||||
client = llm_params.get('client', None)
|
client = llm_params.get('client', None)
|
||||||
try:
|
try:
|
||||||
self.model_token = models_tokens["bedrock"][llm_params["model"]]
|
self.model_token = models_tokens["bedrock"][llm_params["model"]]
|
||||||
except KeyError as exc:
|
except KeyError:
|
||||||
raise KeyError("Model not supported") from exc
|
print("model not found, using default token size (8192)")
|
||||||
|
self.model_token = 8192
|
||||||
return Bedrock({
|
return Bedrock({
|
||||||
"client": client,
|
"client": client,
|
||||||
"model_id": model_id,
|
"model_id": model_id,
|
||||||
@ -218,13 +222,18 @@ class AbstractGraph(ABC):
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
elif "claude-3-" in llm_params["model"]:
|
elif "claude-3-" in llm_params["model"]:
|
||||||
self.model_token = models_tokens["claude"]["claude3"]
|
try:
|
||||||
|
self.model_token = models_tokens["claude"]["claude3"]
|
||||||
|
except KeyError:
|
||||||
|
print("model not found, using default token size (8192)")
|
||||||
|
self.model_token = 8192
|
||||||
return Anthropic(llm_params)
|
return Anthropic(llm_params)
|
||||||
elif "deepseek" in llm_params["model"]:
|
elif "deepseek" in llm_params["model"]:
|
||||||
try:
|
try:
|
||||||
self.model_token = models_tokens["deepseek"][llm_params["model"]]
|
self.model_token = models_tokens["deepseek"][llm_params["model"]]
|
||||||
except KeyError as exc:
|
except KeyError:
|
||||||
raise KeyError("Model not supported") from exc
|
print("model not found, using default token size (8192)")
|
||||||
|
self.model_token = 8192
|
||||||
return DeepSeek(llm_params)
|
return DeepSeek(llm_params)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -312,10 +321,7 @@ class AbstractGraph(ABC):
|
|||||||
models_tokens["bedrock"][embedder_config["model"]]
|
models_tokens["bedrock"][embedder_config["model"]]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise KeyError("Model not supported") from exc
|
raise KeyError("Model not supported") from exc
|
||||||
return BedrockEmbeddings(client=client, model_id=embedder_config["model"])
|
return BedrockEmbeddings(client=client, model_id=embedder_config["model"])
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Model provided by the configuration not supported")
|
|
||||||
|
|
||||||
def get_state(self, key=None) -> dict:
|
def get_state(self, key=None) -> dict:
|
||||||
"""""
|
"""""
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from ..nodes import (
|
|||||||
FetchNode,
|
FetchNode,
|
||||||
ParseNode,
|
ParseNode,
|
||||||
RAGNode,
|
RAGNode,
|
||||||
GenerateAnswerNode
|
GenerateAnswerPDFNode
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ class PDFScraperGraph(AbstractGraph):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
|
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
|
||||||
super().__init__(prompt, config, source, schema)
|
super().__init__(prompt, config, source)
|
||||||
|
|
||||||
self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"
|
self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"
|
||||||
|
|
||||||
@ -64,41 +64,21 @@ class PDFScraperGraph(AbstractGraph):
|
|||||||
input='pdf | pdf_dir',
|
input='pdf | pdf_dir',
|
||||||
output=["doc", "link_urls", "img_urls"],
|
output=["doc", "link_urls", "img_urls"],
|
||||||
)
|
)
|
||||||
parse_node = ParseNode(
|
generate_answer_node_pdf = GenerateAnswerPDFNode(
|
||||||
input="doc",
|
|
||||||
output=["parsed_doc"],
|
|
||||||
node_config={
|
|
||||||
"chunk_size": self.model_token,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
rag_node = RAGNode(
|
|
||||||
input="user_prompt & (parsed_doc | doc)",
|
|
||||||
output=["relevant_chunks"],
|
|
||||||
node_config={
|
|
||||||
"llm_model": self.llm_model,
|
|
||||||
"embedder_model": self.embedder_model,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
generate_answer_node = GenerateAnswerNode(
|
|
||||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||||
output=["answer"],
|
output=["answer"],
|
||||||
node_config={
|
node_config={
|
||||||
"llm_model": self.llm_model,
|
"llm_model": self.llm_model,
|
||||||
"schema": self.schema,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return BaseGraph(
|
return BaseGraph(
|
||||||
nodes=[
|
nodes=[
|
||||||
fetch_node,
|
fetch_node,
|
||||||
parse_node,
|
generate_answer_node_pdf,
|
||||||
rag_node,
|
|
||||||
generate_answer_node,
|
|
||||||
],
|
],
|
||||||
edges=[
|
edges=[
|
||||||
(fetch_node, parse_node),
|
(fetch_node, generate_answer_node_pdf)
|
||||||
(parse_node, rag_node),
|
|
||||||
(rag_node, generate_answer_node)
|
|
||||||
],
|
],
|
||||||
entry_point=fetch_node
|
entry_point=fetch_node
|
||||||
)
|
)
|
||||||
@ -114,4 +94,4 @@ class PDFScraperGraph(AbstractGraph):
|
|||||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||||
|
|
||||||
return self.final_state.get("answer", "No answer found.")
|
return self.final_state.get("answer", "No answer found.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user